Skip to content

Commit ed6e41c

Browse files
authored
fixed issue with binning bool cols and with constant datetime cols
1 parent fa6dfc9 commit ed6e41c

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

mostlyai/qa/accuracy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,7 @@ def _clip(col, bins):
10461046

10471047
if col.nunique() == 1:
10481048
# ensure 2 breaks for single-valued columns
1049-
val = col.iloc[0]
1049+
val = col.dropna().iloc[0]
10501050
upper_limit = [val + np.timedelta64(1, "D")] if not pd.isna(val) else []
10511051
breaks = [val] + upper_limit
10521052
else:
@@ -1115,6 +1115,7 @@ def bin_non_categorical(
11151115

11161116

11171117
def bin_categorical(col: pd.Series, bins: int | list[str]) -> tuple[pd.Categorical, list[str]]:
1118+
col = col.astype("string[pyarrow]")
11181119
col = col.fillna(NA_BIN)
11191120
col = col.replace("", EMPTY_BIN)
11201121
# determine top values, if not provided

tests/unit/test_accuracy.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
trim_labels,
3838
calculate_correlations,
3939
plot_store_correlation_matrices,
40+
bin_categorical,
4041
)
4142
from mostlyai.qa.sampling import pull_data_for_accuracy, sample_two_consecutive_rows
4243
from mostlyai.qa.common import (
@@ -496,6 +497,14 @@ def test_num_col_nans_only(self):
496497
df_counts = df["nans"].value_counts().to_dict()
497498
assert df_counts["(n/a)"] == 10
498499

500+
def test_bin_categorical(self):
501+
x = pd.Series(["a", "b"] * 50 + ["x"])
502+
col, _ = bin_categorical(x, 5)
503+
assert len(col) == 101
504+
x = pd.Series([True, False] * 50 + [np.nan] * 100, dtype="object")
505+
col, _ = bin_categorical(x, 5)
506+
assert len(col) == 200
507+
499508
def test_bin_numeric(self):
500509
# test several edge cases
501510
cases = [
@@ -534,6 +543,10 @@ def test_bin_datetime(self):
534543
),
535544
["⪰ 2023-01-30 13:00:00.333000"] * 20,
536545
), # two values
546+
(
547+
pd.Series([pd.NaT, "2024-11-20"], dtype="datetime64[ns]"),
548+
["(n/a)", "⪰ 2024-Nov-20"],
549+
), # single value with leading N/A
537550
]
538551

539552
for col, expected in cases:

0 commit comments

Comments
 (0)