diff --git a/sdkit/__init__.py b/sdkit/__init__.py index 26e6c96..e81fbe5 100644 --- a/sdkit/__init__.py +++ b/sdkit/__init__.py @@ -1,9 +1,15 @@ from threading import local - +from torch.cuda import is_available as cuda_available +from logging import getLogger class Context(local): def __init__(self) -> None: - self._device: str = "cuda:0" + + self._device: str = 'cuda:0' + if not cuda_available(): + getLogger('sdkit').warning("CUDA device not found, fallback to cpu device.") + self._device: str = 'cpu' + self._half_precision: bool = True self._vram_usage_level = None