diff --git a/pyopenms_viz/__init__.py b/pyopenms_viz/__init__.py index 078c1245..9f606a33 100644 --- a/pyopenms_viz/__init__.py +++ b/pyopenms_viz/__init__.py @@ -10,6 +10,8 @@ import types from pathlib import Path +import pyopenms_viz._dataframe # noqa: F401 + __version__ = "0.1.5" diff --git a/pyopenms_viz/_bokeh/core.py b/pyopenms_viz/_bokeh/core.py index 999a9569..914b9406 100644 --- a/pyopenms_viz/_bokeh/core.py +++ b/pyopenms_viz/_bokeh/core.py @@ -258,7 +258,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): kwargs["line_width"] = 2.5 if by is None: - source = ColumnDataSource(data) + source = ColumnDataSource(data.to_dict()) if color_gen is not None: kwargs["line_color"] = ( color_gen if isinstance(color_gen, str) else next(color_gen) @@ -270,7 +270,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): legend_items = [] for group, df in data.groupby(by, sort=False): - source = ColumnDataSource(df) + source = ColumnDataSource(df.to_dict()) if color_gen is not None: kwargs["line_color"] = ( color_gen if isinstance(color_gen, str) else next(color_gen) @@ -303,7 +303,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): if not plot_3d: direction = kwargs.pop("direction", "vertical") if by is None: - source = ColumnDataSource(data) + source = ColumnDataSource(data.to_dict()) if direction == "horizontal": x0_data_var = 0 x1_data_var = x @@ -326,7 +326,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): else: legend_items = [] for group, df in data.groupby(by): - source = ColumnDataSource(df) + source = ColumnDataSource(df.to_dict()) if direction == "horizontal": x0_data_var = 0 x1_data_var = x @@ -422,7 +422,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): kwargs[k] = v if by is None: kwargs["marker"] = next(marker_gen) - source = ColumnDataSource(data) + source = ColumnDataSource(data.to_dict()) line = fig.scatter(x=x, y=y, source=source, **kwargs) return fig, None else: @@ -431,7 +431,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs): kwargs["marker"] = next(marker_gen) if z is None: kwargs["fill_color"] = next(color_gen) - source = ColumnDataSource(df) + source = ColumnDataSource(df.to_dict()) line = fig.scatter(x=x, y=y, source=source, **kwargs) legend_items.append((group, [line])) legend = Legend(items=legend_items) diff --git a/pyopenms_viz/_core.py b/pyopenms_viz/_core.py index 50782f40..af5c414e 100644 --- a/pyopenms_viz/_core.py +++ b/pyopenms_viz/_core.py @@ -7,7 +7,7 @@ import types import re -from pandas import cut, merge, Interval, concat +from pandas import cut, Interval from pandas.core.frame import DataFrame from pandas.core.dtypes.generic import ABCDataFrame from pandas.core.dtypes.common import is_integer @@ -15,6 +15,7 @@ from numpy import ceil, log1p, log2, nan, mean, repeat, concatenate +from ._dataframe import UnifiedDataFrame, concat from ._config import LegendConfig, FeatureConfig, _BasePlotConfig from ._misc import ( ColorGenerator, @@ -25,6 +26,8 @@ from .constants import IS_SPHINX_BUILD import warnings +import polars as pl + _common_kinds = ("line", "vline", "scatter") _msdata_kinds = ("chromatogram", "mobilogram", "spectrum", "peakmap") @@ -148,9 +151,9 @@ def __init__( _config: _BasePlotConfig | None = None, **kwargs, ) -> None: - + # Data attributes - self.data = data.copy() + self.data = UnifiedDataFrame(data) if not isinstance(data, UnifiedDataFrame) else data self.kind = kind self.by = by self.plot_3d = plot_3d @@ -564,6 +567,7 @@ def __init__( super().__init__(data, x, y, **kwargs) if annotation_data is not None: + annotation_data = UnifiedDataFrame(annotation_data) self.annotation_data = annotation_data.copy() else: self.annotation_data = None @@ -725,7 +729,7 @@ def __init__( super().__init__(data, x, y, **kwargs) - self.reference_spectrum = reference_spectrum + self.reference_spectrum = UnifiedDataFrame(reference_spectrum) if reference_spectrum is not None else None self.mirror_spectrum = mirror_spectrum self.relative_intensity = relative_intensity self.bin_peaks = bin_peaks @@ -920,7 +924,7 @@ def _prepare_data( # copy spectrum data to not modify the original spectrum = spectrum.copy() reference_spectrum = ( - self.reference_spectrum.copy() if reference_spectrum is not None else None + self.reference_spectrum if reference_spectrum is not None else None ) # Convert to relative intensity if required @@ -960,6 +964,7 @@ def _get_colors( colors = [next(color_gen) for _ in range(len(uniques))] # Create a mapping of unique values to their corresponding colors color_map = {uniques[i]: colors[i] for i in range(len(colors))} + # Apply the color mapping to the specified column in the data and turn it into a ColorGenerator return ColorGenerator(data[self.by].apply(lambda x: color_map[x])) # Fallback ColorGenerator with one color @@ -1141,6 +1146,7 @@ def __init__( self.fill_by_z = fill_by_z if annotation_data is not None: + annotation_data = UnifiedDataFrame(annotation_data) self.annotation_data = annotation_data.copy() else: self.annotation_data = None @@ -1398,7 +1404,7 @@ def __init__(self, data: DataFrame) -> None: def __call__(self, *args: Any, **kwargs: Any) -> Any: backend_name = kwargs.get("backend", None) if backend_name is None: - backend_name = "matplotlib" + backend_name = "ms_matplotlib" plot_backend = _get_plot_backend(backend_name) @@ -1439,7 +1445,7 @@ def _get_call_args(backend_name: str, data: DataFrame, args, kwargs): dict The arguments to pass to the plotting backend. """ - if isinstance(data, ABCDataFrame): + if isinstance(data, ABCDataFrame) or isinstance(data, pl.DataFrame): arg_def = [ ("x", None), ("y", None), @@ -1506,27 +1512,27 @@ def _load_backend(backend: str) -> types.ModuleType: types.ModuleType The imported backend. """ - if backend == "bokeh": + if backend == "ms_bokeh": try: - module = importlib.import_module("pyopenms_viz.plotting._bokeh") + module = importlib.import_module("pyopenms_viz._bokeh") except ImportError: raise ImportError( "Bokeh is required for plotting when the 'bokeh' backend is selected." ) from None return module - elif backend == "matplotlib": + elif backend == "ms_matplotlib": try: - module = importlib.import_module("pyopenms_viz.plotting._matplotlib") + module = importlib.import_module("pyopenms_viz._matplotlib") except ImportError: raise ImportError( "Matplotlib is required for plotting when the 'matplotlib' backend is selected." ) from None return module - elif backend == "plotly": + elif backend == "ms_plotly": try: - module = importlib.import_module("pyopenms_viz.plotting._plotly") + module = importlib.import_module("pyopenms_viz._plotly") except ImportError: raise ImportError( "Plotly is required for plotting when the 'plotly' backend is selected." @@ -1534,7 +1540,7 @@ def _load_backend(backend: str) -> types.ModuleType: return module raise ValueError( - f"Could not find plotting backend '{backend}'. Needs to be one of 'bokeh', 'matplotlib', or 'plotly'." + f"Could not find plotting backend '{backend}'. Needs to be one of 'ms_bokeh', 'ms_matplotlib', or 'ms_plotly'." ) @@ -1548,3 +1554,4 @@ def _get_plot_backend(backend: str | None = None): module = _load_backend(backend_str) _backends[backend_str] = module return module + diff --git a/pyopenms_viz/_dataframe.py b/pyopenms_viz/_dataframe.py new file mode 100644 index 00000000..a8ea81c8 --- /dev/null +++ b/pyopenms_viz/_dataframe.py @@ -0,0 +1,336 @@ +from pandas.core.dtypes.generic import ABCDataFrame as PandasDataFrame +from pandas.core.groupby.generic import DataFrameGroupBy as PandasGroupBy +from polars.dataframe.frame import DataFrame as PolarsDataFrame +from polars.series import Series as PolarsSeries +from polars.dataframe.group_by import GroupBy as PolarsGroupBy + +import pandas as pd +import polars as pl + + +class PandasColumnWrapper: + """Wrapper for Pandas Series to add custom methods.""" + def __init__(self, series): + self.series = series + + def __getattr__(self, name): + """Delegate attribute access to the underlying Pandas Series.""" + return getattr(self.series, name) + + +class PolarsColumnWrapper: + """Wrapper for Polars Series to add custom methods.""" + def __init__(self, series): + self.series = series + + def __getattr__(self, name): + """Delegate attribute access to the underlying Polars Series.""" + return getattr(self.series, name) + + def __getitem__(self, key): + """Allow access to elements using bracket notation.""" + return PolarsColumnWrapper(self.series[key]) + + def __len__(self): + """Return the length of the Series.""" + return self.series.len() + + def astype(self, dtype): + """Cast the Series to the specified dtype.""" + return self.series.cast(dtype) + + def unique(self): + """Return unique values in the Series.""" + return self.series.unique().to_numpy() + + def duplicated(self, keep='first'): + """Return a boolean Series indicating duplicate values.""" + duplicated_mask = self.series.is_duplicated() + if keep == 'first': + first_occurrences = self.series.is_first_distinct() + return duplicated_mask & ~first_occurrences + + elif keep == 'last': + last_occurrences = self.series.is_last_distinct() + return duplicated_mask & ~last_occurrences + + elif keep is False: + return duplicated_mask.cast(pl.Boolean) + + else: + raise ValueError("keep must be 'first', 'last', or False") + + def apply(self, func): + """Apply a function to each element of the Series.""" + return self.series.map_elements(func, skip_nulls=True) + + def tolist(self): + """Return the Series as a list.""" + return self.series.to_list() + +class GroupedDataFrame: + """Class to handle grouped DataFrames for both Pandas and Polars.""" + def __init__(self, grouped_data, is_pandas=True): + self.grouped_data = grouped_data + self.is_pandas = is_pandas + + def __iter__(self): + """Allow iteration over groups.""" + if self.is_pandas: + for group_name, group_df in self.grouped_data: + yield group_name, UnifiedDataFrame(group_df) + else: + for group_name, group_df in self.grouped_data: + yield group_name[0], UnifiedDataFrame(group_df) + + def sum(self): + """Sum the grouped data.""" + if self.is_pandas: + summed_data = self.grouped_data.sum().reset_index() + return UnifiedDataFrame(summed_data) + else: + summed_data = self.grouped_data.agg(pl.all().sum()) + return UnifiedDataFrame(summed_data) + + def agg(self, agg_dict): + """ + Aggregate the grouped data using the specified aggregation methods. + + Parameters: + - agg_dict: A dictionary where keys are column names and values are aggregation functions. + + Returns: + A UnifiedDataFrame containing the aggregated results. + """ + if self.is_pandas: + aggregated_data = self.grouped_data.agg(agg_dict).reset_index() + return UnifiedDataFrame(aggregated_data) + + else: + # For Polars, we need to construct the aggregation expression + agg_expressions = [] + for col, func in agg_dict.items(): + if func == 'mean': + agg_expressions.append(pl.col(col).mean().alias(col)) + elif func == 'sum': + agg_expressions.append(pl.col(col).sum().alias(col)) + elif func == 'max': + agg_expressions.append(pl.col(col).max().alias(col)) + else: + raise ValueError(f"Unsupported aggregation function: {func}") + + aggregated_data = self.grouped_data.agg(agg_expressions) + return UnifiedDataFrame(aggregated_data) + +@pl.api.register_dataframe_namespace("mass") +class UnifiedDataFrame: + """ + Wrapper class for Pandas and Polars DataFrames to provide a unified interface. + """ + def __init__(self, data): + if isinstance(data, (PandasDataFrame, PolarsDataFrame)): + self.data = data + else: + raise TypeError(f"Unsupported data type {type(data)}. Must be either pandas DataFrame or Polars DataFrame.") + + def __repr__(self): + """Return a string representation of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return str(self.data) + elif isinstance(self.data, PolarsDataFrame): + return self.data.__str__() + + def __getattribute__(self, name): + """Delegate attribute access to the underlying DataFrame.""" + try: + return object.__getattribute__(self, name) + except AttributeError: + return getattr(self.data, name) + + def __getitem__(self, key): + """Allow access to columns using bracket notation.""" + if isinstance(self.data, PandasDataFrame): + if isinstance(key, list): + return UnifiedDataFrame(self.data[key]) + else: + return PandasColumnWrapper(self.data[key]) + elif isinstance(self.data, PolarsDataFrame): + if isinstance(key, list): + return UnifiedDataFrame(self.data.select(key)) + else: + return PolarsColumnWrapper(self.data[key]) + else: + raise KeyError(f"Column '{key}' not found in DataFrame.") + + def __setitem__(self, key, value): + """Allow assignment to columns using bracket notation.""" + if isinstance(self.data, pd.DataFrame): + self.data[key] = value + elif isinstance(self.data, pl.DataFrame): + # Ensure value is of the correct type before assignment + if isinstance(value, PolarsColumnWrapper): + self.data = self.data.with_columns(value.series.alias(key)) + else: + # If value is not wrapped, convert it to a Polars Series + self.data = self.data.with_columns(pl.Series(name=key, values=value)) + + def __len__(self): + """Return the number of rows in the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return len(self.data) + elif isinstance(self.data, PolarsDataFrame): + return self.data.height + + @property + def index(self): + """Return the index of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return self.data.index + elif isinstance(self.data, PolarsDataFrame): + return list(range(self.data.height)) + + @property + def columns(self): + """Return a list of column names.""" + if isinstance(self.data, PandasDataFrame): + return self.data.columns.tolist() + elif isinstance(self.data, PolarsDataFrame): + return self.data.columns + + def copy(self): + """Return a copy of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return UnifiedDataFrame(self.data.copy()) + elif isinstance(self.data, PolarsDataFrame): + return UnifiedDataFrame(self.data.clone()) + + def sort_values(self, by, ascending=True, inplace=False): + """Sort the DataFrame by the specified column(s).""" + if isinstance(self.data, PandasDataFrame): + if inplace: + self.data.sort_values(by=by, ascending=ascending, inplace=True) + else: + sorted_data = self.data.sort_values(by=by, ascending=ascending) + return UnifiedDataFrame(sorted_data) + + elif isinstance(self.data, PolarsDataFrame): + sorted_data = self.data.sort(by=by, descending=not ascending) + if inplace: + self.data = sorted_data + else: + return UnifiedDataFrame(sorted_data) + + def reset_index(self, drop=False): + """Reset the index of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return UnifiedDataFrame(self.data.reset_index(drop=drop)) + elif isinstance(self.data, PolarsDataFrame): + # For Polars we can just return the same DataFrame since it doesn't have an index like Pandas. + return UnifiedDataFrame(self.data) + + def duplicated(self): + """Return a boolean Series indicating duplicate rows.""" + if isinstance(self.data, PandasDataFrame): + return self.data.duplicated() + elif isinstance(self.data, PolarsDataFrame): + return self.data.is_duplicated() + + def iterrows(self): + """Return an iterator for rows of the DataFrame.""" + if isinstance(self.data, PandasDataFrame): + return self.data.iterrows() + elif isinstance(self.data, PolarsDataFrame): + return enumerate(self.data.iter_rows(named=True)) + + def groupby(self, by, sort=True): + """Group by specified columns.""" + if isinstance(self.data, PandasDataFrame): + grouped = self.data.groupby(by, sort=sort) + return GroupedDataFrame(grouped) + elif isinstance(self.data, PolarsDataFrame): + grouped = self.data.group_by(by) + return GroupedDataFrame(grouped, is_pandas=False) + + def plot(self, x: str, y: str, kind: str = "line", **kwargs): + from ._core import PlotAccessor + return PlotAccessor(self.data)(x, y, kind, **kwargs) + + def tolist(self, column_name): + """Return a list of values from a specified column.""" + if isinstance(self.data, PandasDataFrame): + return self.data[column_name].tolist() + elif isinstance(self.data, PolarsDataFrame): + return self.data[column_name].to_list() + + def to_dict(self, orient='list'): + """Return the DataFrame as a dictionary.""" + if isinstance(self.data, PandasDataFrame): + return self.data.to_dict(orient=orient) + elif isinstance(self.data, PolarsDataFrame): + return self.data.to_dict(as_series=False) + + +def concat(udfs, ignore_index=True): + """ + Concatenate multiple UnifiedDataFrames into one. + + Parameters: + - udfs: A list of UnifiedDataFrames to concatenate. + - ignore_index: Whether to ignore the index during concatenation. + + Returns: + A new UnifiedDataFrame containing concatenated data. + """ + if not udfs: + raise ValueError("No UnifiedDataFrames provided for concatenation.") + + if hasattr(udfs[0], 'data'): # If udfs is a list of UnifiedDataFrames + dataframes = [udf.data for udf in udfs] + else: + dataframes = udfs + + if isinstance(dataframes[0], PandasDataFrame): + concatenated_df = pd.concat(dataframes, ignore_index=ignore_index) + return UnifiedDataFrame(concatenated_df) + + elif isinstance(dataframes[0], PolarsDataFrame): + concatenated_df = pl.concat(dataframes) + return UnifiedDataFrame(concatenated_df) + + else: + raise TypeError("Unsupported data type in UnifiedDataFrames.") + + +def cut(series, bins, right=True, labels=None): + """ + Bin values into discrete intervals for a Series (Pandas or Polars). + + Parameters: + - series: A PolarsColumnWrapper or Pandas Series instance. + - bins: The criteria to bin by (can be an integer or a sequence of scalars). + - right: Indicates whether intervals include the rightmost edge. + - labels: Specifies the labels for the returned bins. + + Returns: + A new Series with binned data. + """ + if isinstance(series.series, pd.Series): + # Use Pandas cut + binned = pd.cut(series.series, bins=bins, right=right, labels=labels) + return PandasColumnWrapper(binned) # Assuming you have a PandasColumnWrapper + + elif isinstance(series.series, pl.Series): + if isinstance(bins, int): + # Create equal-width bins if bins is an integer + min_val = series.series.min() + max_val = series.series.max() + bin_edges = [min_val + i * (max_val - min_val) / bins for i in range(bins + 1)] + else: + bin_edges = bins + + # Use Polars cut method with generated bin edges + binned_series = series.series.cut(breaks=bin_edges, labels=labels, left_closed=not right) + return PolarsColumnWrapper(binned_series) # Return wrapped Polars Series + + else: + raise TypeError("Unsupported data type in the provided series.") \ No newline at end of file diff --git a/test/test_dataframe.py b/test/test_dataframe.py new file mode 100644 index 00000000..7b772743 --- /dev/null +++ b/test/test_dataframe.py @@ -0,0 +1,307 @@ +""" +tes/test_dataframe +~~~~~~~~~~~~~~~~~~ +""" + +import pytest +import pandas as pd +import polars as pl +from pyopenms_viz._dataframe import * + +#################### +# PandasColumnWrapper + + +def test_pandas_column_wrapper_getattr(): + # Create a Pandas Series + series = pd.Series([1, 2, 3, 4, 5]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test accessing attributes of the underlying Pandas Series + assert column_wrapper.sum() == 15 + assert column_wrapper.mean() == 3.0 + assert column_wrapper.max() == 5 + assert column_wrapper.min() == 1 + +def test_pandas_column_wrapper_getattr_invalid(): + # Create a Pandas Series + series = pd.Series([1, 2, 3, 4, 5]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test accessing invalid attributes of the underlying Pandas Series + with pytest.raises(AttributeError): + column_wrapper.invalid_attribute + +def test_pandas_column_wrapper_cast(): + # Create a Pandas Series + series = pd.Series([1, 2, 3, 4, 5]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test casting the Series to a different dtype + casted_series = column_wrapper.astype(float) + assert casted_series.dtype == float + +def test_pandas_column_wrapper_is_duplicated(): + # Create a Pandas Series with duplicate values + series = pd.Series([1, 2, 3, 4, 5, 1, 2]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test checking for duplicate values + duplicated_series = column_wrapper.duplicated() + assert duplicated_series.tolist() == [False, False, False, False, False, True, True] + +def test_pandas_column_wrapper_to_list(): + # Create a Pandas Series + series = pd.Series([1, 2, 3, 4, 5]) + + # Create a PandasColumnWrapper instance + column_wrapper = PandasColumnWrapper(series) + + # Test converting the Series to a list + series_list = column_wrapper.tolist() + assert series_list == [1, 2, 3, 4, 5] + + +#################### +# PolarsColumnWrapper + + +def test_polars_column_wrapper_getattr(): + # Create a Polars Series + series = pl.Series("a", [1, 2, 3, 4, 5]) + + # Create a PolarsColumnWrapper instance + column_wrapper = PolarsColumnWrapper(series) + + # Test accessing attributes of the underlying Polars Series + assert column_wrapper.sum() == 15 + assert column_wrapper.mean() == 3.0 + assert column_wrapper.max() == 5 + assert column_wrapper.min() == 1 + +def test_polars_column_wrapper_cast(): + # Create a Polars Series + series = pl.Series("a", [1, 2, 3, 4, 5]) + + # Create a PolarsColumnWrapper instance + column_wrapper = PolarsColumnWrapper(series) + + # Test casting the Series to a different dtype + casted_series = column_wrapper.astype(pl.Float64) + assert casted_series.dtype == pl.Float64 + +def test_polars_column_wrapper_is_duplicated_first(): + series = pl.Series("a", [1, 2, 3, 1, 2]) + column_wrapper = PolarsColumnWrapper(series) + + duplicated_series = column_wrapper.duplicated(keep='first') + expected_result = [False, False, False, True, True] # First occurrence is kept + + assert duplicated_series.to_list() == expected_result + +def test_polars_column_wrapper_is_duplicated_last(): + series = pl.Series("a", [1, 2, 3, 1, 2]) + column_wrapper = PolarsColumnWrapper(series) + + duplicated_series = column_wrapper.duplicated(keep='last') + expected_result = [True, True, False, False, False] # Last occurrence is kept + + assert duplicated_series.to_list() == expected_result + +def test_polars_column_wrapper_is_duplicated_all(): + series = pl.Series("a", [1, 2, 3, 1, 2]) + column_wrapper = PolarsColumnWrapper(series) + + duplicated_series = column_wrapper.duplicated(keep=False) + expected_result = [True, True, False, True, True] # All duplicates marked + + assert duplicated_series.to_list() == expected_result + +def test_polars_column_wrapper_to_list(): + # Create a Polars Series + series = pl.Series("a", [1, 2, 3, 4, 5]) + + # Create a PolarsColumnWrapper instance + column_wrapper = PolarsColumnWrapper(series) + + # Test converting the Series to a list + series_list = column_wrapper.tolist() + assert series_list == [1, 2, 3, 4, 5] + + +#################### +# GroupedDataFrame + +def test_grouped_dataframe_init_with_pandas(): + # Create a sample Pandas DataFrame and group it + df = pd.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.groupby('key') + + gdf = GroupedDataFrame(grouped, is_pandas=True) + + assert gdf.is_pandas is True + assert len(list(gdf)) == 2 # Two groups: A and B + +def test_grouped_dataframe_iterate_with_pandas(): + df = pd.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.groupby('key') + + gdf = GroupedDataFrame(grouped, is_pandas=True) + + groups = list(gdf) + + assert len(groups) == 2 # Should have two groups + assert groups[0][0] == 'A' # First group name should be 'A' + assert isinstance(groups[0][1], UnifiedDataFrame) # Should return UnifiedDataFrame + +def test_grouped_dataframe_sum_with_pandas(): + df = pd.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.groupby('key') + + gdf = GroupedDataFrame(grouped, is_pandas=True) + + summed_df = gdf.sum() + + assert isinstance(summed_df, UnifiedDataFrame) # Should return a UnifiedDataFrame + assert summed_df.data.equals(pd.DataFrame({'key': ['A', 'B'], 'value': [4, 6]})) # Check summed values + +def test_grouped_dataframe_init_with_polars(): + # Create a sample Polars DataFrame and group it + df = pl.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.group_by('key') + + gdf = GroupedDataFrame(grouped, is_pandas=False) + + assert gdf.is_pandas is False + assert len(list(gdf)) == 2 # Two groups: A and B + +def test_grouped_dataframe_iterate_with_polars(): + df = pl.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.group_by('key', maintain_order=True) + + gdf = GroupedDataFrame(grouped, is_pandas=False) + + groups = list(gdf) + + assert len(groups) == 2 # Should have two groups + assert groups[0][0] == 'A' # First group name should be 'A' + assert isinstance(groups[0][1], UnifiedDataFrame) # Should return UnifiedDataFrame + +def test_grouped_dataframe_sum_with_polars(): + df = pl.DataFrame({ + 'key': ['A', 'B', 'A', 'B'], + 'value': [1, 2, 3, 4] + }) + grouped = df.group_by('key', maintain_order=True) + + gdf = GroupedDataFrame(grouped, is_pandas=False) + + summed_df = gdf.sum() + + assert isinstance(summed_df, UnifiedDataFrame) # Should return a UnifiedDataFrame + expected_result = pl.DataFrame({'key': ['A', 'B'], 'value': [4, 6]}) + print(f"summed_df.data: {summed_df.data}") + print(f"expected_result: {expected_result}") + + assert summed_df.data.equals(expected_result) # Check summed values + + +#################### +# UnifiedDataFrame + + +def test_unified_dataframe(): + pandas_data = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + polars_data = PolarsDataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + pandas_df = UnifiedDataFrame(pandas_data) + polars_df = UnifiedDataFrame(polars_data) + + assert isinstance(pandas_df, UnifiedDataFrame) + assert isinstance(polars_df, UnifiedDataFrame) + + assert len(pandas_df) == 3 + assert len(polars_df) == 3 + + assert pandas_df.columns == ["A", "B"] + assert polars_df.columns == ["A", "B"] + + assert pandas_df["A"].tolist() == [1, 2, 3] + assert polars_df["A"].tolist() == [1, 2, 3] + + pandas_df["C"] = [7, 8, 9] + polars_df["C"] = [7, 8, 9] + + assert pandas_df.columns == ["A", "B", "C"] + assert polars_df.columns == ["A", "B", "C"] + + sorted_pandas_df = pandas_df.sort_values("A") + sorted_polars_df = polars_df.sort_values("A") + + assert sorted_pandas_df["A"].tolist() == [1, 2, 3] + assert sorted_polars_df["A"].tolist() == [1, 2, 3] + + assert sorted_pandas_df.reset_index(drop=True).index.tolist() == [0, 1, 2] + assert sorted_polars_df.reset_index(drop=True).index == [0, 1, 2] + + assert sorted_pandas_df.duplicated().tolist() == [False, False, False] + assert sorted_polars_df.duplicated().to_list() == [False, False, False] + +# Test UnifiedDataFrame.plot() +def test_unified_dataframe_plot(): + pandas_data = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + polars_data = PolarsDataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + pandas_df = UnifiedDataFrame(pandas_data) + polars_df = UnifiedDataFrame(polars_data) + + pandas_plot = pandas_df.plot(x="A", y="B", kind="line") + polars_plot = polars_df.plot(x="A", y="B", kind="line") + + assert isinstance(pandas_plot, object) # Replace with the actual type of the plot object + assert isinstance(polars_plot, object) # Replace with the actual type of the plot object + +# Test UnifiedDataFrame.tolist() +def test_unified_dataframe_tolist(): + pandas_data = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + polars_data = PolarsDataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + pandas_df = UnifiedDataFrame(pandas_data) + polars_df = UnifiedDataFrame(polars_data) + + assert pandas_df.tolist("A") == [1, 2, 3] + assert polars_df.tolist("A") == [1, 2, 3] + +# Test UnifiedDataFrame.to_dict() +def test_unified_dataframe_to_dict(): + pandas_data = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + polars_data = PolarsDataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}) + + pandas_df = UnifiedDataFrame(pandas_data) + polars_df = UnifiedDataFrame(polars_data) + + assert pandas_df.to_dict(orient="list") == {"A": [1, 2, 3], "B": [4, 5, 6]} + assert polars_df.to_dict() == {"A": [1, 2, 3], "B": [4, 5, 6]} \ No newline at end of file