diff --git a/docs/gallery_scripts_template/plot_investigate_spectrum_binning_ms_matplotlib.py b/docs/gallery_scripts_template/plot_investigate_spectrum_binning_ms_matplotlib.py index ed0c190d..1aac6254 100644 --- a/docs/gallery_scripts_template/plot_investigate_spectrum_binning_ms_matplotlib.py +++ b/docs/gallery_scripts_template/plot_investigate_spectrum_binning_ms_matplotlib.py @@ -1,26 +1,40 @@ """ -Investigate Spctrum Binning ms_matplotlib +Investigate Spectrum Binning ms_matplotlib ======================================= -Here we use a dummy spectrum example to investigate spectrum binning. +Here we use a dummy spectrum example to investigate spectrum binning. """ import pandas as pd import matplotlib.pyplot as plt import requests from io import StringIO +import sys +import os +# # Add parent directories to the path (adjust as necessary) +# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +# Set the plotting backend to ms_matplotlib pd.options.plotting.backend = "ms_matplotlib" -# download the file for example plotting -url = ( - "https://github.com/OpenMS/pyopenms_viz/releases/download/v0.1.5/TestSpectrumDf.tsv" -) +# Download the file for example plotting +url = "https://github.com/OpenMS/pyopenms_viz/releases/download/v0.1.5/TestSpectrumDf.tsv" response = requests.get(url) response.raise_for_status() # Check for any HTTP errors df = pd.read_csv(StringIO(response.text), sep="\t") -# Let's assess the peak binning and create a 4 by 2 subplot to visualize the different methods of binning +# Add a 'Run' column and duplicate entries for each run group. +# For example, here we create three run groups (1, 2, and 3). +runs = [1, 2, 3] +df_list = [] +for run in runs: + df_run = df.copy() + df_run["Run"] = run + df_list.append(df_run) +df = pd.concat(df_list, ignore_index=True) + +# Update the parameters for binning and visualization. params_list = [ {"title": "Spectrum (Raw)", "bin_peaks": False}, { @@ -72,18 +86,26 @@ }, ] -# Create a 3-row subplot +# Create a 4x2 subplot grid to visualize different binning methods. fig, axs = plt.subplots(4, 2, figsize=(14, 14)) i = j = 0 for params in params_list: - p = df.plot( - kind="spectrum", x="mz", y="intensity", canvas=axs[i][j], grid=False, **params + # Here we pass the "Run" column to group the spectrum by run. + df.plot( + kind="spectrum", + x="mz", + y="intensity", + canvas=axs[i][j], + grid=False, + show_plot=False, + by="Run", + **params ) j += 1 - if j >= 2: # If we've filled two columns, move to the next row + if j >= 2: # Move to next row when two columns are filled. j = 0 i += 1 fig.tight_layout() -fig.show() +plt.show() diff --git a/docs/gallery_scripts_template/plot_spyogenes_subplots_ms_matplotlib.py b/docs/gallery_scripts_template/plot_spyogenes_subplots_ms_matplotlib.py index ec662873..f183960c 100644 --- a/docs/gallery_scripts_template/plot_spyogenes_subplots_ms_matplotlib.py +++ b/docs/gallery_scripts_template/plot_spyogenes_subplots_ms_matplotlib.py @@ -1,6 +1,6 @@ """ Plot Spyogenes subplots ms_matplotlib -======================================= +==================================================== Here we show how we can plot multiple chromatograms across runs together """ @@ -8,9 +8,13 @@ import pandas as pd import requests import zipfile -import numpy as np import matplotlib.pyplot as plt +import sys +import os +# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) + +# Set the plotting backend pd.options.plotting.backend = "ms_matplotlib" ###### Load Data ####### @@ -21,76 +25,51 @@ # Download the zip file try: + # Save the zip file to the current directory print(f"Downloading {zip_filename}...") response = requests.get(url) response.raise_for_status() # Check for any HTTP errors - - # Save the zip file to the current directory with open(zip_filename, "wb") as out: out.write(response.content) print(f"Downloaded {zip_filename} successfully.") except requests.RequestException as e: print(f"Error downloading zip file: {e}") -except IOError as e: - print(f"Error writing zip file: {e}") -# Unzipping the file +# Unzip the file try: with zipfile.ZipFile(zip_filename, "r") as zip_ref: - # Extract all files to the current directory zip_ref.extractall() print("Unzipped files successfully.") except zipfile.BadZipFile as e: print(f"Error unzipping file: {e}") -annotation_bounds = pd.read_csv( - "spyogenes/AADGQTVSGGSILYR3_manual_annotations.tsv", sep="\t" -) # contain annotations across all runs -chrom_df = pd.read_csv( - "spyogenes/chroms_AADGQTVSGGSILYR3.tsv", sep="\t" -) # contains chromatogram for precursor across all runs - -##### Set Plotting Variables ##### -pd.options.plotting.backend = "ms_matplotlib" -RUN_NAMES = [ - "Run #0 Spyogenes 0% human plasma", - "Run #1 Spyogenes 0% human plasma", - "Run #2 Spyogenes 0% human plasma", - "Run #3 Spyogenes 10% human plasma", - "Run #4 Spyogenes 10% human plasma", - "Run #5 Spyogenes 10% human plasma", -] - -fig, axs = plt.subplots(len(np.unique(chrom_df["run"])), 1, figsize=(10, 15)) - -# plt.close ### required for running in jupyter notebook setting - -# For each run fill in the axs object with the corresponding chromatogram -plot_list = [] -for i, run in enumerate(RUN_NAMES): - run_df = chrom_df[chrom_df["run_name"] == run] - current_bounds = annotation_bounds[annotation_bounds["run"] == run] +# Load the data +annotation_bounds = pd.read_csv("spyogenes/AADGQTVSGGSILYR3_manual_annotations.tsv", sep="\t") +chrom_df = pd.read_csv("spyogenes/chroms_AADGQTVSGGSILYR3.tsv", sep="\t") - run_df.plot( - kind="chromatogram", - x="rt", - y="int", - grid=False, - by="ion_annotation", - title=run_df.iloc[0]["run_name"], - title_font_size=16, - xaxis_label_font_size=14, - yaxis_label_font_size=14, - xaxis_tick_font_size=12, - yaxis_tick_font_size=12, - canvas=axs[i], - relative_intensity=True, - annotation_data=current_bounds, - xlabel="Retention Time (sec)", - ylabel="Relative\nIntensity", - annotation_legend_config=dict(show=False), - legend_config={"show": False}, - ) +##### Plotting Using Tile By ##### +# Instead of pre-creating subplots and looping over RUN_NAMES, +# we call the plot method once and provide a facet_column parameter. +fig = chrom_df.plot( + kind="chromatogram", + x="rt", + y="int", + facet_column="run_name", # Automatically groups data by run_name and creates subplots + facet_col_wrap=1, # Layout: 1 column (one subplot per row) + grid=False, + by="ion_annotation", + title_font_size=16, + xaxis_label_font_size=14, + yaxis_label_font_size=14, + xaxis_tick_font_size=12, + yaxis_tick_font_size=12, + relative_intensity=True, + annotation_data=annotation_bounds, + xlabel="Retention Time (sec)", + ylabel="Relative\nIntensity", + annotation_legend_config={"show": False}, + legend_config={"show": False}, +) fig.tight_layout() fig diff --git a/pyopenms_viz/__init__.py b/pyopenms_viz/__init__.py index e822cc9f..e43e0ee2 100644 --- a/pyopenms_viz/__init__.py +++ b/pyopenms_viz/__init__.py @@ -119,6 +119,7 @@ def _get_call_args(backend_name: str, data: DataFrame, args, kwargs): ("feature_config", None), ("_config", None), ("backend", backend_name), + # ("tile_by", None), ] else: raise ValueError( diff --git a/pyopenms_viz/_config.py b/pyopenms_viz/_config.py index 0b79487c..341adc07 100644 --- a/pyopenms_viz/_config.py +++ b/pyopenms_viz/_config.py @@ -205,6 +205,9 @@ def default_legend_factory(): legend_config: LegendConfig | dict = field(default_factory=default_legend_factory) opacity: float = 1.0 + facet_column: str | None = None # Name of the column to tile the plot by. + facet_col_wrap: int = 1 # How many columns in the subplot grid. + def __post_init__(self): # if legend_config is a dictionary, update it to LegendConfig object if isinstance(self.legend_config, dict): diff --git a/pyopenms_viz/_core.py b/pyopenms_viz/_core.py index 62d51c02..962b5f4e 100644 --- a/pyopenms_viz/_core.py +++ b/pyopenms_viz/_core.py @@ -14,6 +14,7 @@ from pandas.util._decorators import Appender import re + from numpy import ceil, log1p, log2, nan, mean, repeat, concatenate from ._config import ( LegendConfig, @@ -539,7 +540,6 @@ def _create_tooltips(self, entries: dict, index: bool = True): class ChromatogramPlot(BaseMSPlot, ABC): - _config: ChromatogramConfig = None @property @@ -560,55 +560,112 @@ def load_config(self, **kwargs): def __init__(self, data, config: ChromatogramConfig = None, **kwargs) -> None: super().__init__(data, config, **kwargs) - self.label_suffix = self.x # set label suffix for bounding box - self._check_and_aggregate_duplicates() - # sort data by x so in order + # Sort data by x (and by self.by if provided) so the data is in order. if self.by is not None: self.data.sort_values(by=[self.by, self.x], inplace=True) else: self.data.sort_values(by=self.x, inplace=True) - # Convert to relative intensity if required + # Convert to relative intensity if required. if self.relative_intensity: self.data[self.y] = self.data[self.y] / self.data[self.y].max() * 100 + # Perform all validations for the plotting configuration. + self._validate_plot_config() + + # Proceed to generate the plot. self.plot() + def _validate_plot_config(self): + """ + Validate plot configuration options (e.g., tiling parameters) before plotting. + """ + # Validate the facet_column option: check if the specified column exists in the data. + if hasattr(self._config, "facet_column"): + facet_column = self._config.facet_column + if facet_column not in self.data.columns: + warnings.warn( + f"facet_column column '{facet_column}' not found in data. Plot will be generated without tiling." + ) + self._config.facet_column = None + + # Validate facet_col_wrap: ensure it is a positive integer. + if hasattr(self._config, "facet_col_wrap"): + if not isinstance(self._config.facet_col_wrap, int) or self._config.facet_col_wrap < 1: + warnings.warn("facet_col_wrap must be a positive integer. Defaulting to 1.") + self._config.facet_col_wrap = 1 + def plot(self): """ - Create the plot + Create the plot using the validated configuration. """ + facet_column = self._config.facet_column + + # Define tooltips based on the overall data columns. tooltip_entries = {"retention time": self.x, "intensity": self.y} if "Annotation" in self.data.columns: tooltip_entries["annotation"] = "Annotation" if "product_mz" in self.data.columns: tooltip_entries["product m/z"] = "product_mz" - tooltips, custom_hover_data = self._create_tooltips( - tooltip_entries, index=False - ) - - linePlot = self.get_line_renderer(data=self.data, config=self._config) - - self.canvas = linePlot.generate(tooltips, custom_hover_data) - self._modify_y_range((0, self.data[self.y].max()), (0, 0.1)) - - if self._interactive: - self.manual_boundary_renderer = self._add_bounding_vertical_drawer() - - if self.annotation_data is not None: - self._add_peak_boundaries(self.annotation_data) + tooltips, custom_hover_data = self._create_tooltips(tooltip_entries, index=False) + + if facet_column: + # Group the data by the facet_column column. + grouped = self.data.groupby(facet_column) + num_groups = len(grouped) + + # Use tiling options from the configuration and set instance properties. + self.facet_col_wrap = self._config.facet_col_wrap + self.tile_rows = int(ceil(num_groups / self.facet_col_wrap)) + + # Create a figure with a grid of subplots. + fig, axes = self._create_subplots(self.tile_rows, self.facet_col_wrap) + + # Loop through each group and generate the corresponding subplot. + for i, (group_val, group_df) in enumerate(grouped): + ax = axes[i] + # Construct the title for this subplot. + title = f"{facet_column}: {group_val}" + + # Get a line renderer instance and generate the plot for the current group, + # passing the current axis (canvas) and title directly. + linePlot = self.get_line_renderer(data=group_df, config=self._config, canvas=ax, title=title) + linePlot.generate(tooltips, custom_hover_data) + + # Use the abstracted function to modify the y-axis range. + self._modify_y_range((0, group_df[self.y].max()), (0, 0.1)) + + # Add annotations for the current group if available. + if self.annotation_data is not None and facet_column in self.annotation_data.columns: + group_annotations = self.annotation_data[self.annotation_data[facet_column] == group_val] + self._add_peak_boundaries(group_annotations) + + # Remove any extra axes if the grid size is larger than the number of groups. + self._delete_extra_axes(axes, start_index=i + 1) + + # fig.tight_layout() + self.canvas = fig + else: + # For the non-tiled case, create the plot for the entire dataset. + linePlot = self.get_line_renderer(data=self.data, config=self._config) + # Here, spectrumPlot.generate returns an Axes, not a Figure. + self.canvas = linePlot.generate(tooltips, custom_hover_data) + self._modify_y_range((0, self.data[self.y].max()), (0, 0.1)) + + if self._interactive: + self.manual_boundary_renderer = self._add_bounding_vertical_drawer() + if self.annotation_data is not None: + self._add_peak_boundaries(self.annotation_data) def _add_peak_boundaries(self, annotation_data): """ Prepare data for adding peak boundaries to the plot. This is not a complete method should be overridden by subclasses. - Args: annotation_data (DataFrame): The feature data containing the peak boundaries. - Returns: None """ @@ -617,7 +674,7 @@ def _add_peak_boundaries(self, annotation_data): def compute_apex_intensity(self, annotation_data): """ - Compute the apex intensity of the peak group based on the peak boundaries + Compute the apex intensity of the peak group based on the peak boundaries. """ for idx, feature in annotation_data.iterrows(): annotation_data.loc[idx, "apexIntensity"] = self.data.loc[ @@ -658,35 +715,48 @@ def _kind(self): @property def known_columns(self) -> List[str]: """ - List of known columns in the data, if there are duplicates outside of these columns they will be grouped in aggregation if specified + List of known columns in the data. Any duplicates outside these columns + will be grouped in aggregation if specified. """ known_columns = super().known_columns known_columns.extend([self.peak_color] if self.peak_color is not None else []) - known_columns.extend( - [self.ion_annotation] if self.ion_annotation is not None else [] - ) - known_columns.extend( - [self.sequence_annotation] if self.sequence_annotation is not None else [] - ) - known_columns.extend( - [self.custom_annotation] if self.custom_annotation is not None else [] - ) - known_columns.extend( - [self.annotation_color] if self.annotation_color is not None else [] - ) + known_columns.extend([self.ion_annotation] if self.ion_annotation is not None else []) + known_columns.extend([self.sequence_annotation] if self.sequence_annotation is not None else []) + known_columns.extend([self.custom_annotation] if self.custom_annotation is not None else []) + known_columns.extend([self.annotation_color] if self.annotation_color is not None else []) return known_columns @property def _configClass(self): return SpectrumConfig - def __init__( - self, - data, - **kwargs, - ) -> None: + def __init__(self, data, **kwargs) -> None: super().__init__(data, **kwargs) + + # (Other validations like _check_and_aggregate_duplicates, sorting, etc.) + self._check_and_aggregate_duplicates() + + # Sort data by x (and by self.by if provided) + if self.by is not None: + self.data.sort_values(by=[self.by, self.x], inplace=True) + else: + self.data.sort_values(by=self.x, inplace=True) + + # Convert to relative intensity if required. + if self.relative_intensity: + self.data[self.y] = self.data[self.y] / self.data[self.y].max() * 100 + + # Validate tiling configuration in the constructor + if hasattr(self._config, "facet_column"): + facet_column = self._config.facet_column + if facet_column not in self.data.columns: + warnings.warn( + f"facet_column column '{facet_column}' not found in data. Plot will be generated without tiling." + ) + self._config.facet_column = None + # (Other configuration validations can be added here as needed.) + # Proceed to generate the plot self.plot() def load_config(self, **kwargs): @@ -699,7 +769,6 @@ def load_config(self, **kwargs): def _check_and_aggregate_duplicates(self): super()._check_and_aggregate_duplicates() - if self.reference_spectrum is not None: if self.reference_spectrum[self.known_columns].duplicated().any(): if self.aggregate_duplicates: @@ -717,7 +786,8 @@ def _check_and_aggregate_duplicates(self): @property def _peak_bins(self): """ - Get a list of intervals to use in bins. Here bins are not evenly spaced. Currently this only occurs in mz-tol setting + Get a list of intervals to use in bins. + Currently only used for the mz-tol setting. """ if self.bin_method == "mz-tol-bin" and self.bin_peaks == "auto": return mz_tolerance_binning(self.data, self.x, self.mz_tol) @@ -728,7 +798,6 @@ def _peak_bins(self): def _computed_num_bins(self): """ Compute the number of bins based on the number of peaks in the data. - Returns: int: The number of bins. """ @@ -741,173 +810,155 @@ def _computed_num_bins(self): return None elif self.bin_method == "none": return self.num_x_bins - else: # throw error if bin_method is not recognized + else: raise ValueError(f"bin_method {self.bin_method} not recognized") else: return self.num_x_bins + def plot(self): """Standard spectrum plot with m/z on x-axis, intensity on y-axis and optional mirror spectrum.""" + + facet_column = self._config.facet_column + if facet_column: + # Group data by facet_column column and assign tiling properties + grouped = self.data.groupby(facet_column) + num_groups = len(grouped) + self.facet_col_wrap = self._config.facet_col_wrap + self.tile_rows = ceil(num_groups / self.facet_col_wrap) + + # Create a figure with a grid of subplots using the backend-specific helper + fig, axes = self._create_subplots(self.tile_rows, self.facet_col_wrap) + + for i, (group_val, group_df) in enumerate(grouped): + # Prepare group-specific spectrum and reference data + group_spectrum = self._prepare_data(group_df) + if self.reference_spectrum is not None and facet_column in self.reference_spectrum.columns: + group_reference = self._prepare_data( + self.reference_spectrum[self.reference_spectrum[facet_column] == group_val] + ) + else: + group_reference = None + + # Prepare tooltips for this group + entries = {"m/z": self.x, "intensity": self.y} + for optional in ("native_id", self.ion_annotation, self.sequence_annotation): + if optional in group_df.columns: + entries[optional.replace("_", " ")] = optional + tooltips, custom_hover_data = self._create_tooltips(entries=entries, index=False) + + # Determine grouping for color generation + if self.peak_color is not None and self.peak_color in group_df.columns: + self.by = self.peak_color + elif self.ion_annotation is not None and self.ion_annotation in group_df.columns: + self.by = self.ion_annotation + + # Get annotations and convert data for line plots + ann_texts, ann_xs, ann_ys, ann_colors = self._get_annotations(group_spectrum, self.x, self.y) + group_spectrum = self.convert_for_line_plots(group_spectrum, self.x, self.y) + self.color = self._get_colors(group_spectrum, kind="peak") + + # Pass title directly into the renderer for backend abstraction + title = f"{facet_column}: {group_val}" + spectrumPlot = self.get_line_renderer( + data=group_spectrum, by=self.by, color=self.color, config=self._config, title=title + ) + spectrumPlot.canvas = axes[i] + spectrumPlot.generate(tooltips, custom_hover_data) + spectrumPlot._add_annotations(axes[i], ann_texts, ann_xs, ann_ys, ann_colors) + + # Handle mirror spectrum if applicable + if self.mirror_spectrum and group_reference is not None: + group_reference[self.y] = group_reference[self.y] * -1 + color_mirror = self._get_colors(group_reference, kind="peak") + group_reference = self.convert_for_line_plots(group_reference, self.x, self.y) + _, reference_custom_hover_data = self.get_spectrum_tooltip_data(group_reference, self.x, self.y) + mirrorSpectrumPlot = self.get_line_renderer( + data=group_reference, color=color_mirror, config=self._config + ) + mirrorSpectrumPlot.canvas = axes[i] + mirrorSpectrumPlot.generate(None, None) + ann_texts, ann_xs, ann_ys, ann_colors = self._get_annotations(group_reference, self.x, self.y) + mirrorSpectrumPlot._add_annotations(axes[i], ann_texts, ann_xs, ann_ys, ann_colors) + + # Delete extra axes if present and finalize layout + self._delete_extra_axes(axes, start_index=i + 1) + # fig.tight_layout() + self.canvas = fig - # Prepare data - spectrum = self._prepare_data(self.data) - if self.reference_spectrum is not None: - reference_spectrum = self._prepare_data(self.reference_spectrum) else: - reference_spectrum = None - - entries = {"m/z": self.x, "intensity": self.y} - for optional in ( - "native_id", - self.ion_annotation, - self.sequence_annotation, - ): - if optional in self.data.columns: - entries[optional.replace("_", " ")] = optional - - tooltips, custom_hover_data = self._create_tooltips( - entries=entries, index=False - ) - - # color generation is more complex for spectrum plots, so it has its own methods - - # Peak colors are determined by peak_color column (highest priorty) or ion_annotation column (second priority) or "by" column (lowest priority) - if self.peak_color is not None and self.peak_color in self.data.columns: - self.by = self.peak_color - elif ( - self.ion_annotation is not None and self.ion_annotation in self.data.columns - ): - self.by = self.ion_annotation + # Single-plot behavior when tiling is not requested + spectrum = self._prepare_data(self.data) + reference_spectrum = self._prepare_data(self.reference_spectrum) if self.reference_spectrum is not None else None + entries = {"m/z": self.x, "intensity": self.y} + for optional in ("native_id", self.ion_annotation, self.sequence_annotation): + if optional in self.data.columns: + entries[optional.replace("_", " ")] = optional + tooltips, custom_hover_data = self._create_tooltips(entries=entries, index=False) + + if self.peak_color is not None and self.peak_color in self.data.columns: + self.by = self.peak_color + elif self.ion_annotation is not None and self.ion_annotation in self.data.columns: + self.by = self.ion_annotation + + ann_texts, ann_xs, ann_ys, ann_colors = self._get_annotations(spectrum, self.x, self.y) + spectrum = self.convert_for_line_plots(spectrum, self.x, self.y) + self.color = self._get_colors(spectrum, kind="peak") + spectrumPlot = self.get_line_renderer(data=spectrum, by=self.by, color=self.color, config=self._config) + self.canvas = spectrumPlot.generate(tooltips, custom_hover_data) + spectrumPlot._add_annotations(self.canvas, ann_texts, ann_xs, ann_ys, ann_colors) + + if self.mirror_spectrum and self.reference_spectrum is not None: + reference_spectrum[self.y] = reference_spectrum[self.y] * -1 + color_mirror = self._get_colors(reference_spectrum, kind="peak") + reference_spectrum = self.convert_for_line_plots(reference_spectrum, self.x, self.y) + _, reference_custom_hover_data = self.get_spectrum_tooltip_data(reference_spectrum, self.x, self.y) + mirrorSpectrumPlot = self.get_line_renderer(data=reference_spectrum, color=color_mirror, config=self._config) + mirrorSpectrumPlot.generate(None, None) + ann_texts, ann_xs, ann_ys, ann_colors = self._get_annotations(reference_spectrum, self.x, self.y) + mirrorSpectrumPlot._add_annotations(self.canvas, ann_texts, ann_xs, ann_ys, ann_colors) + + self.plot_x_axis_line(self.canvas, line_width=2) + min_values = [spectrum[self.x].min()] + max_values = [spectrum[self.x].max()] + if reference_spectrum is not None: + min_values.append(reference_spectrum[self.x].min()) + max_values.append(reference_spectrum[self.x].max()) + self._modify_x_range((min(min_values), max(max_values)), padding=(0.20, 0.20)) + + # Use the helper to compute y-range and padding, then apply + y_range, y_padding = self._compute_y_range_and_padding(spectrum, reference_spectrum) + self._modify_y_range(y_range, padding=y_padding) - # Annotations for spectrum - ann_texts, ann_xs, ann_ys, ann_colors = self._get_annotations( - spectrum, self.x, self.y - ) - - # Convert to line plot format - spectrum = self.convert_for_line_plots(spectrum, self.x, self.y) - - self.color = self._get_colors(spectrum, kind="peak") - spectrumPlot = self.get_line_renderer( - data=spectrum, by=self.by, color=self.color, config=self._config - ) - self.canvas = spectrumPlot.generate(tooltips, custom_hover_data) - spectrumPlot._add_annotations( - self.canvas, ann_texts, ann_xs, ann_ys, ann_colors - ) - - # Mirror spectrum - if self.mirror_spectrum and self.reference_spectrum is not None: - ## create a mirror spectrum - # Set intensity to negative values - reference_spectrum[self.y] = reference_spectrum[self.y] * -1 - - color_mirror = self._get_colors(reference_spectrum, kind="peak") - reference_spectrum = self.convert_for_line_plots( - reference_spectrum, self.x, self.y - ) - - _, reference_custom_hover_data = self.get_spectrum_tooltip_data( - reference_spectrum, self.x, self.y - ) - mirrorSpectrumPlot = self.get_line_renderer( - data=reference_spectrum, color=color_mirror, config=self._config - ) - - mirrorSpectrumPlot.generate(None, None) - - # Annotations for reference spectrum - ann_texts, ann_xs, ann_ys, ann_colors = self._get_annotations( - reference_spectrum, self.x, self.y - ) - mirrorSpectrumPlot._add_annotations( - self.canvas, ann_texts, ann_xs, ann_ys, ann_colors - ) - - # Plot horizontal line to hide connection between peaks - self.plot_x_axis_line(self.canvas, line_width=2) - - # Adjust x axis padding (Plotly cuts outermost peaks) - min_values = [spectrum[self.x].min()] - max_values = [spectrum[self.x].max()] - if reference_spectrum is not None: - min_values.append(reference_spectrum[self.x].min()) - max_values.append(reference_spectrum[self.x].max()) - self._modify_x_range((min(min_values), max(max_values)), padding=(0.20, 0.20)) - # Adjust y axis padding (annotations should stay inside plot) - max_value = spectrum[self.y].max() - min_value = 0 - min_padding = 0 - max_padding = 0.15 - if reference_spectrum is not None and self.mirror_spectrum: - min_value = reference_spectrum[self.y].min() - min_padding = -0.2 - max_padding = 0.4 - - self._modify_y_range((min_value, max_value), padding=(min_padding, max_padding)) def _bin_peaks(self, df: DataFrame) -> DataFrame: """ Bin peaks based on x-axis values. - - Args: - data (DataFrame): The data to bin. - x (str): The column name for the x-axis data. - y (str): The column name for the y-axis data. - - Returns: - DataFrame: The binned data. """ - - # if _peak_bins is set that they are used as the bins over the num_bins parameter if self._peak_bins is not None: - # Function to assign each value to a bin def assign_bin(value): for low, high in self._peak_bins: if low <= value <= high: return f"{low:.4f}-{high:.4f}" - return nan # For values that don't fall into any bin - - # Apply the binning + return nan df[self.x] = df[self.x].apply(assign_bin) - else: # use computed number of bins, bins evenly spaced + else: bins = np.histogram_bin_edges(df[self.x], self._computed_num_bins) - def assign_bin(value): for low_idx in range(len(bins) - 1): if bins[low_idx] <= value <= bins[low_idx + 1]: return f"{bins[low_idx]:.4f}-{bins[low_idx + 1]:.4f}" - return nan # For values that don't fall into any bin - - # Apply the binning + return nan df[self.x] = df[self.x].apply(assign_bin) - - # TODO I am not sure why "cut" method seems to be failing with plotly so created a workaround for now - # error is that object is not JSON serializable because of Interval type - # df[self.x] = cut(df[self.x], bins=self._computed_num_bins) - - # TODO: Find a better way to retain other columns + # Retain other columns cols = [self.x] - if self.by is not None: - cols.append(self.by) - if self.peak_color is not None: - cols.append(self.peak_color) - if self.ion_annotation is not None: - cols.append(self.ion_annotation) - if self.sequence_annotation is not None: - cols.append(self.sequence_annotation) - if self.custom_annotation is not None: - cols.append(self.custom_annotation) - if self.annotation_color is not None: - cols.append(self.annotation_color) - - # Group by x bins and calculate the sum intensity within each bin - df = ( - df.groupby(cols, observed=True) - .agg({self.y: self.aggregation_method}) - .reset_index() - ) - + if self.by is not None: cols.append(self.by) + if self.peak_color is not None: cols.append(self.peak_color) + if self.ion_annotation is not None: cols.append(self.ion_annotation) + if self.sequence_annotation is not None: cols.append(self.sequence_annotation) + if self.custom_annotation is not None: cols.append(self.custom_annotation) + if self.annotation_color is not None: cols.append(self.annotation_color) + df = df.groupby(cols, observed=True).agg({self.y: self.aggregation_method}).reset_index() def convert_to_numeric(value): if isinstance(value, Interval): return value.mid @@ -915,94 +966,52 @@ def convert_to_numeric(value): return mean([float(i) for i in value.split("-")]) else: return value - df[self.x] = df[self.x].apply(convert_to_numeric).astype(float) - - df = df.fillna(0) - return df + return df.fillna(0) def _prepare_data(self, df, label_suffix=""): """ - Prepare data for plotting based on configuration (relative intensity, bin peaks) - - Args: - df (DataFrame): The data to prepare. - label_suffix (str, optional): The suffix to add to the label. Defaults to "", Only for plotly backend - - Returns: - DataFrame: The prepared data. + Prepare data for plotting based on configuration (e.g. relative intensity, bin peaks) """ - - # Convert to relative intensity if required if self.relative_intensity or self.mirror_spectrum: df[self.y] = df[self.y] / df[self.y].max() * 100 - - # Bin peaks if required - if self.bin_peaks == True or (self.bin_peaks == "auto"): + if self.bin_peaks == True or self.bin_peaks == "auto": df = self._bin_peaks(df) - return df - def _get_colors( - self, data: DataFrame, kind: Literal["peak", "annotation"] | None = None - ): + def _get_colors(self, data: DataFrame, kind: Literal["peak", "annotation"] | None = None): """Get color generators for peaks or annotations based on config.""" if kind == "annotation": - # Custom annotating colors with top priority - if ( - self.annotation_color is not None - and self.annotation_color in data.columns - ): + if self.annotation_color is not None and self.annotation_color in data.columns: return ColorGenerator(data[self.annotation_color]) - # Ion annotation colors - elif ( - self.ion_annotation is not None and self.ion_annotation in data.columns - ): - # Generate colors based on ion annotations - return ColorGenerator( - self._get_ion_color_annotation(data[self.ion_annotation]) - ) - # Grouped by colors (from default color map) + elif self.ion_annotation is not None and self.ion_annotation in data.columns: + return ColorGenerator(self._get_ion_color_annotation(data[self.ion_annotation])) elif self.by is not None: - # Get unique values to determine number of distinct colors uniques = data[self.by].unique() color_gen = ColorGenerator() - # Generate a list of colors equal to the number of unique values colors = [next(color_gen) for _ in range(len(uniques))] - # Create a mapping of unique values to their corresponding colors color_map = {uniques[i]: colors[i] for i in range(len(colors))} - # Apply the color mapping to the specified column in the data and turn it into a ColorGenerator return ColorGenerator(data[self.by].apply(lambda x: color_map[x])) - # Fallback ColorGenerator with one color return ColorGenerator(n=1) - else: # Peaks + else: if self.by: uniques = data[self.by].unique().tolist() - # Custom colors with top priority if self.peak_color is not None: return ColorGenerator(uniques) - # Colors based on ion annotation for peaks and annotation text if self.ion_annotation is not None and self.peak_color is None: return ColorGenerator(self._get_ion_color_annotation(uniques)) - # Else just use default colors return ColorGenerator() def _get_annotations(self, data: DataFrame, x: str, y: str): - """Create annotations for each peak. Return lists of texts, x and y locations and colors.""" - + """Create annotations for each peak.""" data["color"] = ["black" for _ in range(len(data))] - ann_texts = [] top_n = self.annotate_top_n_peaks if top_n == "all": top_n = len(data) elif top_n is None: top_n = 0 - # sort values for top intensity peaks on top (ascending for reference spectra with negative values) - data = data.sort_values( - y, ascending=True if data[y].min() < 0 else False - ).reset_index() - + data = data.sort_values(y, ascending=True if data[y].min() < 0 else False).reset_index() for i, row in data.iterrows(): texts = [] if i < top_n: @@ -1010,10 +1019,7 @@ def _get_annotations(self, data: DataFrame, x: str, y: str): texts.append(str(round(row[x], 4))) if self.ion_annotation and self.ion_annotation in data.columns: texts.append(str(row[self.ion_annotation])) - if ( - self.sequence_annotation - and self.sequence_annotation in data.columns - ): + if self.sequence_annotation and self.sequence_annotation in data.columns: texts.append(str(row[self.sequence_annotation])) if self.custom_annotation and self.custom_annotation in data.columns: texts.append(str(row[self.custom_annotation])) @@ -1021,26 +1027,20 @@ def _get_annotations(self, data: DataFrame, x: str, y: str): return ann_texts, data[x].tolist(), data[y].tolist(), data["color"].tolist() def _get_ion_color_annotation(self, ion_annotations: str) -> str: - """Retrieve the color associated with a specific ion annotation from a predefined colormap.""" + """Retrieve the color associated with an ion annotation.""" colormap = { "a": ColorGenerator.color_blind_friendly_map[ColorGenerator.Colors.PURPLE], "b": ColorGenerator.color_blind_friendly_map[ColorGenerator.Colors.BLUE], - "c": ColorGenerator.color_blind_friendly_map[ - ColorGenerator.Colors.LIGHTBLUE - ], + "c": ColorGenerator.color_blind_friendly_map[ColorGenerator.Colors.LIGHTBLUE], "x": ColorGenerator.color_blind_friendly_map[ColorGenerator.Colors.YELLOW], "y": ColorGenerator.color_blind_friendly_map[ColorGenerator.Colors.RED], "z": ColorGenerator.color_blind_friendly_map[ColorGenerator.Colors.ORANGE], } - def get_ion_color(ion): if isinstance(ion, str): for key in colormap.keys(): - # Exact matches if ion == key: return colormap[key] - # Fragment ions via regex - ## Check if ion format is a1+, a1-, etc. or if it's a1^1, a1^2, etc. if re.search(r"^[abcxyz]{1}[0-9]*[+-]$", ion): x = re.search(r"^[abcxyz]{1}[0-9]*[+-]$", ion) elif re.search(r"^[abcxyz]{1}[0-9]*\^[0-9]*$", ion): @@ -1049,10 +1049,7 @@ def get_ion_color(ion): x = None if x: return colormap[ion[0]] - return ColorGenerator.color_blind_friendly_map[ - ColorGenerator.Colors.DARKGRAY - ] - + return ColorGenerator.color_blind_friendly_map[ColorGenerator.Colors.DARKGRAY] return [get_ion_color(ion) for ion in ion_annotations] def to_line(self, x, y): @@ -1074,31 +1071,19 @@ def convert_for_line_plots(self, data: DataFrame, x: str, y: str) -> DataFrame: def get_spectrum_tooltip_data(self, spectrum: DataFrame, x: str, y: str): """Get tooltip data for a spectrum plot.""" - - # Need to group data in correct order for tooltips if self.by is not None: grouped = spectrum.groupby(self.by, sort=False) self.data = concat([group for _, group in grouped], ignore_index=True) - - # Hover tooltips with m/z, intensity and optional information entries = {"m/z": x, "intensity": y} - for optional in ( - "native_id", - self.ion_annotation, - self.sequence_annotation, - ): + for optional in ("native_id", self.ion_annotation, self.sequence_annotation): if optional in self.data.columns: entries[optional.replace("_", " ")] = optional - # Create tooltips and custom hover data with backend specific formatting - tooltips, custom_hover_data = self._create_tooltips( - entries=entries, index=False - ) - # Repeat data each time (since each peak is represented by three points in line plot) + tooltips, custom_hover_data = self._create_tooltips(entries=entries, index=False) custom_hover_data = repeat(custom_hover_data, 3, axis=0) - return tooltips, custom_hover_data + class PeakMapPlot(BaseMSPlot, ABC): # need to inherit from ChromatogramPlot and SpectrumPlot for get_line_renderer and get_vline_renderer methods respectively @property diff --git a/pyopenms_viz/_matplotlib/core.py b/pyopenms_viz/_matplotlib/core.py index 11e7c737..6c963d9b 100644 --- a/pyopenms_viz/_matplotlib/core.py +++ b/pyopenms_viz/_matplotlib/core.py @@ -4,6 +4,8 @@ from typing import Tuple import re from numpy import nan +import matplotlib +matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.lines import Line2D from matplotlib.patches import Rectangle @@ -35,6 +37,49 @@ class MATPLOTLIBPlot(BasePlot, ABC): data (DataFrame): The input data frame. """ + def _compute_y_range_and_padding(self, spectrum, reference_spectrum=None): + + max_value = spectrum[self.y].max() + if reference_spectrum is not None and self.mirror_spectrum: + min_value = reference_spectrum[self.y].min() + padding = (-0.2, 0.4) + else: + min_value = 0 + padding = (0, 0.15) # Ensure this is a tuple + return (min_value, max_value), padding + + + + def _create_subplots(self, rows: int, columns: int): + """ + Create a grid of subplots using matplotlib. + + Args: + rows (int): Number of subplot rows. + columns (int): Number of subplot columns. + figsize (Tuple[float, float]): Size of the figure in inches. + + Returns: + Tuple[Figure, List[Axes]]: The matplotlib Figure and a flattened list of Axes. + """ + fig, axes = plt.subplots(rows, columns, squeeze=False) + axes = axes.flatten() + return fig, axes + + + def _delete_extra_axes(self, axes, start_index: int): + """ + Remove any extra axes from a grid if the number of groups is smaller than + the number of subplot axes. + + Args: + axes (list): List of axes objects. + start_index (int): The index in the axes list from which to start deleting. + """ + for j in range(start_index, len(axes)): + self.fig.delaxes(axes[j]) + + # In matplotlib the canvas is referred to as a Axes, the figure object is the encompassing object @property def ax(self): diff --git a/test/conftest.py b/test/conftest.py index 92123f0e..0ce02391 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,5 +1,9 @@ import pytest import pandas as pd +import os +import sys +# sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + from pathlib import Path from pyopenms_viz.testing import ( MatplotlibSnapshotExtension, diff --git a/test/test_spectrum.py b/test/test_spectrum.py index 27d398bf..b49f9344 100644 --- a/test/test_spectrum.py +++ b/test/test_spectrum.py @@ -7,6 +7,7 @@ import pandas as pd + @pytest.mark.parametrize( "kwargs", [