Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyopenms_viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import types
from pathlib import Path

import pyopenms_viz._dataframe # noqa: F401

__version__ = "0.1.5"


Expand Down
12 changes: 6 additions & 6 deletions pyopenms_viz/_bokeh/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs):
kwargs["line_width"] = 2.5

if by is None:
source = ColumnDataSource(data)
source = ColumnDataSource(data.to_dict())
if color_gen is not None:
kwargs["line_color"] = (
color_gen if isinstance(color_gen, str) else next(color_gen)
Expand All @@ -270,7 +270,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs):

legend_items = []
for group, df in data.groupby(by, sort=False):
source = ColumnDataSource(df)
source = ColumnDataSource(df.to_dict())
if color_gen is not None:
kwargs["line_color"] = (
color_gen if isinstance(color_gen, str) else next(color_gen)
Expand Down Expand Up @@ -303,7 +303,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs):
if not plot_3d:
direction = kwargs.pop("direction", "vertical")
if by is None:
source = ColumnDataSource(data)
source = ColumnDataSource(data.to_dict())
if direction == "horizontal":
x0_data_var = 0
x1_data_var = x
Expand All @@ -326,7 +326,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs):
else:
legend_items = []
for group, df in data.groupby(by):
source = ColumnDataSource(df)
source = ColumnDataSource(df.to_dict())
if direction == "horizontal":
x0_data_var = 0
x1_data_var = x
Expand Down Expand Up @@ -422,7 +422,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs):
kwargs[k] = v
if by is None:
kwargs["marker"] = next(marker_gen)
source = ColumnDataSource(data)
source = ColumnDataSource(data.to_dict())
line = fig.scatter(x=x, y=y, source=source, **kwargs)
return fig, None
else:
Expand All @@ -431,7 +431,7 @@ def plot(cls, fig, data, x, y, by: str | None = None, plot_3d=False, **kwargs):
kwargs["marker"] = next(marker_gen)
if z is None:
kwargs["fill_color"] = next(color_gen)
source = ColumnDataSource(df)
source = ColumnDataSource(df.to_dict())
line = fig.scatter(x=x, y=y, source=source, **kwargs)
legend_items.append((group, [line]))
legend = Legend(items=legend_items)
Expand Down
35 changes: 21 additions & 14 deletions pyopenms_viz/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import types
import re

from pandas import cut, merge, Interval, concat
from pandas import cut, Interval
from pandas.core.frame import DataFrame
from pandas.core.dtypes.generic import ABCDataFrame
from pandas.core.dtypes.common import is_integer
from pandas.util._decorators import Appender

from numpy import ceil, log1p, log2, nan, mean, repeat, concatenate

from ._dataframe import UnifiedDataFrame, concat
from ._config import LegendConfig, FeatureConfig, _BasePlotConfig
from ._misc import (
ColorGenerator,
Expand All @@ -25,6 +26,8 @@
from .constants import IS_SPHINX_BUILD
import warnings

import polars as pl


_common_kinds = ("line", "vline", "scatter")
_msdata_kinds = ("chromatogram", "mobilogram", "spectrum", "peakmap")
Expand Down Expand Up @@ -148,9 +151,9 @@ def __init__(
_config: _BasePlotConfig | None = None,
**kwargs,
) -> None:

# Data attributes
self.data = data.copy()
self.data = UnifiedDataFrame(data) if not isinstance(data, UnifiedDataFrame) else data
self.kind = kind
self.by = by
self.plot_3d = plot_3d
Expand Down Expand Up @@ -564,6 +567,7 @@ def __init__(
super().__init__(data, x, y, **kwargs)

if annotation_data is not None:
annotation_data = UnifiedDataFrame(annotation_data)
self.annotation_data = annotation_data.copy()
else:
self.annotation_data = None
Expand Down Expand Up @@ -725,7 +729,7 @@ def __init__(

super().__init__(data, x, y, **kwargs)

self.reference_spectrum = reference_spectrum
self.reference_spectrum = UnifiedDataFrame(reference_spectrum) if reference_spectrum is not None else None
self.mirror_spectrum = mirror_spectrum
self.relative_intensity = relative_intensity
self.bin_peaks = bin_peaks
Expand Down Expand Up @@ -920,7 +924,7 @@ def _prepare_data(
# copy spectrum data to not modify the original
spectrum = spectrum.copy()
reference_spectrum = (
self.reference_spectrum.copy() if reference_spectrum is not None else None
self.reference_spectrum if reference_spectrum is not None else None
)

# Convert to relative intensity if required
Expand Down Expand Up @@ -960,6 +964,7 @@ def _get_colors(
colors = [next(color_gen) for _ in range(len(uniques))]
# Create a mapping of unique values to their corresponding colors
color_map = {uniques[i]: colors[i] for i in range(len(colors))}

# Apply the color mapping to the specified column in the data and turn it into a ColorGenerator
return ColorGenerator(data[self.by].apply(lambda x: color_map[x]))
# Fallback ColorGenerator with one color
Expand Down Expand Up @@ -1141,6 +1146,7 @@ def __init__(
self.fill_by_z = fill_by_z

if annotation_data is not None:
annotation_data = UnifiedDataFrame(annotation_data)
self.annotation_data = annotation_data.copy()
else:
self.annotation_data = None
Expand Down Expand Up @@ -1398,7 +1404,7 @@ def __init__(self, data: DataFrame) -> None:
def __call__(self, *args: Any, **kwargs: Any) -> Any:
backend_name = kwargs.get("backend", None)
if backend_name is None:
backend_name = "matplotlib"
backend_name = "ms_matplotlib"

plot_backend = _get_plot_backend(backend_name)

Expand Down Expand Up @@ -1439,7 +1445,7 @@ def _get_call_args(backend_name: str, data: DataFrame, args, kwargs):
dict
The arguments to pass to the plotting backend.
"""
if isinstance(data, ABCDataFrame):
if isinstance(data, ABCDataFrame) or isinstance(data, pl.DataFrame):
arg_def = [
("x", None),
("y", None),
Expand Down Expand Up @@ -1506,35 +1512,35 @@ def _load_backend(backend: str) -> types.ModuleType:
types.ModuleType
The imported backend.
"""
if backend == "bokeh":
if backend == "ms_bokeh":
try:
module = importlib.import_module("pyopenms_viz.plotting._bokeh")
module = importlib.import_module("pyopenms_viz._bokeh")
except ImportError:
raise ImportError(
"Bokeh is required for plotting when the 'bokeh' backend is selected."
) from None
return module

elif backend == "matplotlib":
elif backend == "ms_matplotlib":
try:
module = importlib.import_module("pyopenms_viz.plotting._matplotlib")
module = importlib.import_module("pyopenms_viz._matplotlib")
except ImportError:
raise ImportError(
"Matplotlib is required for plotting when the 'matplotlib' backend is selected."
) from None
return module

elif backend == "plotly":
elif backend == "ms_plotly":
try:
module = importlib.import_module("pyopenms_viz.plotting._plotly")
module = importlib.import_module("pyopenms_viz._plotly")
except ImportError:
raise ImportError(
"Plotly is required for plotting when the 'plotly' backend is selected."
) from None
return module

raise ValueError(
f"Could not find plotting backend '{backend}'. Needs to be one of 'bokeh', 'matplotlib', or 'plotly'."
f"Could not find plotting backend '{backend}'. Needs to be one of 'ms_bokeh', 'ms_matplotlib', or 'ms_plotly'."
)


Expand All @@ -1548,3 +1554,4 @@ def _get_plot_backend(backend: str | None = None):
module = _load_backend(backend_str)
_backends[backend_str] = module
return module

Loading