Skip to content

Commit c22d84f

Browse files
committed
add error message if user passes decision tree
1 parent 8169f41 commit c22d84f

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

boruta/boruta_py.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def _fit(self, X, y):
326326

327327
# set n_estimators
328328
if self.n_estimators != 'auto':
329-
self.estimator.set_params(n_estimators=self.n_estimators)
329+
self._set_n_estimators(self.n_estimators)
330330

331331
# main feature selection loop
332332
while np.any(dec_reg == 0) and _iter < self.max_iter:
@@ -335,7 +335,7 @@ def _fit(self, X, y):
335335
# number of features that aren't rejected
336336
not_rejected = np.where(dec_reg >= 0)[0].shape[0]
337337
n_tree = self._get_tree_num(not_rejected)
338-
self.estimator.set_params(n_estimators=n_tree)
338+
self._set_n_estimators(n_estimators=n_tree)
339339

340340
# make sure we start with a new tree in each iteration
341341
if self._is_lightgbm:
@@ -452,6 +452,17 @@ def _transform(self, X, weak=False, return_df=False):
452452
X = X[:, indices]
453453
return X
454454

455+
def _set_n_estimators(self, n_estimators):
456+
try:
457+
self.estimator.set_params(n_estimators=n_estimators)
458+
except ValueError:
459+
raise ValueError(
460+
f"The estimator {self.estimator} does not take the parameter "
461+
"n_estimators. Use Random Forests or gradient boosting machines "
462+
"instead."
463+
)
464+
return self
465+
455466
def _get_tree_num(self, n_feat):
456467
depth = None
457468
try:

boruta/test/test_boruta.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pandas as pd
33
import pytest
44
from sklearn.ensemble import RandomForestClassifier
5+
from sklearn.tree import DecisionTreeClassifier, ExtraTreeClassifier
56

67
from boruta import BorutaPy
78

@@ -26,8 +27,8 @@ def Xy():
2627

2728
# 5 relevant features
2829
X[:, 0] = z
29-
X[:, 1] = (y * np.abs(np.random.normal(0, 1, 1000))
30-
+ np.random.normal(0, 0.1, 1000))
30+
X[:, 1] = (y * np.abs(np.random.normal(0, 1, 1000)) +
31+
np.random.normal(0, 0.1, 1000))
3132
X[:, 2] = y + np.random.normal(0, 1, 1000)
3233
X[:, 3] = y**2 + np.random.normal(0, 1, 1000)
3334
X[:, 4] = np.sqrt(y) + np.random.binomial(2, 0.1, 1000)
@@ -65,3 +66,18 @@ def test_dataframe_is_returned(Xy):
6566
bt = BorutaPy(rfc)
6667
bt.fit(X_df, y_df)
6768
assert isinstance(bt.transform(X_df, return_df=True), pd.DataFrame)
69+
70+
71+
@pytest.mark.parametrize("tree", [ExtraTreeClassifier(), DecisionTreeClassifier()])
72+
def test_boruta_with_decision_trees(tree, Xy):
73+
msg = (
74+
f"The estimator {tree} does not take the parameter "
75+
"n_estimators. Use Random Forests or gradient boosting machines "
76+
"instead."
77+
)
78+
X, y = Xy
79+
bt = BorutaPy(tree)
80+
with pytest.raises(ValueError) as record:
81+
bt.fit(X, y)
82+
83+
assert str(record.value) == msg

0 commit comments

Comments
 (0)