From f5d667f70b8f73a0b6d9800c621c1d5ae18a842b Mon Sep 17 00:00:00 2001 From: singjc Date: Thu, 28 Nov 2024 20:52:24 -0500 Subject: [PATCH 01/14] add: Polars Accessor class and copy or clone method --- pyopenms_viz/_core.py | 105 +++++++++++++++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 21 deletions(-) diff --git a/pyopenms_viz/_core.py b/pyopenms_viz/_core.py index 50782f40..c52995a5 100644 --- a/pyopenms_viz/_core.py +++ b/pyopenms_viz/_core.py @@ -25,6 +25,8 @@ from .constants import IS_SPHINX_BUILD import warnings +import polars as pl + _common_kinds = ("line", "vline", "scatter") _msdata_kinds = ("chromatogram", "mobilogram", "spectrum", "peakmap") @@ -150,7 +152,7 @@ def __init__( ) -> None: # Data attributes - self.data = data.copy() + self.data = self._copy_or_clone(data) self.kind = kind self.by = by self.plot_3d = plot_3d @@ -222,6 +224,15 @@ def __init__( self._load_extension() self._create_figure() + def _copy_or_clone(self, data): + """Return a copy or clone of the provided DataFrame based on its type.""" + if isinstance(data, ABCDataFrame): + return data.copy() + elif isinstance(data, pl.DataFrame): + return data.clone() + else: + raise TypeError("Unsupported data type. Must be either pandas DataFrame or Polars DataFrame.") + def _check_and_aggregate_duplicates(self): """ Check if duplicate data is present and aggregate if specified. @@ -238,18 +249,30 @@ def _check_and_aggregate_duplicates(self): col for col in self.known_columns if col != self.y ] - if self.data[known_columns_without_int].duplicated().any(): - if self.aggregate_duplicates: - self.data = ( - self.data[self.known_columns] - .groupby(known_columns_without_int) - .sum() - .reset_index() - ) - else: - warnings.warn( - "Duplicate data detected, data will not be aggregated which may lead to unexpected plots. To enable aggregation set `aggregate_duplicates=True`." - ) + if isinstance(self.data, ABCDataFrame): + if self.data[known_columns_without_int].duplicated().any(): + if self.aggregate_duplicates: + self.data = ( + self.data[self.known_columns] + .groupby(known_columns_without_int) + .sum() + .reset_index() + ) + else: + warnings.warn( + "Duplicate data detected, data will not be aggregated which may lead to unexpected plots. To enable aggregation set `aggregate_duplicates=True`." + ) + elif isinstance(self.data, pl.DataFrame): + if self.data[known_columns_without_int].is_duplicated().any(): + if self.aggregate_duplicates: + self.data = ( + self.data.groupby(known_columns_without_int) + .agg(pl.sum(pl.col(self.known_columns))) + ) + else: + warnings.warn( + "Duplicate data detected, data will not be aggregated which may lead to unexpected plots. To enable aggregation set `aggregate_duplicates=True`." + ) def _verify_column(self, colname: str | int, name: str) -> str: """fetch data from column name @@ -564,7 +587,7 @@ def __init__( super().__init__(data, x, y, **kwargs) if annotation_data is not None: - self.annotation_data = annotation_data.copy() + self.annotation_data = self._copy_or_clone(annotation_data) else: self.annotation_data = None self.label_suffix = self.x # set label suffix for bounding box @@ -918,9 +941,9 @@ def _prepare_data( """Prepares data for plotting based on configuration (copy, relative intensity, bin peaks).""" # copy spectrum data to not modify the original - spectrum = spectrum.copy() + spectrum = self._copy_or_clone(spectrum) reference_spectrum = ( - self.reference_spectrum.copy() if reference_spectrum is not None else None + self._copy_or_clone(self.reference_spectrum) if reference_spectrum is not None else None ) # Convert to relative intensity if required @@ -1141,7 +1164,7 @@ def __init__( self.fill_by_z = fill_by_z if annotation_data is not None: - self.annotation_data = annotation_data.copy() + self.annotation_data = self._copy_or_clone(annotation_data) else: self.annotation_data = None self.annotation_x_lb = annotation_x_lb @@ -1439,7 +1462,7 @@ def _get_call_args(backend_name: str, data: DataFrame, args, kwargs): dict The arguments to pass to the plotting backend. """ - if isinstance(data, ABCDataFrame): + if isinstance(data, ABCDataFrame) or isinstance(data, pl.DataFrame): arg_def = [ ("x", None), ("y", None), @@ -1508,7 +1531,7 @@ def _load_backend(backend: str) -> types.ModuleType: """ if backend == "bokeh": try: - module = importlib.import_module("pyopenms_viz.plotting._bokeh") + module = importlib.import_module("pyopenms_viz._bokeh") except ImportError: raise ImportError( "Bokeh is required for plotting when the 'bokeh' backend is selected." @@ -1517,7 +1540,7 @@ def _load_backend(backend: str) -> types.ModuleType: elif backend == "matplotlib": try: - module = importlib.import_module("pyopenms_viz.plotting._matplotlib") + module = importlib.import_module("pyopenms_viz._matplotlib") except ImportError: raise ImportError( "Matplotlib is required for plotting when the 'matplotlib' backend is selected." @@ -1526,7 +1549,7 @@ def _load_backend(backend: str) -> types.ModuleType: elif backend == "plotly": try: - module = importlib.import_module("pyopenms_viz.plotting._plotly") + module = importlib.import_module("pyopenms_viz._plotly") except ImportError: raise ImportError( "Plotly is required for plotting when the 'plotly' backend is selected." @@ -1548,3 +1571,43 @@ def _get_plot_backend(backend: str | None = None): module = _load_backend(backend_str) _backends[backend_str] = module return module + + + + +@pl.api.register_dataframe_namespace("mass") +class PolarsPyOpenMSViz: + + def __init__(self, df: pl.DataFrame) -> None: + + self._df = df + + def plot(self, x: str, y: str, kind: str = "line", **kwargs) -> Any: + + return PlotAccessor(self._df)(x, y, kind, **kwargs) + + + def by_first_letter_of_column_names(self) -> list[pl.DataFrame]: + + return [ + + self._df.select([col for col in self._df.columns if col[0] == f]) + + for f in dict.fromkeys(col[0] for col in self._df.columns) + + ] + + + def by_first_letter_of_column_values(self, col: str) -> list[pl.DataFrame]: + + return [ + + self._df.filter(pl.col(col).str.starts_with(c)) + + for c in sorted( + + set(self._df.select(pl.col(col).str.slice(0, 1)).to_series()) + + ) + + ] \ No newline at end of file From 28292cdf05f9561613c7cee9b30c04233245e3a4 Mon Sep 17 00:00:00 2001 From: singjc Date: Thu, 28 Nov 2024 23:58:21 -0500 Subject: [PATCH 02/14] add: unified DataFrame class for pandas and polars --- pyopenms_viz/_core.py | 60 ++++++----------- pyopenms_viz/_dataframe.py | 131 +++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+), 39 deletions(-) create mode 100644 pyopenms_viz/_dataframe.py diff --git a/pyopenms_viz/_core.py b/pyopenms_viz/_core.py index c52995a5..dc6e67f1 100644 --- a/pyopenms_viz/_core.py +++ b/pyopenms_viz/_core.py @@ -15,6 +15,7 @@ from numpy import ceil, log1p, log2, nan, mean, repeat, concatenate +from ._dataframe import UnifiedDataFrame from ._config import LegendConfig, FeatureConfig, _BasePlotConfig from ._misc import ( ColorGenerator, @@ -152,7 +153,7 @@ def __init__( ) -> None: # Data attributes - self.data = self._copy_or_clone(data) + self.data = UnifiedDataFrame(data) self.kind = kind self.by = by self.plot_3d = plot_3d @@ -224,15 +225,6 @@ def __init__( self._load_extension() self._create_figure() - def _copy_or_clone(self, data): - """Return a copy or clone of the provided DataFrame based on its type.""" - if isinstance(data, ABCDataFrame): - return data.copy() - elif isinstance(data, pl.DataFrame): - return data.clone() - else: - raise TypeError("Unsupported data type. Must be either pandas DataFrame or Polars DataFrame.") - def _check_and_aggregate_duplicates(self): """ Check if duplicate data is present and aggregate if specified. @@ -249,30 +241,18 @@ def _check_and_aggregate_duplicates(self): col for col in self.known_columns if col != self.y ] - if isinstance(self.data, ABCDataFrame): - if self.data[known_columns_without_int].duplicated().any(): - if self.aggregate_duplicates: - self.data = ( - self.data[self.known_columns] - .groupby(known_columns_without_int) - .sum() - .reset_index() - ) - else: - warnings.warn( - "Duplicate data detected, data will not be aggregated which may lead to unexpected plots. To enable aggregation set `aggregate_duplicates=True`." - ) - elif isinstance(self.data, pl.DataFrame): - if self.data[known_columns_without_int].is_duplicated().any(): - if self.aggregate_duplicates: - self.data = ( - self.data.groupby(known_columns_without_int) - .agg(pl.sum(pl.col(self.known_columns))) - ) - else: - warnings.warn( - "Duplicate data detected, data will not be aggregated which may lead to unexpected plots. To enable aggregation set `aggregate_duplicates=True`." - ) + if self.data[known_columns_without_int].duplicated().any(): + if self.aggregate_duplicates: + self.data = ( + self.data[self.known_columns] + .groupby(known_columns_without_int) + .sum() + .reset_index() + ) + else: + warnings.warn( + "Duplicate data detected, data will not be aggregated which may lead to unexpected plots. To enable aggregation set `aggregate_duplicates=True`." + ) def _verify_column(self, colname: str | int, name: str) -> str: """fetch data from column name @@ -587,7 +567,8 @@ def __init__( super().__init__(data, x, y, **kwargs) if annotation_data is not None: - self.annotation_data = self._copy_or_clone(annotation_data) + annotation_data = UnifiedDataFrame(annotation_data) + self.annotation_data = annotation_data.copy() else: self.annotation_data = None self.label_suffix = self.x # set label suffix for bounding box @@ -748,7 +729,7 @@ def __init__( super().__init__(data, x, y, **kwargs) - self.reference_spectrum = reference_spectrum + self.reference_spectrum = UnifiedDataFrame(reference_spectrum) if reference_spectrum is not None else None self.mirror_spectrum = mirror_spectrum self.relative_intensity = relative_intensity self.bin_peaks = bin_peaks @@ -941,9 +922,9 @@ def _prepare_data( """Prepares data for plotting based on configuration (copy, relative intensity, bin peaks).""" # copy spectrum data to not modify the original - spectrum = self._copy_or_clone(spectrum) + spectrum = spectrum.copy() reference_spectrum = ( - self._copy_or_clone(self.reference_spectrum) if reference_spectrum is not None else None + self.reference_spectrum if reference_spectrum is not None else None ) # Convert to relative intensity if required @@ -1164,7 +1145,8 @@ def __init__( self.fill_by_z = fill_by_z if annotation_data is not None: - self.annotation_data = self._copy_or_clone(annotation_data) + annotation_data = UnifiedDataFrame(annotation_data) + self.annotation_data = annotation_data.copy() else: self.annotation_data = None self.annotation_x_lb = annotation_x_lb diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py new file mode 100644 index 00000000..1df3e1fa --- /dev/null +++ b/pyopenms_viz/_dataframe.py @@ -0,0 +1,131 @@ +from pandas.core.dtypes.generic import ABCDataFrame as PandasDataFrame +from pandas.core.groupby.generic import DataFrameGroupBy as PandasGroupBy +from polars.dataframe.frame import DataFrame as PolarsDataFrame +from polars.series import Series as PolarsSeries +from polars.dataframe.group_by import GroupBy as PolarsGroupBy + +import polars as pl + + +class PandasColumnWrapper: + """Wrapper for Pandas Series to add custom methods.""" + def __init__(self, series): + self.series = series + + def __getattr__(self, name): + """Delegate attribute access to the underlying Pandas Series.""" + return getattr(self.series, name) + + +class PolarsColumnWrapper: + """Wrapper for Polars Series to add custom methods.""" + def __init__(self, series): + self.series = series + + def __getattr__(self, name): + """Delegate attribute access to the underlying Polars Series.""" + return getattr(self.series, name) + + def duplicated(self): + """Return a boolean Series indicating duplicate values.""" + return self.series.is_duplicated() + + def tolist(self): + """Return the Series as a list.""" + return self.series.to_list() + + +class UnifiedDataFrame: + """ + Wrapper class for Pandas and Polars DataFrames to provide a unified interface. + """ + def __init__(self, data): + if isinstance(data, (PandasDataFrame, PolarsDataFrame)): + self.data = data + else: + raise TypeError("Unsupported data type. Must be either pandas DataFrame or Polars DataFrame.") + + def __getitem__(self, key): + """Allow access to columns using bracket notation.""" + if isinstance(self.data, PandasDataFrame): + return PandasColumnWrapper(self.data[key]) + elif isinstance(self.data, PolarsDataFrame): + return PolarsColumnWrapper(self.data[key]) + else: + raise KeyError(f"Column '{key}' not found in DataFrame.") + + def __setitem__(self, key, value): + """Allow assignment to columns using bracket notation.""" + if isinstance(self.data, PandasDataFrame): + self.data[key] = value + elif isinstance(self.data, PolarsDataFrame): + self.data = self.data.with_columns( + PolarsSeries(key, value) + ) + + def __len__(self): + """Return the number of rows in the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return len(self.data) + elif isinstance(self.data, PolarsDataFrame): + return self.data.height + + @property + def columns(self): + """Return a list of column names.""" + if isinstance(self.data, PandasDataFrame): + return self.data.columns.tolist() + elif isinstance(self.data, PolarsDataFrame): + return self.data.columns + + def copy(self): + """Return a copy of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return UnifiedDataFrame(self.data.copy()) + elif isinstance(self.data, PolarsDataFrame): + return UnifiedDataFrame(self.data.clone()) + + def sort_values(self, by, ascending=True): + """Sort the DataFrame by the specified column(s).""" + if isinstance(self.data, PandasDataFrame): + return UnifiedDataFrame(self.data.sort_values(by=by, ascending=ascending).reset_index(drop=True)) + elif isinstance(self.data, PolarsDataFrame): + return UnifiedDataFrame( + self.data.sort(by=by, descending=not ascending).with_row_count().rename({"row_nr": "index"}) + ) + + def reset_index(self, drop=False): + """Reset the index of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return UnifiedDataFrame(self.data.reset_index(drop=drop)) + elif isinstance(self.data, PolarsDataFrame): + # For Polars we can just return the same DataFrame since it doesn't have an index like Pandas. + return UnifiedDataFrame(self.data) + + def iterrows(self): + """Return an iterator for rows of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return self.data.iterrows() + elif isinstance(self.data, PolarsDataFrame): + return enumerate(self.data.iter_rows(named=True)) + + def groupby(self, by): + """Group by specified columns.""" + if isinstance(self.data, PolarsDataFrame): + return UnifiedDataFrame(self.data.groupby(by)) + elif isinstance(self.data, PolarsDataFrame): + return UnifiedDataFrame(self.data.groupby(by)) + + def sum(self): + """Sum the grouped data.""" + if isinstance(self.data, PandasGroupBy): + return UnifiedDataFrame(self.data.sum().reset_index()) + elif isinstance(self.data, PolarsGroupBy): + return UnifiedDataFrame(self.data.agg(pl.sum(pl.col("*")))) + + def tolist(self, column_name): + """Return a list of values from a specified column.""" + if isinstance(self.data, PandasDataFrame): + return self.data[column_name].tolist() + elif isinstance(self.data, PolarsDataFrame): + return self.data[column_name].to_list() \ No newline at end of file From d3b861329246148856a1d0e302c69147bf589bdb Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 01:36:54 -0500 Subject: [PATCH 03/14] add: groupby unification --- pyopenms_viz/_core.py | 4 +- pyopenms_viz/_dataframe.py | 84 ++++++++++++++++++++++++++++---------- 2 files changed, 65 insertions(+), 23 deletions(-) diff --git a/pyopenms_viz/_core.py b/pyopenms_viz/_core.py index dc6e67f1..6661d771 100644 --- a/pyopenms_viz/_core.py +++ b/pyopenms_viz/_core.py @@ -151,9 +151,9 @@ def __init__( _config: _BasePlotConfig | None = None, **kwargs, ) -> None: - + # Data attributes - self.data = UnifiedDataFrame(data) + self.data = UnifiedDataFrame(data) if not isinstance(data, UnifiedDataFrame) else data self.kind = kind self.by = by self.plot_3d = plot_3d diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py index 1df3e1fa..d479bd00 100644 --- a/pyopenms_viz/_dataframe.py +++ b/pyopenms_viz/_dataframe.py @@ -26,6 +26,10 @@ def __getattr__(self, name): """Delegate attribute access to the underlying Polars Series.""" return getattr(self.series, name) + def astype(self, dtype): + """Cast the Series to the specified dtype.""" + return self.series.cast(dtype) + def duplicated(self): """Return a boolean Series indicating duplicate values.""" return self.series.is_duplicated() @@ -34,6 +38,29 @@ def tolist(self): """Return the Series as a list.""" return self.series.to_list() +class GroupedDataFrame: + """Class to handle grouped DataFrames for both Pandas and Polars.""" + def __init__(self, grouped_data, is_pandas=True): + self.grouped_data = grouped_data + self.is_pandas = is_pandas + + def __iter__(self): + """Allow iteration over groups.""" + if self.is_pandas: + for group_name, group_df in self.grouped_data: + yield group_name, UnifiedDataFrame(group_df) + else: + for group_name, group_df in self.grouped_data: + yield group_name, UnifiedDataFrame(group_df) + + def sum(self): + """Sum the grouped data.""" + if self.is_pandas: + summed_data = self.grouped_data.sum().reset_index() + return UnifiedDataFrame(summed_data) + else: + summed_data = self.grouped_data.agg(pl.all().sum()) + return UnifiedDataFrame(summed_data) class UnifiedDataFrame: """ @@ -43,14 +70,20 @@ def __init__(self, data): if isinstance(data, (PandasDataFrame, PolarsDataFrame)): self.data = data else: - raise TypeError("Unsupported data type. Must be either pandas DataFrame or Polars DataFrame.") + raise TypeError(f"Unsupported data type {type(data)}. Must be either pandas DataFrame or Polars DataFrame.") def __getitem__(self, key): """Allow access to columns using bracket notation.""" if isinstance(self.data, PandasDataFrame): - return PandasColumnWrapper(self.data[key]) + if isinstance(key, list): + return UnifiedDataFrame(self.data[key]) + else: + return PandasColumnWrapper(self.data[key]) elif isinstance(self.data, PolarsDataFrame): - return PolarsColumnWrapper(self.data[key]) + if isinstance(key, list): + return UnifiedDataFrame(self.data.select(key)) + else: + return PolarsColumnWrapper(self.data[key]) else: raise KeyError(f"Column '{key}' not found in DataFrame.") @@ -85,14 +118,21 @@ def copy(self): elif isinstance(self.data, PolarsDataFrame): return UnifiedDataFrame(self.data.clone()) - def sort_values(self, by, ascending=True): + def sort_values(self, by, ascending=True, inplace=False): """Sort the DataFrame by the specified column(s).""" - if isinstance(self.data, PandasDataFrame): - return UnifiedDataFrame(self.data.sort_values(by=by, ascending=ascending).reset_index(drop=True)) - elif isinstance(self.data, PolarsDataFrame): - return UnifiedDataFrame( - self.data.sort(by=by, descending=not ascending).with_row_count().rename({"row_nr": "index"}) - ) + if isinstance(self.data, PandasDataFrame): + if inplace: + self.data.sort_values(by=by, ascending=ascending, inplace=True) + else: + sorted_data = self.data.sort_values(by=by, ascending=ascending) + return UnifiedDataFrame(sorted_data) + + elif isinstance(self.data, PolarsDataFrame): + sorted_data = self.data.sort(by=by, descending=not ascending) + if inplace: + self.data = sorted_data + else: + return UnifiedDataFrame(sorted_data) def reset_index(self, drop=False): """Reset the index of the DataFrame.""" @@ -101,6 +141,13 @@ def reset_index(self, drop=False): elif isinstance(self.data, PolarsDataFrame): # For Polars we can just return the same DataFrame since it doesn't have an index like Pandas. return UnifiedDataFrame(self.data) + + def duplicated(self): + """Return a boolean Series indicating duplicate rows.""" + if isinstance(self.data, PandasDataFrame): + return self.data.duplicated() + elif isinstance(self.data, PolarsDataFrame): + return self.data.is_duplicated() def iterrows(self): """Return an iterator for rows of the DataFrame.""" @@ -109,19 +156,14 @@ def iterrows(self): elif isinstance(self.data, PolarsDataFrame): return enumerate(self.data.iter_rows(named=True)) - def groupby(self, by): + def groupby(self, by, sort=True): """Group by specified columns.""" - if isinstance(self.data, PolarsDataFrame): - return UnifiedDataFrame(self.data.groupby(by)) + if isinstance(self.data, PandasDataFrame): + grouped = self.data.groupby(by, sort=sort) + return GroupedDataFrame(grouped) elif isinstance(self.data, PolarsDataFrame): - return UnifiedDataFrame(self.data.groupby(by)) - - def sum(self): - """Sum the grouped data.""" - if isinstance(self.data, PandasGroupBy): - return UnifiedDataFrame(self.data.sum().reset_index()) - elif isinstance(self.data, PolarsGroupBy): - return UnifiedDataFrame(self.data.agg(pl.sum(pl.col("*")))) + grouped = self.data.group_by(by) + return GroupedDataFrame(grouped, is_pandas=False) def tolist(self, column_name): """Return a list of values from a specified column.""" From c5952e92738f2366424d276270b1a3e2fcea75ad Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 09:39:00 -0500 Subject: [PATCH 04/14] move: pl namespace register to unified data frame class --- pyopenms_viz/_core.py | 39 -------------------------------------- pyopenms_viz/_dataframe.py | 11 +++++++++-- 2 files changed, 9 insertions(+), 41 deletions(-) diff --git a/pyopenms_viz/_core.py b/pyopenms_viz/_core.py index 6661d771..800573d3 100644 --- a/pyopenms_viz/_core.py +++ b/pyopenms_viz/_core.py @@ -1554,42 +1554,3 @@ def _get_plot_backend(backend: str | None = None): _backends[backend_str] = module return module - - - -@pl.api.register_dataframe_namespace("mass") -class PolarsPyOpenMSViz: - - def __init__(self, df: pl.DataFrame) -> None: - - self._df = df - - def plot(self, x: str, y: str, kind: str = "line", **kwargs) -> Any: - - return PlotAccessor(self._df)(x, y, kind, **kwargs) - - - def by_first_letter_of_column_names(self) -> list[pl.DataFrame]: - - return [ - - self._df.select([col for col in self._df.columns if col[0] == f]) - - for f in dict.fromkeys(col[0] for col in self._df.columns) - - ] - - - def by_first_letter_of_column_values(self, col: str) -> list[pl.DataFrame]: - - return [ - - self._df.filter(pl.col(col).str.starts_with(c)) - - for c in sorted( - - set(self._df.select(pl.col(col).str.slice(0, 1)).to_series()) - - ) - - ] \ No newline at end of file diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py index d479bd00..06fc1a65 100644 --- a/pyopenms_viz/_dataframe.py +++ b/pyopenms_viz/_dataframe.py @@ -7,6 +7,8 @@ import polars as pl + + class PandasColumnWrapper: """Wrapper for Pandas Series to add custom methods.""" def __init__(self, series): @@ -61,7 +63,8 @@ def sum(self): else: summed_data = self.grouped_data.agg(pl.all().sum()) return UnifiedDataFrame(summed_data) - + +@pl.api.register_dataframe_namespace("mass") class UnifiedDataFrame: """ Wrapper class for Pandas and Polars DataFrames to provide a unified interface. @@ -164,7 +167,11 @@ def groupby(self, by, sort=True): elif isinstance(self.data, PolarsDataFrame): grouped = self.data.group_by(by) return GroupedDataFrame(grouped, is_pandas=False) - + + def plot(self, x: str, y: str, kind: str = "line", **kwargs): + from ._core import PlotAccessor + return PlotAccessor(self.data)(x, y, kind, **kwargs) + def tolist(self, column_name): """Return a list of values from a specified column.""" if isinstance(self.data, PandasDataFrame): From 384c0f51ed725ab5940c6c7896a7295a476e9d79 Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 11:03:25 -0500 Subject: [PATCH 05/14] Add: import _dataframe to register namespace for polars --- pyopenms_viz/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyopenms_viz/__init__.py b/pyopenms_viz/__init__.py index 078c1245..9f606a33 100644 --- a/pyopenms_viz/__init__.py +++ b/pyopenms_viz/__init__.py @@ -10,6 +10,8 @@ import types from pathlib import Path +import pyopenms_viz._dataframe # noqa: F401 + __version__ = "0.1.5" From 05ed61da453291ca88db8c9e0538d729a9e654b1 Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 11:24:30 -0500 Subject: [PATCH 06/14] fix: Groupby wrapper, return first entry for polars, as it returns a single tuple everytime --- pyopenms_viz/_dataframe.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py index 06fc1a65..30e327ca 100644 --- a/pyopenms_viz/_dataframe.py +++ b/pyopenms_viz/_dataframe.py @@ -53,7 +53,7 @@ def __iter__(self): yield group_name, UnifiedDataFrame(group_df) else: for group_name, group_df in self.grouped_data: - yield group_name, UnifiedDataFrame(group_df) + yield group_name[0], UnifiedDataFrame(group_df) def sum(self): """Sum the grouped data.""" @@ -106,6 +106,14 @@ def __len__(self): elif isinstance(self.data, PolarsDataFrame): return self.data.height + @property + def index(self): + """Return the index of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return self.data.index + elif isinstance(self.data, PolarsDataFrame): + return list(range(self.data.height)) + @property def columns(self): """Return a list of column names.""" From c7f3f4cca8d9901fa8601e719fdb9c1c0eda1c7e Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 11:37:44 -0500 Subject: [PATCH 07/14] add: __repr__ to unified dataframe class --- pyopenms_viz/_dataframe.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py index 30e327ca..f5cf7cf4 100644 --- a/pyopenms_viz/_dataframe.py +++ b/pyopenms_viz/_dataframe.py @@ -75,6 +75,13 @@ def __init__(self, data): else: raise TypeError(f"Unsupported data type {type(data)}. Must be either pandas DataFrame or Polars DataFrame.") + def __repr__(self): + """Return a string representation of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return str(self.data) + elif isinstance(self.data, PolarsDataFrame): + return self.data.__str__() + def __getitem__(self, key): """Allow access to columns using bracket notation.""" if isinstance(self.data, PandasDataFrame): From d82f6f84a91eed8af3eb16b0afd840557b276322 Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 11:44:03 -0500 Subject: [PATCH 08/14] add: to_dict method for unified dataframe --- pyopenms_viz/_dataframe.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py index f5cf7cf4..bdccacf0 100644 --- a/pyopenms_viz/_dataframe.py +++ b/pyopenms_viz/_dataframe.py @@ -192,4 +192,11 @@ def tolist(self, column_name): if isinstance(self.data, PandasDataFrame): return self.data[column_name].tolist() elif isinstance(self.data, PolarsDataFrame): - return self.data[column_name].to_list() \ No newline at end of file + return self.data[column_name].to_list() + + def to_dict(self): + """Return the DataFrame as a dictionary.""" + if isinstance(self.data, PandasDataFrame): + return self.data.to_dict() + elif isinstance(self.data, PolarsDataFrame): + return self.data.to_dict() \ No newline at end of file From 4e0db07edb6e0d69dc67849d96c3a146e79d169a Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 11:44:33 -0500 Subject: [PATCH 09/14] fix: convert all df's to_dict for use with both pandas or polars df through the unified dataframe interface --- pyopenms_viz/_bokeh/core.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyopenms_viz/_bokeh/core.py b/pyopenms_viz/_bokeh/core.py index 999a9569..914b9406 100644 --- a/pyopenms_viz/_bokeh/core.py +++ b/pyopenms_viz/_bokeh/core.py @@ -258,7 +258,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): kwargs["line_width"] = 2.5 if by is None: - source = ColumnDataSource(data) + source = ColumnDataSource(data.to_dict()) if color_gen is not None: kwargs["line_color"] = ( color_gen if isinstance(color_gen, str) else next(color_gen) @@ -270,7 +270,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): legend_items = [] for group, df in data.groupby(by, sort=False): - source = ColumnDataSource(df) + source = ColumnDataSource(df.to_dict()) if color_gen is not None: kwargs["line_color"] = ( color_gen if isinstance(color_gen, str) else next(color_gen) @@ -303,7 +303,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): if not plot_3d: direction = kwargs.pop("direction", "vertical") if by is None: - source = ColumnDataSource(data) + source = ColumnDataSource(data.to_dict()) if direction == "horizontal": x0_data_var = 0 x1_data_var = x @@ -326,7 +326,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): else: legend_items = [] for group, df in data.groupby(by): - source = ColumnDataSource(df) + source = ColumnDataSource(df.to_dict()) if direction == "horizontal": x0_data_var = 0 x1_data_var = x @@ -422,7 +422,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): kwargs[k] = v if by is None: kwargs["marker"] = next(marker_gen) - source = ColumnDataSource(data) + source = ColumnDataSource(data.to_dict()) line = fig.scatter(x=x, y=y, source=source, **kwargs) return fig, None else: @@ -431,7 +431,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): kwargs["marker"] = next(marker_gen) if z is None: kwargs["fill_color"] = next(color_gen) - source = ColumnDataSource(df) + source = ColumnDataSource(df.to_dict()) line = fig.scatter(x=x, y=y, source=source, **kwargs) legend_items.append((group, [line])) legend = Legend(items=legend_items) From 1f5141b0bcc234f0a9578847a25d5e4f02ee9ae0 Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 12:48:36 -0500 Subject: [PATCH 10/14] add: keep arg to duplicated method for polars column wrapper --- pyopenms_viz/_dataframe.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py index bdccacf0..f1e9e364 100644 --- a/pyopenms_viz/_dataframe.py +++ b/pyopenms_viz/_dataframe.py @@ -7,8 +7,6 @@ import polars as pl - - class PandasColumnWrapper: """Wrapper for Pandas Series to add custom methods.""" def __init__(self, series): @@ -32,9 +30,22 @@ def astype(self, dtype): """Cast the Series to the specified dtype.""" return self.series.cast(dtype) - def duplicated(self): + def duplicated(self, keep='first'): """Return a boolean Series indicating duplicate values.""" - return self.series.is_duplicated() + duplicated_mask = self.series.is_duplicated() + if keep == 'first': + first_occurrences = self.series.is_first_distinct() + return duplicated_mask & ~first_occurrences + + elif keep == 'last': + last_occurrences = self.series.is_last_distinct() + return duplicated_mask & ~last_occurrences + + elif keep is False: + return duplicated_mask.cast(pl.Boolean) + + else: + raise ValueError("keep must be 'first', 'last', or False") def tolist(self): """Return the Series as a list.""" From 60fe1037f9deb86bba18d3e54dd6801febbd6947 Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 13:26:41 -0500 Subject: [PATCH 11/14] update: unified dataframe to dict method --- pyopenms_viz/_dataframe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py index f1e9e364..5dd746d7 100644 --- a/pyopenms_viz/_dataframe.py +++ b/pyopenms_viz/_dataframe.py @@ -205,9 +205,9 @@ def tolist(self, column_name): elif isinstance(self.data, PolarsDataFrame): return self.data[column_name].to_list() - def to_dict(self): + def to_dict(self, orient='list'): """Return the DataFrame as a dictionary.""" if isinstance(self.data, PandasDataFrame): - return self.data.to_dict() + return self.data.to_dict(orient=orient) elif isinstance(self.data, PolarsDataFrame): - return self.data.to_dict() \ No newline at end of file + return self.data.to_dict(as_series=False) \ No newline at end of file From 37c7eef06db3c3bf23297e21f998bd15b772fefa Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 13:27:10 -0500 Subject: [PATCH 12/14] add: tests for _dataframe mod --- test/test_dataframe.py | 307 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 307 insertions(+) create mode 100644 test/test_dataframe.py diff --git a/test/test_dataframe.py b/test/test_dataframe.py new file mode 100644 index 00000000..7b772743 --- /dev/null +++ b/test/test_dataframe.py @@ -0,0 +1,307 @@ +""" +tes/test_dataframe +~~~~~~~~~~~~~~~~~~ +""" + +import pytest +import pandas as pd +import polars as pl +from pyopenms_viz._dataframe import * + +#################### +# PandasColumnWrapper + + +def test_pandas_column_wrapper_getattr(): + # Create a Pandas Series + series = pd.Series([1, 2, 3, 4, 5]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test accessing attributes of the underlying Pandas Series + assert column_wrapper.sum() == 15 + assert column_wrapper.mean() == 3.0 + assert column_wrapper.max() == 5 + assert column_wrapper.min() == 1 + +def test_pandas_column_wrapper_getattr_invalid(): + # Create a Pandas Series + series = pd.Series([1, 2, 3, 4, 5]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test accessing invalid attributes of the underlying Pandas Series + with pytest.raises(AttributeError): + column_wrapper.invalid_attribute + +def test_pandas_column_wrapper_cast(): + # Create a Pandas Series + series = pd.Series([1, 2, 3, 4, 5]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test casting the Series to a different dtype + casted_series = column_wrapper.astype(float) + assert casted_series.dtype == float + +def test_pandas_column_wrapper_is_duplicated(): + # Create a Pandas Series with duplicate values + series = pd.Series([1, 2, 3, 4, 5, 1, 2]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test checking for duplicate values + duplicated_series = column_wrapper.duplicated() + assert duplicated_series.tolist() == [False, False, False, False, False, True, True] + +def test_pandas_column_wrapper_to_list(): + # Create a Pandas Series + series = pd.Series([1, 2, 3, 4, 5]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test converting the Series to a list + series_list = column_wrapper.tolist() + assert series_list == [1, 2, 3, 4, 5] + + +#################### +# PolarsColumnWrapper + + +def test_polars_column_wrapper_getattr(): + # Create a Polars Series + series = pl.Series("a", [1, 2, 3, 4, 5]) + + # Create a PolarsColumnWrapper instance + column_wrapper = PolarsColumnWrapper(series) + + # Test accessing attributes of the underlying Polars Series + assert column_wrapper.sum() == 15 + assert column_wrapper.mean() == 3.0 + assert column_wrapper.max() == 5 + assert column_wrapper.min() == 1 + +def test_polars_column_wrapper_cast(): + # Create a Polars Series + series = pl.Series("a", [1, 2, 3, 4, 5]) + + # Create a PolarsColumnWrapper instance + column_wrapper = PolarsColumnWrapper(series) + + # Test casting the Series to a different dtype + casted_series = column_wrapper.astype(pl.Float64) + assert casted_series.dtype == pl.Float64 + +def test_polars_column_wrapper_is_duplicated_first(): + series = pl.Series("a", [1, 2, 3, 1, 2]) + column_wrapper = PolarsColumnWrapper(series) + + duplicated_series = column_wrapper.duplicated(keep='first') + expected_result = [False, False, False, True, True] # First occurrence is kept + + assert duplicated_series.to_list() == expected_result + +def test_polars_column_wrapper_is_duplicated_last(): + series = pl.Series("a", [1, 2, 3, 1, 2]) + column_wrapper = PolarsColumnWrapper(series) + + duplicated_series = column_wrapper.duplicated(keep='last') + expected_result = [True, True, False, False, False] # Last occurrence is kept + + assert duplicated_series.to_list() == expected_result + +def test_polars_column_wrapper_is_duplicated_all(): + series = pl.Series("a", [1, 2, 3, 1, 2]) + column_wrapper = PolarsColumnWrapper(series) + + duplicated_series = column_wrapper.duplicated(keep=False) + expected_result = [True, True, False, True, True] # All duplicates marked + + assert duplicated_series.to_list() == expected_result + +def test_polars_column_wrapper_to_list(): + # Create a Polars Series + series = pl.Series("a", [1, 2, 3, 4, 5]) + + # Create a PolarsColumnWrapper instance + column_wrapper = PolarsColumnWrapper(series) + + # Test converting the Series to a list + series_list = column_wrapper.tolist() + assert series_list == [1, 2, 3, 4, 5] + + +#################### +# GroupedDataFrame + +def test_grouped_dataframe_init_with_pandas(): + # Create a sample Pandas DataFrame and group it + df = pd.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.groupby('key') + + gdf = GroupedDataFrame(grouped, is_pandas=True) + + assert gdf.is_pandas is True + assert len(list(gdf)) == 2 # Two groups: A and B + +def test_grouped_dataframe_iterate_with_pandas(): + df = pd.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.groupby('key') + + gdf = GroupedDataFrame(grouped, is_pandas=True) + + groups = list(gdf) + + assert len(groups) == 2 # Should have two groups + assert groups[0][0] == 'A' # First group name should be 'A' + assert isinstance(groups[0][1], UnifiedDataFrame) # Should return UnifiedDataFrame + +def test_grouped_dataframe_sum_with_pandas(): + df = pd.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.groupby('key') + + gdf = GroupedDataFrame(grouped, is_pandas=True) + + summed_df = gdf.sum() + + assert isinstance(summed_df, UnifiedDataFrame) # Should return a UnifiedDataFrame + assert summed_df.data.equals(pd.DataFrame({'key': ['A', 'B'], 'value': [4, 6]})) # Check summed values + +def test_grouped_dataframe_init_with_polars(): + # Create a sample Polars DataFrame and group it + df = pl.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.group_by('key') + + gdf = GroupedDataFrame(grouped, is_pandas=False) + + assert gdf.is_pandas is False + assert len(list(gdf)) == 2 # Two groups: A and B + +def test_grouped_dataframe_iterate_with_polars(): + df = pl.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.group_by('key', maintain_order=True) + + gdf = GroupedDataFrame(grouped, is_pandas=False) + + groups = list(gdf) + + assert len(groups) == 2 # Should have two groups + assert groups[0][0] == 'A' # First group name should be 'A' + assert isinstance(groups[0][1], UnifiedDataFrame) # Should return UnifiedDataFrame + +def test_grouped_dataframe_sum_with_polars(): + df = pl.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.group_by('key', maintain_order=True) + + gdf = GroupedDataFrame(grouped, is_pandas=False) + + summed_df = gdf.sum() + + assert isinstance(summed_df, UnifiedDataFrame) # Should return a UnifiedDataFrame + expected_result = pl.DataFrame({'key': ['A', 'B'], 'value': [4, 6]}) + print(f"summed_df.data: {summed_df.data}") + print(f"expected_result: {expected_result}") + + assert summed_df.data.equals(expected_result) # Check summed values + + +#################### +# UnifiedDataFrame + + +def test_unified_dataframe(): + pandas_data = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + polars_data = PolarsDataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + pandas_df = UnifiedDataFrame(pandas_data) + polars_df = UnifiedDataFrame(polars_data) + + assert isinstance(pandas_df, UnifiedDataFrame) + assert isinstance(polars_df, UnifiedDataFrame) + + assert len(pandas_df) == 3 + assert len(polars_df) == 3 + + assert pandas_df.columns == ["A", "B"] + assert polars_df.columns == ["A", "B"] + + assert pandas_df["A"].tolist() == [1, 2, 3] + assert polars_df["A"].tolist() == [1, 2, 3] + + pandas_df["C"] = [7, 8, 9] + polars_df["C"] = [7, 8, 9] + + assert pandas_df.columns == ["A", "B", "C"] + assert polars_df.columns == ["A", "B", "C"] + + sorted_pandas_df = pandas_df.sort_values("A") + sorted_polars_df = polars_df.sort_values("A") + + assert sorted_pandas_df["A"].tolist() == [1, 2, 3] + assert sorted_polars_df["A"].tolist() == [1, 2, 3] + + assert sorted_pandas_df.reset_index(drop=True).index.tolist() == [0, 1, 2] + assert sorted_polars_df.reset_index(drop=True).index == [0, 1, 2] + + assert sorted_pandas_df.duplicated().tolist() == [False, False, False] + assert sorted_polars_df.duplicated().to_list() == [False, False, False] + +# Test UnifiedDataFrame.plot() +def test_unified_dataframe_plot(): + pandas_data = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + polars_data = PolarsDataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + pandas_df = UnifiedDataFrame(pandas_data) + polars_df = UnifiedDataFrame(polars_data) + + pandas_plot = pandas_df.plot(x="A", y="B", kind="line") + polars_plot = polars_df.plot(x="A", y="B", kind="line") + + assert isinstance(pandas_plot, object) # Replace with the actual type of the plot object + assert isinstance(polars_plot, object) # Replace with the actual type of the plot object + +# Test UnifiedDataFrame.tolist() +def test_unified_dataframe_tolist(): + pandas_data = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + polars_data = PolarsDataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + pandas_df = UnifiedDataFrame(pandas_data) + polars_df = UnifiedDataFrame(polars_data) + + assert pandas_df.tolist("A") == [1, 2, 3] + assert polars_df.tolist("A") == [1, 2, 3] + +# Test UnifiedDataFrame.to_dict() +def test_unified_dataframe_to_dict(): + pandas_data = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + polars_data = PolarsDataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + pandas_df = UnifiedDataFrame(pandas_data) + polars_df = UnifiedDataFrame(polars_data) + + assert pandas_df.to_dict(orient="list") == {"A": [1, 2, 3], "B": [4, 5, 6]} + assert polars_df.to_dict() == {"A": [1, 2, 3], "B": [4, 5, 6]} \ No newline at end of file From 27594a27b6e14f7678bad68668e328729731be77 Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 14:17:13 -0500 Subject: [PATCH 13/14] add: unified concat method, and add apply method to polars column wrapper --- pyopenms_viz/_core.py | 15 ++++++------ pyopenms_viz/_dataframe.py | 50 +++++++++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/pyopenms_viz/_core.py b/pyopenms_viz/_core.py index 800573d3..af5c414e 100644 --- a/pyopenms_viz/_core.py +++ b/pyopenms_viz/_core.py @@ -7,7 +7,7 @@ import types import re -from pandas import cut, merge, Interval, concat +from pandas import cut, Interval from pandas.core.frame import DataFrame from pandas.core.dtypes.generic import ABCDataFrame from pandas.core.dtypes.common import is_integer @@ -15,7 +15,7 @@ from numpy import ceil, log1p, log2, nan, mean, repeat, concatenate -from ._dataframe import UnifiedDataFrame +from ._dataframe import UnifiedDataFrame, concat from ._config import LegendConfig, FeatureConfig, _BasePlotConfig from ._misc import ( ColorGenerator, @@ -964,6 +964,7 @@ def _get_colors( colors = [next(color_gen) for _ in range(len(uniques))] # Create a mapping of unique values to their corresponding colors color_map = {uniques[i]: colors[i] for i in range(len(colors))} + # Apply the color mapping to the specified column in the data and turn it into a ColorGenerator return ColorGenerator(data[self.by].apply(lambda x: color_map[x])) # Fallback ColorGenerator with one color @@ -1403,7 +1404,7 @@ def __init__(self, data: DataFrame) -> None: def __call__(self, *args: Any, **kwargs: Any) -> Any: backend_name = kwargs.get("backend", None) if backend_name is None: - backend_name = "matplotlib" + backend_name = "ms_matplotlib" plot_backend = _get_plot_backend(backend_name) @@ -1511,7 +1512,7 @@ def _load_backend(backend: str) -> types.ModuleType: types.ModuleType The imported backend. """ - if backend == "bokeh": + if backend == "ms_bokeh": try: module = importlib.import_module("pyopenms_viz._bokeh") except ImportError: @@ -1520,7 +1521,7 @@ def _load_backend(backend: str) -> types.ModuleType: ) from None return module - elif backend == "matplotlib": + elif backend == "ms_matplotlib": try: module = importlib.import_module("pyopenms_viz._matplotlib") except ImportError: @@ -1529,7 +1530,7 @@ def _load_backend(backend: str) -> types.ModuleType: ) from None return module - elif backend == "plotly": + elif backend == "ms_plotly": try: module = importlib.import_module("pyopenms_viz._plotly") except ImportError: @@ -1539,7 +1540,7 @@ def _load_backend(backend: str) -> types.ModuleType: return module raise ValueError( - f"Could not find plotting backend '{backend}'. Needs to be one of 'bokeh', 'matplotlib', or 'plotly'." + f"Could not find plotting backend '{backend}'. Needs to be one of 'ms_bokeh', 'ms_matplotlib', or 'ms_plotly'." ) diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py index 5dd746d7..62a2c18b 100644 --- a/pyopenms_viz/_dataframe.py +++ b/pyopenms_viz/_dataframe.py @@ -4,6 +4,7 @@ from polars.series import Series as PolarsSeries from polars.dataframe.group_by import GroupBy as PolarsGroupBy +import pandas as pd import polars as pl @@ -26,10 +27,22 @@ def __getattr__(self, name): """Delegate attribute access to the underlying Polars Series.""" return getattr(self.series, name) + def __getitem__(self, key): + """Allow access to elements using bracket notation.""" + return PolarsColumnWrapper(self.series[key]) + + def __len__(self): + """Return the length of the Series.""" + return self.series.len() + def astype(self, dtype): """Cast the Series to the specified dtype.""" return self.series.cast(dtype) + def unique(self): + """Return unique values in the Series.""" + return self.series.unique().to_numpy() + def duplicated(self, keep='first'): """Return a boolean Series indicating duplicate values.""" duplicated_mask = self.series.is_duplicated() @@ -46,6 +59,10 @@ def duplicated(self, keep='first'): else: raise ValueError("keep must be 'first', 'last', or False") + + def apply(self, func): + """Apply a function to each element of the Series.""" + return self.series.map_elements(func, skip_nulls=True) def tolist(self): """Return the Series as a list.""" @@ -210,4 +227,35 @@ def to_dict(self, orient='list'): if isinstance(self.data, PandasDataFrame): return self.data.to_dict(orient=orient) elif isinstance(self.data, PolarsDataFrame): - return self.data.to_dict(as_series=False) \ No newline at end of file + return self.data.to_dict(as_series=False) + + +def concat(udfs, ignore_index=True): + """ + Concatenate multiple UnifiedDataFrames into one. + + Parameters: + - udfs: A list of UnifiedDataFrames to concatenate. + - ignore_index: Whether to ignore the index during concatenation. + + Returns: + A new UnifiedDataFrame containing concatenated data. + """ + if not udfs: + raise ValueError("No UnifiedDataFrames provided for concatenation.") + + if hasattr(udfs[0], 'data'): # If udfs is a list of UnifiedDataFrames + dataframes = [udf.data for udf in udfs] + else: + dataframes = udfs + + if isinstance(dataframes[0], PandasDataFrame): + concatenated_df = pd.concat(dataframes, ignore_index=ignore_index) + return UnifiedDataFrame(concatenated_df) + + elif isinstance(dataframes[0], PolarsDataFrame): + concatenated_df = pl.concat(dataframes) + return UnifiedDataFrame(concatenated_df) + + else: + raise TypeError("Unsupported data type in UnifiedDataFrames.") \ No newline at end of file From 4676573376cd7f054b78518091cb03bf146162cf Mon Sep 17 00:00:00 2001 From: singjc Date: Fri, 29 Nov 2024 14:35:38 -0500 Subject: [PATCH 14/14] add: agg for polars column wrapper --- pyopenms_viz/_dataframe.py | 87 +++++++++++++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 6 deletions(-) diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py index 62a2c18b..a8ea81c8 100644 --- a/pyopenms_viz/_dataframe.py +++ b/pyopenms_viz/_dataframe.py @@ -91,6 +91,36 @@ def sum(self): else: summed_data = self.grouped_data.agg(pl.all().sum()) return UnifiedDataFrame(summed_data) + + def agg(self, agg_dict): + """ + Aggregate the grouped data using the specified aggregation methods. + + Parameters: + - agg_dict: A dictionary where keys are column names and values are aggregation functions. + + Returns: + A UnifiedDataFrame containing the aggregated results. + """ + if self.is_pandas: + aggregated_data = self.grouped_data.agg(agg_dict).reset_index() + return UnifiedDataFrame(aggregated_data) + + else: + # For Polars, we need to construct the aggregation expression + agg_expressions = [] + for col, func in agg_dict.items(): + if func == 'mean': + agg_expressions.append(pl.col(col).mean().alias(col)) + elif func == 'sum': + agg_expressions.append(pl.col(col).sum().alias(col)) + elif func == 'max': + agg_expressions.append(pl.col(col).max().alias(col)) + else: + raise ValueError(f"Unsupported aggregation function: {func}") + + aggregated_data = self.grouped_data.agg(agg_expressions) + return UnifiedDataFrame(aggregated_data) @pl.api.register_dataframe_namespace("mass") class UnifiedDataFrame: @@ -110,6 +140,13 @@ def __repr__(self): elif isinstance(self.data, PolarsDataFrame): return self.data.__str__() + def __getattribute__(self, name): + """Delegate attribute access to the underlying DataFrame.""" + try: + return object.__getattribute__(self, name) + except AttributeError: + return getattr(self.data, name) + def __getitem__(self, key): """Allow access to columns using bracket notation.""" if isinstance(self.data, PandasDataFrame): @@ -127,12 +164,15 @@ def __getitem__(self, key): def __setitem__(self, key, value): """Allow assignment to columns using bracket notation.""" - if isinstance(self.data, PandasDataFrame): + if isinstance(self.data, pd.DataFrame): self.data[key] = value - elif isinstance(self.data, PolarsDataFrame): - self.data = self.data.with_columns( - PolarsSeries(key, value) - ) + elif isinstance(self.data, pl.DataFrame): + # Ensure value is of the correct type before assignment + if isinstance(value, PolarsColumnWrapper): + self.data = self.data.with_columns(value.series.alias(key)) + else: + # If value is not wrapped, convert it to a Polars Series + self.data = self.data.with_columns(pl.Series(name=key, values=value)) def __len__(self): """Return the number of rows in the DataFrame.""" @@ -258,4 +298,39 @@ def concat(udfs, ignore_index=True): return UnifiedDataFrame(concatenated_df) else: - raise TypeError("Unsupported data type in UnifiedDataFrames.") \ No newline at end of file + raise TypeError("Unsupported data type in UnifiedDataFrames.") + + +def cut(series, bins, right=True, labels=None): + """ + Bin values into discrete intervals for a Series (Pandas or Polars). + + Parameters: + - series: A PolarsColumnWrapper or Pandas Series instance. + - bins: The criteria to bin by (can be an integer or a sequence of scalars). + - right: Indicates whether intervals include the rightmost edge. + - labels: Specifies the labels for the returned bins. + + Returns: + A new Series with binned data. + """ + if isinstance(series.series, pd.Series): + # Use Pandas cut + binned = pd.cut(series.series, bins=bins, right=right, labels=labels) + return PandasColumnWrapper(binned) # Assuming you have a PandasColumnWrapper + + elif isinstance(series.series, pl.Series): + if isinstance(bins, int): + # Create equal-width bins if bins is an integer + min_val = series.series.min() + max_val = series.series.max() + bin_edges = [min_val + i * (max_val - min_val) / bins for i in range(bins + 1)] + else: + bin_edges = bins + + # Use Polars cut method with generated bin edges + binned_series = series.series.cut(breaks=bin_edges, labels=labels, left_closed=not right) + return PolarsColumnWrapper(binned_series) # Return wrapped Polars Series + + else: + raise TypeError("Unsupported data type in the provided series.") \ No newline at end of file