Skip to content

Commit d8fb2d2

Browse files
committed
fix the comment by ai review
Signed-off-by: sewon.jeon <sewon.jeon@connecteve.com>
1 parent a231740 commit d8fb2d2

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

monai/transforms/post/array.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,11 +843,14 @@ def __call__(self, points: NdarrayOrTensor, spatial_shape: Sequence[int] | None
843843
continue
844844
if not ((center >= 0).all() and (center < bounds_t).all()):
845845
continue
846-
# Round to nearest integer for impulse placement
846+
# Round to nearest integer for impulse placement, then clamp to valid index range
847847
center_int = center.round().long()
848+
# Clamp indices to [0, size-1] to avoid out-of-bounds (e.g., 9.7 rounds to 10 in size-10 array)
849+
bounds_max = (bounds_t - 1).long()
850+
center_int = torch.minimum(torch.maximum(center_int, torch.zeros_like(center_int)), bounds_max)
848851
# Place impulse (use maximum in case of overlapping landmarks)
849852
current_val = heatmap[idx][tuple(center_int)]
850-
heatmap[idx][tuple(center_int)] = max(current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device))
853+
heatmap[idx][tuple(center_int)] = torch.maximum(current_val, torch.tensor(1.0, dtype=self.torch_dtype, device=device))
851854

852855
# Apply Gaussian blur using GaussianFilter
853856
# Reshape to (num_points, 1, *spatial) for per-channel filtering

monai/transforms/post/dictionary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ def _determine_shape(
680680
self, points: Any, static_shape: tuple[int, ...] | None, data: Mapping[Hashable, Any], ref_key: Hashable | None
681681
) -> tuple[int, ...]:
682682
points_t = convert_to_tensor(points, dtype=torch.float32, track_meta=False)
683-
if points_t.ndim not in (2, 3):
683+
if points_t.ndim != 2:
684684
raise ValueError(f"{self._ERR_INVALID_POINTS} Got {points_t.ndim}D tensor.")
685685
spatial_dims = int(points_t.shape[-1])
686686
if static_shape is not None:

0 commit comments

Comments
 (0)