-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat: add activation checkpointing to unet #8554
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
feat: add activation checkpointing to unet #8554
Conversation
|
Note Currently processing new changes in this PR. This may take a few minutes, please wait... 📒 Files selected for processing (1)
Tip CodeRabbit can use your project's `pylint` configuration to improve the quality of Python code reviews.Add a pylint configuration file to your project to customize how CodeRabbit runs WalkthroughAdds monai/networks/blocks/activation_checkpointing.py introducing ActivationCheckpointWrapper that applies torch.utils.checkpoint.checkpoint(..., use_reentrant=False) to a wrapped nn.Module. Adds CheckpointUNet(UNet) in monai/networks/nets/unet.py which overrides _get_connection_block to wrap the connection subblock, down_path, and up_path with ActivationCheckpointWrapper and updates Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Areas to pay attention to:
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)
29-33: Add a brief class docstring to the wrapper.
Improves discoverability and meets docstring guidelines.Apply this diff:
class _ActivationCheckpointWrapper(nn.Module): - def __init__(self, module: nn.Module) -> None: + """Apply activation checkpointing to the wrapped module during training.""" + def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module
134-135: Document the newuse_checkpointingarg in the class docstring and user docs.
State trade-offs (memory vs compute), that it’s training-only, incompatible withtorch.no_grad, and preserves RNG by default.Proposed docstring snippet to add under “Args”:
use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False.I can open a docs patch and add a short example enabling the flag.
163-164: Static wrapping caveat: runtime flips won’t take effect.
After init, changingself.use_checkpointingwon’t rewrap existing blocks. Either document this or add a small helper to (re)build the model if you expect runtime toggling.Do you expect users to toggle this at runtime? If yes, I can sketch a safe rewrap helper.
210-212: Checkpointing scope is subblock-only; consider an optional broader policy.
Current placement is a good default. If more memory is needed, offer a policy to also wrapdown_path/up_path(with a warning about extra compute).I can add a
checkpoint_policy: Literal["subblock","all"] = "subblock"in__init__and wire it here on request.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(5 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: packaging
🔇 Additional comments (1)
monai/networks/nets/unet.py (1)
16-21: Imports for checkpointing look good.
castandcheckpointare appropriate for the new wrapper.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Fábio S. Ferreira <ferreira.fabio80@gmail.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
29-43: Good guard + compatibility fallback.
Training/grad-enabled checks anduse_reentrant=FalsewithTypeErrorfallback are the right call. This addresses the prior review note.
🧹 Nitpick comments (5)
monai/networks/nets/unet.py (5)
29-43: Avoid per-iteration TypeError cost: detectuse_reentrantsupport once.
Resolve support at import/init time to prevent raising an exception every forward on older torch.Apply:
@@ -class _ActivationCheckpointWrapper(nn.Module): +_SUPPORTS_USE_REENTRANT: bool | None = None + +class _ActivationCheckpointWrapper(nn.Module): @@ - def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training and torch.is_grad_enabled() and x.requires_grad: - try: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) - except TypeError: - # Fallback for older PyTorch without `use_reentrant` - return cast(torch.Tensor, checkpoint(self.module, x)) - return cast(torch.Tensor, self.module(x)) + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.training and torch.is_grad_enabled() and x.requires_grad: + global _SUPPORTS_USE_REENTRANT + if _SUPPORTS_USE_REENTRANT is None: + try: + # probe once + checkpoint(self.module, x, use_reentrant=False) # type: ignore[arg-type] + _SUPPORTS_USE_REENTRANT = True + except TypeError: + _SUPPORTS_USE_REENTRANT = False + except Exception: + # do not change behavior on unexpected errors; fall back below + _SUPPORTS_USE_REENTRANT = False + if _SUPPORTS_USE_REENTRANT: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, checkpoint(self.module, x)) + return cast(torch.Tensor, self.module(x))Add outside the hunk (file header):
import inspect # if you switch to signature probing instead of try/exceptNote: PyTorch recommends passing
use_reentrantexplicitly going forward. (docs.pytorch.org)
29-43: TorchScript: make wrapper script-safe.
try/exceptand dynamic checkpoint calls won’t script. Short-circuit under scripting.Apply:
@@ - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: + if torch.jit.is_scripting(): + # Avoid checkpoint in scripted graphs + return cast(torch.Tensor, self.module(x))
29-43: Docstring completeness.
Add Google-style docstrings for the wrapper’s class/init/forward (inputs, returns, raises).Example:
@@ -class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" +class _ActivationCheckpointWrapper(nn.Module): + """Wrap a module and apply activation checkpointing during training. + + Args: + module: The submodule to checkpoint. + + Returns: + torch.Tensor: Output tensor from the wrapped submodule. + + Raises: + RuntimeError: If checkpoint fails at runtime. + """
90-92: Tighten theuse_checkpointingdocstring and add a BN caveat.
Keep it on one Args entry and note the BatchNorm limitation.- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: If True, apply activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Bypassed in eval and when gradients are disabled. Note: avoid with + BatchNorm layers due to running-stat updates during recomputation. Defaults to False.Reference on RNG determinism (dropout is handled by default). (docs.pytorch.org)
217-219: Wrap site is fine; consider BN detection here instead of inside the wrapper.
Keeps policy close to where wrapping occurs and avoids per-instance warnings.- if self.use_checkpointing: - subblock = _ActivationCheckpointWrapper(subblock) + if self.use_checkpointing: + has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in subblock.modules()) + if has_bn: + warnings.warn( + "Skipping activation checkpointing for this subblock (contains BatchNorm).", + RuntimeWarning, + ) + else: + subblock = _ActivationCheckpointWrapper(subblock) return nn.Sequential(down_path, SkipConnection(subblock), up_path)Rationale: avoids double-updating BN stats during recomputation. (github.com)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)
16-16: LGTM: imports are correct and scoped.
castandcheckpointare needed by the wrapper; no issues.Also applies to: 20-20
141-141: Public API addition: ensure tests and docs cover new flag.
Add unit tests for parity (on/off), eval bypass, andno_grad()bypass; document in release notes/configs.Proposed minimal tests:
- Forward/backward equivalence within tolerance for
use_checkpointing={False,True}on a tiny UNet.model.eval()andtorch.no_grad()paths skip checkpoint (nocheckpointmocks invoked).- Autocast path under CUDA runs without dtype mismatches.
Do you want a patch with pytest tests added under
tests/test_unet_checkpointing.py?
170-170: LGTM: stores flag on the instance.
No concerns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
monai/networks/nets/unet.py (4)
29-51: Warn once for BN, and add param/return docstring per guidelines.Avoid repeated RuntimeWarnings and document the wrapper’s contract.
class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" + """Apply activation checkpointing to the wrapped module during training. + + Args: + module: submodule to wrap. + Returns: + torch.Tensor: output of the wrapped module. + Warnings: + Skips checkpointing and emits a RuntimeWarning if the submodule contains + BatchNorm to avoid double-updating running stats. + """ def __init__(self, module: nn.Module) -> None: super().__init__() # Pre-detect BatchNorm presence for fast path self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) self.module = module + self._bn_warned = False def forward(self, x: torch.Tensor) -> torch.Tensor: if self.training and torch.is_grad_enabled() and x.requires_grad: if self._has_bn: - warnings.warn( - "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " - "running statistics during recomputation.", - RuntimeWarning, - ) + if not self._bn_warned: + warnings.warn( + "Activation checkpointing skipped for a subblock containing BatchNorm to avoid double-updating " + "running statistics during recomputation.", + RuntimeWarning, + ) + self._bn_warned = True return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: # Fallback for older PyTorch without `use_reentrant` return cast(torch.Tensor, checkpoint(self.module, x)) return cast(torch.Tensor, self.module(x))Minimal tests to add:
- Training vs eval parity (values match with/without checkpointing).
- BN subblock emits RuntimeWarning and bypasses checkpointing.
- Guard under
torch.no_grad()and whenrequires_grad=False.
99-101: Clarifyuse_checkpointingbehavior in docs (BN, grad, training-only, build-time).Make expectations explicit for users.
- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Checkpointing is active only when `self.training` is True, gradients + are enabled, and inputs require gradients; it is bypassed in eval and when grads are disabled. + Sub-blocks containing BatchNorm are not checkpointed to avoid double-updating running stats. + Note: this is a build-time option; changing it after initialization will not retroactively wrap existing + sub-blocks. Defaults to False.
179-180: Flag is build-time only; consider asserting or documenting.Changing
self.use_checkpointingpost-init has no effect since wrappers are created during construction. The doc update above covers this; alternatively, convert to a read-only attribute to avoid confusion.
226-228: Wrapping only the subblock is OK; consider optional coverage toggle.If desired later, expose an opt-in to also wrap
down/upblocks for additional savings.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)
16-21: Imports look good.Scoped import of
checkpointpluscastis appropriate.
150-151: API change is sensible and non-breaking.Parameter added at the end; default preserves behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
29-43: Skip checkpointing for subblocks with BatchNorm to avoid double-updating running stats.
Checkpoint recompute updates BN running stats twice under training. Detect BN in the wrapped module and bypass checkpointing with a warning.Apply this diff:
class _ActivationCheckpointWrapper(nn.Module): - """Apply activation checkpointing to the wrapped module during training.""" + """Apply activation checkpointing to the wrapped module during training. + Skips checkpointing for submodules containing BatchNorm to avoid double-updating + running statistics during recomputation. + """ def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module + self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self.training and torch.is_grad_enabled() and x.requires_grad: + if self.training and torch.is_grad_enabled() and x.requires_grad: + if self._has_bn: + warnings.warn( + "Activation checkpointing skipped for a subblock containing BatchNorm " + "to avoid double-updating running statistics during recomputation.", + RuntimeWarning, + ) + return cast(torch.Tensor, self.module(x)) try: return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) except TypeError: # Fallback for older PyTorch without `use_reentrant` return cast(torch.Tensor, checkpoint(self.module, x)) return cast(torch.Tensor, self.module(x))
🧹 Nitpick comments (3)
monai/networks/nets/unet.py (3)
90-92: Clarify arg docs and surface BN caveat.
Tighten wording and document BN behavior for transparency.Apply this diff:
- use_checkpointing: if True, apply activation checkpointing to internal sub-blocks during training to reduce memory - at the cost of extra compute. Checkpointing is bypassed in eval and when gradients are disabled. Defaults to False. + use_checkpointing: If True, applies activation checkpointing to internal sub-blocks during training to reduce + memory at the cost of extra compute. Bypassed in eval mode and when gradients are disabled. + Note: sub-blocks containing BatchNorm are executed without checkpointing to avoid double-updating + running statistics. Defaults to False.
217-219: Placement of wrapper is sensible; consider optional breadth control.
Future enhancement: expose a knob to checkpoint down/up paths too for deeper memory savings on very deep nets.
141-142: Add tests to lock behavior.
- Parity: forward/backward equivalence (outputs/grad norms) with vs. without checkpointing.
- Modes: train vs. eval; torch.no_grad().
- Norms: with InstanceNorm and with BatchNorm (assert BN path skips with warning).
I can draft unit tests targeting UNet’s smallest config to keep runtime minimal—want me to open a follow-up?
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(6 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
🔇 Additional comments (3)
monai/networks/nets/unet.py (3)
16-21: LGTM: imports for cast/checkpoint are correct.
Direct import of checkpoint and use of typing.cast are appropriate.
35-42: Validate AMP behavior under fallback (reentrant) checkpointing.
Older Torch (fallback path) may not replay autocast exactly; please verify mixed-precision parity.Minimal check: run a forward/backward with torch.autocast and compare loss/grad norms with and without checkpointing on a small UNet to ensure deltas are within numerical noise.
141-142: API addition looks good.
Name and default match MONAI conventions.
|
Hi @ferreirafabio80 thanks for the contribution but I would suggest this isn't necessarily the way to go with adapting this class. Perhaps instead you can create a subclass of class CheckpointUNet(UNet):
def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module:
subblock = _ActivationCheckpointWrapper(subblock)
return super()._get_connection_block(down_path, up_path, subblock)This would suffice for your own use if you just wanted such a definition. I think the I see also that |
|
Hi @ericspod, thank you for your comments. Yes, that also works. I've defined a subclass and overridden the method as you suggested. Regarding the I was probably being extremely careful with the checks in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
♻️ Duplicate comments (2)
monai/networks/nets/unet.py (2)
35-36: Missing training and gradient guards causes eval overhead and no_grad crashes.The forward unconditionally calls checkpoint. This will:
- Apply checkpointing during inference (eval mode) → unnecessary compute overhead.
- Fail under
torch.no_grad()→ runtime error.Apply this diff:
def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))
29-37: BatchNorm in checkpointed subblocks will double-update running stats.Checkpoint recomputes the forward pass during backward, causing BatchNorm layers to update
running_mean/running_vartwice per training step, skewing statistics.Consider detecting BatchNorm in
__init__and either warning or skipping checkpoint:class _ActivationCheckpointWrapper(nn.Module): """Apply activation checkpointing to the wrapped module during training.""" def __init__(self, module: nn.Module) -> None: super().__init__() self.module = module + self._has_bn = any(isinstance(m, nn.modules.batchnorm._BatchNorm) for m in module.modules()) def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + if self._has_bn: + warnings.warn( + "Activation checkpointing skipped for subblock with BatchNorm to avoid double-update of running stats.", + RuntimeWarning, + ) + return cast(torch.Tensor, self.module(x)) + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(4 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: build-docs
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (windows-latest)
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)
16-16: LGTM on imports.Both
castandcheckpointare used in the new wrapper and are correctly imported.Also applies to: 20-20
316-316: Clarify checkpointing scope: onlysubblockvs. entire connection block.Only
subblock(the recursive nested structure) is wrapped, whiledown_pathandup_path(encoder/decoder convolutions at each level) are not checkpointed. Is this intentional?Typical UNet checkpointing strategies checkpoint entire encoder/decoder blocks for maximum memory savings. Consider whether
down_pathandup_pathshould also be wrapped, or document the rationale for checkpointing only the recursive substructure.# Alternative: checkpoint all three components def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: down_path = _ActivationCheckpointWrapper(down_path) up_path = _ActivationCheckpointWrapper(up_path) subblock = _ActivationCheckpointWrapper(subblock) return super()._get_connection_block(down_path, up_path, subblock)
monai/networks/nets/unet.py
Outdated
| class _ActivationCheckpointWrapper(nn.Module): | ||
| """Apply activation checkpointing to the wrapped module during training.""" | ||
| def __init__(self, module: nn.Module) -> None: | ||
| super().__init__() | ||
| self.module = module | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add comprehensive Google-style docstring.
Per coding guidelines, document all parameters, return values, and behavior.
As per coding guidelines.
Apply this diff:
class _ActivationCheckpointWrapper(nn.Module):
- """Apply activation checkpointing to the wrapped module during training."""
+ """
+ Wrapper applying activation checkpointing to a module during training.
+
+ During the forward pass in training mode, intermediate activations are not stored;
+ they are recomputed during the backward pass to reduce memory usage.
+
+ Args:
+ module: The module to wrap with activation checkpointing.
+ """
def __init__(self, module: nn.Module) -> None:
super().__init__()
self.module = module
def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass with optional activation checkpointing.
+
+ Args:
+ x: Input tensor.
+
+ Returns:
+ Output tensor from the wrapped module.
+ """
return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| class _ActivationCheckpointWrapper(nn.Module): | |
| """Apply activation checkpointing to the wrapped module during training.""" | |
| def __init__(self, module: nn.Module) -> None: | |
| super().__init__() | |
| self.module = module | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) | |
| class _ActivationCheckpointWrapper(nn.Module): | |
| """ | |
| Wrapper applying activation checkpointing to a module during training. | |
| During the forward pass in training mode, intermediate activations are not stored; | |
| they are recomputed during the backward pass to reduce memory usage. | |
| Args: | |
| module: The module to wrap with activation checkpointing. | |
| """ | |
| def __init__(self, module: nn.Module) -> None: | |
| super().__init__() | |
| self.module = module | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass with optional activation checkpointing. | |
| Args: | |
| x: Input tensor. | |
| Returns: | |
| Output tensor from the wrapped module. | |
| """ | |
| return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) |
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 29 to 37, the
_ActivationCheckpointWrapper class and its forward method lack a comprehensive
Google-style docstring; add a docstring for the class describing its purpose
(applies activation checkpointing to a wrapped nn.Module during training),
document the constructor parameter module (type and role) and any stored
attributes, and add a Google-style docstring on forward documenting Args (x:
torch.Tensor — shape/expected dtype/context), Returns (torch.Tensor — shape and
that it is the wrapped module's output), and Behavior/Raises (explain that
checkpoint(...) is used with use_reentrant=False, when it will be active—only
during forward in training—and any exceptions propagated from the wrapped
module); keep wording concise and follow Google docstring sections (Args,
Returns, Raises, and a short Example/Notes if useful).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
monai/networks/nets/unet.py (2)
26-26: ExportCheckpointUNetvia__all__.Public class isn’t exported; importing it from
monai.networks.netswill fail. Add it to__all__.-__all__ = ["UNet", "Unet"] +__all__ = ["UNet", "Unet", "CheckpointUNet"]
36-37: Guard checkpoint under grad-enabled training only.Forward unconditionally checkpointing will crash under
torch.no_grad()(common in eval) because checkpoint requires gradients. Gate onself.training,torch.is_grad_enabled(), and a grad-requiring input before calling checkpoint.def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + if self.training and torch.is_grad_enabled() and x.requires_grad: + return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) + return cast(torch.Tensor, self.module(x))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
monai/networks/nets/unet.py
Outdated
| class _ActivationCheckpointWrapper(nn.Module): | ||
| """Apply activation checkpointing to the wrapped module during training.""" | ||
|
|
||
| def __init__(self, module: nn.Module) -> None: | ||
| super().__init__() | ||
| self.module = module | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add Google-style docstrings.
Class and forward docstrings need Args/Returns sections per guidelines. Document the wrapped module, checkpoint guard details, and returned tensor.
As per coding guidelines.
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 29 to 37, the
_ActivationCheckpointWrapper class and its forward method lack Google-style
docstrings; add a class-level docstring that briefly describes purpose, document
the module parameter as "module: nn.Module — module to wrap for activation
checkpointing", mention that checkpointing is applied during training to save
memory and that use_reentrant=False is used as the checkpoint guard, and add a
forward method docstring with Args: x (torch.Tensor): input tensor to the
wrapped module and Returns: torch.Tensor: output tensor from the wrapped module
(with activations checkpointed); keep wording concise and follow Google-style
"Args/Returns" formatting.
Sorry for the delay, I think we should put this into a new file in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/nets/unet.py (1)
25-25: Export the new public variant.
ExposeCheckpointUNetalongsideUNet/Unetsofrom monai.networks.nets import CheckpointUNetworks consistently.-__all__ = ["UNet", "Unet"] +__all__ = ["UNet", "Unet", "CheckpointUNet"]
♻️ Duplicate comments (1)
monai/networks/nets/unet.py (1)
302-307: Docstring required forCheckpointUNet.
Add a Google-style class docstring describing the checkpointing behavior, inherited args, and trade-offs per our guidelines.class CheckpointUNet(UNet): + """ + UNet variant that wraps internal connection blocks with activation checkpointing. + + See `UNet` for constructor arguments. During training with gradients enabled, + intermediate activations inside encoder–decoder connections are recomputed in + the backward pass to reduce peak memory usage at the cost of extra compute. + """
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/networks/blocks/activation_checkpointing.py(1 hunks)monai/networks/nets/unet.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/blocks/activation_checkpointing.pymonai/networks/nets/unet.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: packaging
- GitHub Check: build-docs
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-pytorch (2.6.0)
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Forward pass with optional activation checkpointing. | ||
| Args: | ||
| x: Input tensor. | ||
| Returns: | ||
| Output tensor from the wrapped module. | ||
| """ | ||
| return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gate checkpointing to active training passes.
The docstring promises training-only checkpointing, but forward always recomputes, so eval/no-grad still pays the checkpoint dispatch. Wrap the call with self.training, torch.is_grad_enabled(), and an x.requires_grad check, falling back to the plain module call otherwise, to avoid needless recompute overhead while preserving the memory trade-off during training.(docs.pytorch.org)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with optional activation checkpointing.
Args:
x: Input tensor.
Returns:
Output tensor from the wrapped module.
"""
- return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+ if self.training and torch.is_grad_enabled() and x.requires_grad:
+ return cast(torch.Tensor, checkpoint(self.module, x, use_reentrant=False))
+ return cast(torch.Tensor, self.module(x))🤖 Prompt for AI Agents
In monai/networks/blocks/activation_checkpointing.py around lines 32 to 41,
forward always calls checkpoint(self.module, x, use_reentrant=False) even during
eval/no-grad, causing unnecessary recompute; change it to only use
torch.utils.checkpoint when running training and gradients are enabled: check
self.training and torch.is_grad_enabled() and that input tensor x.requires_grad
before calling checkpoint(..., use_reentrant=False); otherwise call and return
self.module(x) directly to avoid unnecessary checkpoint overhead while
preserving training memory savings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think just change the docstring here instead of the code change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ericspod Understood. Update the docstrings to reflect unconditional checkpointing:
Class docstring:
"""Wrapper applying activation checkpointing to a module.
Checkpointing trades compute for memory by recomputing activations during
the backward pass instead of storing them. Applied during both training and eval.
Args:
module: The module to wrap with activation checkpointing.
"""Forward method docstring:
"""Forward pass with activation checkpointing.
Args:
x: Input tensor.
Returns:
Output tensor from the wrapped module.
"""Key changes: removed "during training" and "optional" language to match the unconditional behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (1)
monai/networks/nets/unet.py (1)
25-25: Sort__all__alphabetically.-__all__ = ["UNet", "Unet", "CheckpointUNet"] +__all__ = ["CheckpointUNet", "UNet", "Unet"]
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/networks/nets/unet.py(2 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.py
🪛 Ruff (0.14.3)
monai/networks/nets/unet.py
25-25: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
306-306: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
🔇 Additional comments (1)
monai/networks/nets/unet.py (1)
20-20: LGTM.Import is clean and the wrapper is properly used in
CheckpointUNet.
| class CheckpointUNet(UNet): | ||
| """UNet variant that wraps internal connection blocks with activation checkpointing. | ||
| See `UNet` for constructor arguments. During training with gradients enabled, | ||
| intermediate activations inside encoder–decoder connections are recomputed in | ||
| the backward pass to reduce peak memory usage at the cost of extra compute. | ||
| """ | ||
|
|
||
| def _get_connection_block(self, down_path: nn.Module, up_path: nn.Module, subblock: nn.Module) -> nn.Module: | ||
| subblock = ActivationCheckpointWrapper(subblock) | ||
| down_path = ActivationCheckpointWrapper(down_path) | ||
| up_path = ActivationCheckpointWrapper(up_path) | ||
| return super()._get_connection_block(down_path, up_path, subblock) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify unit test coverage for CheckpointUNet.
Past review comments requested tests that confirm: (a) forward parity with UNet in eval/no-grad mode, (b) checkpointing engages only during training.
🏁 Script executed:
#!/bin/bash
# Search for test files covering CheckpointUNet
rg -n --type=py -C3 'CheckpointUNet|CheckpointUnet|checkpointunet' -g '*test*.py'Length of output: 0
---
🏁 Script executed:
```shell
#!/bin/bash
# Broader search for test files in common test directories
find . -type f -name "*test*.py" -path "*/test*" | head -20
Length of output: 1066
🏁 Script executed:
#!/bin/bash
# Search for CheckpointUNet anywhere in the codebase (not just test files)
rg -n 'CheckpointUNet' --type=pyLength of output: 135
🏁 Script executed:
#!/bin/bash
# Check for UNet tests
rg -n 'class.*UNet.*Test|def test.*unet' --type=py -iLength of output: 3075
---
🏁 Script executed:
```shell
#!/bin/bash
# Inspect test_unet.py for any CheckpointUNet references
rg -n 'CheckpointUNet|checkpoint' tests/networks/nets/test_unet.py -i
Length of output: 0
Add unit tests for CheckpointUNet to validate wrapper behavior.
CheckpointUNet has zero test coverage. Tests must confirm: (a) forward pass produces identical results to parent UNet in eval mode or with torch.no_grad(), and (b) activation checkpointing engages only during training with gradients enabled.
🧰 Tools
🪛 Ruff (0.14.3)
306-306: Docstring contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF002)
🤖 Prompt for AI Agents
In monai/networks/nets/unet.py around lines 302 to 314, add unit tests (e.g.,
monai/networks/nets/test_unet_checkpoint.py) that create a UNet and a
CheckpointUNet with identical seeded initialization and the same input tensor;
assert their forward outputs are equal in eval mode and when wrapped with
torch.no_grad(); then verify activation checkpointing is active only during
training with gradients by monkeypatching or wrapping
ActivationCheckpointWrapper.forward to count invocations: run a training
forward+backward (output.sum().backward()) with requires_grad enabled and assert
the wrapper.forward is invoked more than once (indicating recomputation), and
run the same in eval or torch.no_grad() and assert it is invoked exactly once.
Ensure deterministic seeding, zero gradients between runs, and use
torch.allclose with a tight tolerance for output comparisons.
|
@ericspod I've moved the wrapper to a different script, added docstrings and a test (which is literally a copy of the unet one). Let me know if this is sensible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
monai/networks/nets/unet.py (1)
25-25: Consider sorting__all__alphabetically.Static analysis suggests sorting:
["CheckpointUNet", "UNet", "Unet"].-__all__ = ["UNet", "Unet", "CheckpointUNet"] +__all__ = ["CheckpointUNet", "UNet", "Unet"]
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (2)
monai/networks/nets/unet.py(2 hunks)tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/nets/unet.pytests/networks/nets/test_checkpointunet.py
🪛 Ruff (0.14.3)
monai/networks/nets/unet.py
25-25: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: packaging
- GitHub Check: build-docs
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (ubuntu-latest)
🔇 Additional comments (2)
monai/networks/nets/unet.py (2)
20-20: LGTM!Import is clean and appropriately placed.
302-324: Implementation and docstrings are solid.The subclass correctly wraps all connection-block components before delegating to the parent. Docstrings follow Google style per coding guidelines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tests/networks/nets/test_checkpointunet.py (1)
167-204: Add tests verifying checkpointing activates during training.All tests run in eval mode (via
eval_modeortest_script_save), so checkpointing is never engaged. Must verify: (a) forward pass matches UNet in eval mode, and (b) checkpointing works during training with gradients.Add two tests as suggested in the past review:
def test_checkpoint_parity_eval(self): """Verify CheckpointUNet matches UNet output in eval mode.""" torch.manual_seed(0) from monai.networks.nets.unet import UNet config = {"spatial_dims": 2, "in_channels": 1, "out_channels": 3, "channels": (16, 32, 64), "strides": (2, 2), "num_res_units": 1} unet = UNet(**config).to(device) checkpoint_unet = CheckpointUNet(**config).to(device) checkpoint_unet.load_state_dict(unet.state_dict()) test_input = torch.randn(2, 1, 32, 32).to(device) with eval_mode(unet), eval_mode(checkpoint_unet): out_unet = unet(test_input) out_checkpoint = checkpoint_unet(test_input) self.assertTrue(torch.allclose(out_unet, out_checkpoint, atol=1e-6)) def test_checkpoint_engages_training(self): """Verify checkpointing activates during training.""" net = CheckpointUNet( spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=1 ).to(device) net.train() test_input = torch.randn(2, 1, 32, 32, requires_grad=True, device=device) output = net(test_input) loss = output.sum() loss.backward() self.assertIsNotNone(test_input.grad)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/nets/test_checkpointunet.py
| def test_shape(self, input_param, input_shape, expected_shape): | ||
| net = CheckpointUNet(**input_param).to(device) | ||
| with eval_mode(net): | ||
| result = net.forward(torch.randn(input_shape).to(device)) | ||
| self.assertEqual(result.shape, expected_shape) | ||
|
|
||
| def test_script(self): | ||
| net = CheckpointUNet( | ||
| spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0 | ||
| ) | ||
| test_data = torch.randn(16, 1, 32, 32) | ||
| test_script_save(net, test_data) | ||
|
|
||
| def test_script_without_running_stats(self): | ||
| net = CheckpointUNet( | ||
| spatial_dims=2, | ||
| in_channels=1, | ||
| out_channels=3, | ||
| channels=(16, 32, 64), | ||
| strides=(2, 2), | ||
| num_res_units=0, | ||
| norm=("batch", {"track_running_stats": False}), | ||
| ) | ||
| test_data = torch.randn(16, 1, 16, 4) | ||
| test_script_save(net, test_data) | ||
|
|
||
| def test_ill_input_shape(self): | ||
| net = CheckpointUNet(spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2)) | ||
| with eval_mode(net): | ||
| with self.assertRaisesRegex(RuntimeError, "Sizes of tensors must match"): | ||
| net.forward(torch.randn(2, 1, 16, 5)) | ||
|
|
||
| @parameterized.expand(ILL_CASES) | ||
| def test_ill_input_hyper_params(self, input_param): | ||
| with self.assertRaises(ValueError): | ||
| _ = CheckpointUNet(**input_param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add Google-style docstrings to all test methods.
Every test method lacks a docstring. Per coding guidelines, docstrings must describe the test's purpose, parameters, and expected behavior.
Example:
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape):
+ """Verify CheckpointUNet produces expected output shapes across configurations.
+
+ Args:
+ input_param: Dictionary of UNet constructor arguments.
+ input_shape: Tuple specifying input tensor dimensions.
+ expected_shape: Tuple specifying expected output tensor dimensions.
+ """
net = CheckpointUNet(**input_param).to(device)Apply similar docstrings to test_script, test_script_without_running_stats, test_ill_input_shape, and test_ill_input_hyper_params.
🤖 Prompt for AI Agents
In tests/networks/nets/test_checkpointunet.py around lines 169 to 204, the test
methods test_script, test_script_without_running_stats, test_ill_input_shape,
and test_ill_input_hyper_params lack Google-style docstrings; add a short
triple-quoted Google-style docstring to each method that states the test
purpose, any important parameters or setup (e.g., network config or test data
shapes), and the expected behavior or assertion (e.g., successful scripting,
scripting with no running stats, raising RuntimeError for mismatched input
shape, raising ValueError for invalid hyper-params). Keep docstrings concise
(one to three lines) and follow the Google style: a short description, optional
Args: line for notable inputs, and a Returns: or Raises: line only if relevant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
tests/networks/nets/test_checkpointunet.py (3)
81-86: Add Google-style docstring.Per coding guidelines, test methods require docstrings describing purpose, parameters, and expected behavior.
Apply this diff:
@parameterized.expand(CASES) def test_shape(self, input_param, input_shape, expected_shape): + """Verify CheckpointUNet produces expected output shapes. + + Args: + input_param: UNet constructor arguments. + input_shape: Input tensor dimensions. + expected_shape: Expected output tensor dimensions. + """ net = CheckpointUNet(**input_param).to(device)
88-93: Add Google-style docstring.Per coding guidelines, test methods require docstrings.
Apply this diff:
def test_script(self): + """Verify CheckpointUNet is scriptable via TorchScript.""" net = CheckpointUNet(
95-99: Add Google-style docstring.Per coding guidelines, test methods require docstrings.
Apply this diff:
def test_ill_input_shape(self): + """Verify RuntimeError raised for mismatched input shape.""" net = CheckpointUNet(spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2))
🧹 Nitpick comments (1)
tests/networks/nets/test_checkpointunet.py (1)
119-139: Test logic is sound; minor redundancy in assertion.The test verifies gradient propagation during training, indirectly confirming checkpointing works. Line 139's
assertIsNotNone(grad_norm)is redundant sincegrad_normis always a tensor after line 134.Optionally remove the redundant assertion:
# gradient flow check grad_norm = sum(p.grad.abs().sum() for p in net.parameters() if p.grad is not None) self.assertGreater(grad_norm.item(), 0.0) # checkpointing should reduce activation memory use; we can't directly assert memory savings # but we can confirm no runtime errors and gradients propagate correctly - self.assertIsNotNone(grad_norm)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting
📒 Files selected for processing (1)
tests/networks/nets/test_checkpointunet.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/networks/nets/test_checkpointunet.py
🔇 Additional comments (2)
tests/networks/nets/test_checkpointunet.py (2)
1-23: Imports and setup look correct.All necessary components are imported and device selection follows standard patterns.
25-77: Test case definitions provide good coverage.Cases cover 2D/3D variants, different channel counts, and residual unit configurations.
|
|
||
| # checkpointing should reduce activation memory use; we can't directly assert memory savings | ||
| # but we can confirm no runtime errors and gradients propagate correctly | ||
| self.assertIsNotNone(grad_norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check will never be reached when grad_norm is None because the previous one would raise an exception when evaluating grad_norm.item().
This looks much better, thanks. Please do work on the testing issues and the DCO issue can be left until last. |
for more information, see https://pre-commit.ci
Description
Introduces an optional
use_checkpointingflag in theUNetimplementation. When enabled, intermediate activations in the encoder–decoder blocks are recomputed during the backward pass instead of being stored in memory._ActivationCheckpointWrapperwrapper around sub-blocks.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.