From df244cc0f69b8cdf4a80735b4da693bd4c1e9561 Mon Sep 17 00:00:00 2001 From: Aroj Hada Date: Thu, 27 Mar 2025 16:30:20 +0100 Subject: [PATCH 1/3] torch.load fix for all pytorch versions. --- src/pytorch_tabular/tabular_datamodule.py | 7 ++++++- src/pytorch_tabular/utils/python_utils.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 7fd68dbf..ca165d16 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -758,7 +758,12 @@ def _load_dataset_from_cache(self, tag: str = "train"): ) elif self.cache_mode is self.CACHE_MODES.DISK: try: - dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False) + # get the torch version + torch_version = torch.__version__ + if torch_version < "2.6": + dataset = torch.load(self.cache_dir / f"{tag}_dataset") # fix for torch version change of torch.load + elif torch_version >= "2.6": + dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False) except FileNotFoundError: raise FileNotFoundError( f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader" diff --git a/src/pytorch_tabular/utils/python_utils.py b/src/pytorch_tabular/utils/python_utils.py index 57176fdc..98208ae4 100644 --- a/src/pytorch_tabular/utils/python_utils.py +++ b/src/pytorch_tabular/utils/python_utils.py @@ -74,7 +74,12 @@ def pl_load( """ if not isinstance(path_or_url, (str, Path)): # any sort of BytesIO or similar - return torch.load(path_or_url, map_location=map_location, weights_only=False) + # get the torch version + torch_version = torch.__version__ + if torch_version < "2.6": + return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6 + elif torch_version >= "2.6": + return torch.load(path_or_url, map_location=map_location, weights_only=False) if str(path_or_url).startswith("http"): return torch.hub.load_state_dict_from_url( str(path_or_url), @@ -82,7 +87,11 @@ def pl_load( ) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: - return torch.load(f, map_location=map_location, weights_only=False) + if torch_version < "2.6": + return torch.load(f, map_location=map_location) # for torch version < 2.6 + elif torch_version >= "2.6": + return torch.load(f, map_location=map_location, weights_only=False) + def check_numpy(x): From 89aefc7e4f6377e993ac1fc06db9343ab48f392c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Mar 2025 15:41:25 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_tabular/tabular_datamodule.py | 6 ++++-- src/pytorch_tabular/utils/python_utils.py | 5 ++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index ca165d16..7b1e89f8 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -761,9 +761,11 @@ def _load_dataset_from_cache(self, tag: str = "train"): # get the torch version torch_version = torch.__version__ if torch_version < "2.6": - dataset = torch.load(self.cache_dir / f"{tag}_dataset") # fix for torch version change of torch.load + dataset = torch.load( + self.cache_dir / f"{tag}_dataset" + ) # fix for torch version change of torch.load elif torch_version >= "2.6": - dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False) + dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False) except FileNotFoundError: raise FileNotFoundError( f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader" diff --git a/src/pytorch_tabular/utils/python_utils.py b/src/pytorch_tabular/utils/python_utils.py index 98208ae4..795ebf22 100644 --- a/src/pytorch_tabular/utils/python_utils.py +++ b/src/pytorch_tabular/utils/python_utils.py @@ -77,7 +77,7 @@ def pl_load( # get the torch version torch_version = torch.__version__ if torch_version < "2.6": - return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6 + return torch.load(path_or_url, map_location=map_location) # for torch version < 2.6 elif torch_version >= "2.6": return torch.load(path_or_url, map_location=map_location, weights_only=False) if str(path_or_url).startswith("http"): @@ -88,10 +88,9 @@ def pl_load( fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: if torch_version < "2.6": - return torch.load(f, map_location=map_location) # for torch version < 2.6 + return torch.load(f, map_location=map_location) # for torch version < 2.6 elif torch_version >= "2.6": return torch.load(f, map_location=map_location, weights_only=False) - def check_numpy(x): From d359fc2b9d517ed0a4080f15614e9d39f2bf7624 Mon Sep 17 00:00:00 2001 From: Aroj Hada Date: Thu, 27 Mar 2025 23:00:16 +0100 Subject: [PATCH 3/3] Update python_utils.py to fix minor bug --- src/pytorch_tabular/utils/python_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_tabular/utils/python_utils.py b/src/pytorch_tabular/utils/python_utils.py index 795ebf22..e08503ed 100644 --- a/src/pytorch_tabular/utils/python_utils.py +++ b/src/pytorch_tabular/utils/python_utils.py @@ -87,6 +87,7 @@ def pl_load( ) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: + torch_version = torch.__version__ if torch_version < "2.6": return torch.load(f, map_location=map_location) # for torch version < 2.6 elif torch_version >= "2.6":