diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py index 50a680336d..9d843c6898 100644 --- a/monai/networks/schedulers/ddim.py +++ b/monai/networks/schedulers/ddim.py @@ -117,18 +117,14 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N ) self.num_inference_steps = num_inference_steps - step_ratio = self.num_train_timesteps // self.num_inference_steps - if self.steps_offset >= step_ratio: - raise ValueError( - f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to " - f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed" - f" the max train timestep." - ) - - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) + if self.steps_offset < 0 or self.steps_offset >= self.num_train_timesteps: + raise ValueError(f"`steps_offset`: {self.steps_offset} must be in range [0, {self.num_train_timesteps}).") + + self.timesteps = ( + torch.linspace((self.num_train_timesteps - 1) - self.steps_offset, 0, num_inference_steps, device=device) + .round() + .long() + ) self.timesteps += self.steps_offset def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor: diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py index e2b7ab55f5..73480346b0 100644 --- a/monai/networks/schedulers/ddpm.py +++ b/monai/networks/schedulers/ddpm.py @@ -31,7 +31,6 @@ from __future__ import annotations -import numpy as np import torch from monai.utils import StrEnum @@ -122,11 +121,9 @@ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | N ) self.num_inference_steps = num_inference_steps - step_ratio = self.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64) - self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps = ( + torch.linspace(self.num_train_timesteps - 1, 0, self.num_inference_steps, device=device).round().long() + ) def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor: """