From d091a1b1c247cc221ffaa63b9e8d011d4623dabe Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 4 Nov 2025 16:24:19 +0000 Subject: [PATCH 1/2] use passed dataframe --- src/spikeinterface/curation/model_based_curation.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index 93ad03734c..98f8b4073f 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -188,14 +188,9 @@ def _check_params_for_classification(self, enforce_metric_params=False, model_in else: warnings.warn(warning_message) - def _export_to_phy(self, classified_units): + def _export_to_phy(self, classified_df): """Export the classified units to Phy as cluster_prediction.tsv file""" - import pandas as pd - - # Create a new DataFrame with unit_id, prediction, and probability columns from dict {unit_id: (prediction, probability)} - classified_df = pd.DataFrame.from_dict(classified_units, orient="index", columns=["prediction", "probability"]) - # Export to Phy format try: sorting_path = self.sorting_analyzer.sorting.get_annotation("phy_folder") From dca8e6a6b67cc6ba934901c1c723433298ee7c1e Mon Sep 17 00:00:00 2001 From: chrishalcrow Date: Tue, 4 Nov 2025 16:40:47 +0000 Subject: [PATCH 2/2] fix tests to match new args --- .../curation/tests/test_model_based_curation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/tests/test_model_based_curation.py b/src/spikeinterface/curation/tests/test_model_based_curation.py index 3683b417df..e2452c1d54 100644 --- a/src/spikeinterface/curation/tests/test_model_based_curation.py +++ b/src/spikeinterface/curation/tests/test_model_based_curation.py @@ -99,7 +99,9 @@ def test_model_based_classification_get_metrics_for_classification( def test_model_based_classification_export_to_phy(sorting_analyzer_for_curation, model): # Test the _export_to_phy() method of ModelBasedClassification model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model[0]) - classified_units = {0: (1, 0.5), 1: (0, 0.5), 2: (1, 0.5), 3: (0, 0.5), 4: (1, 0.5)} + import pandas as pd + + classified_units = pd.DataFrame.from_dict({0: (1, 0.5), 1: (0, 0.5), 2: (1, 0.5), 3: (0, 0.5), 4: (1, 0.5)}) # Function should fail here with pytest.raises(ValueError): model_based_classification._export_to_phy(classified_units)