Skip to content

Commit 3a15eee

Browse files
Update diagnosis tests (#382)
* update diagnosis tests * update diagnosis tests * remove unnecessary condition
1 parent 96ff0ef commit 3a15eee

2 files changed

Lines changed: 24 additions & 1 deletion

File tree

validmind/tests/model_validation/sklearn/OverfitDiagnosis.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,16 @@ def OverfitDiagnosis(
220220
- May not capture more subtle forms of overfitting that do not exceed the threshold.
221221
- Assumes that the binning of features adequately represents the data segments.
222222
"""
223+
224+
numeric_and_categorical_feature_columns = (
225+
datasets[0].feature_columns_numeric + datasets[0].feature_columns_categorical
226+
)
227+
228+
if not numeric_and_categorical_feature_columns:
229+
raise ValueError(
230+
"No valid numeric or categorical columns found in features_columns"
231+
)
232+
223233
is_classification = bool(datasets[0].probability_column(model))
224234

225235
if not metric:
@@ -246,7 +256,7 @@ def OverfitDiagnosis(
246256
figures = []
247257
results_headers = ["slice", "shape", "feature", metric]
248258

249-
for feature_column in datasets[0].feature_columns:
259+
for feature_column in numeric_and_categorical_feature_columns:
250260
bins = 10
251261
if feature_column in datasets[0].feature_columns_categorical:
252262
bins = len(train_df[feature_column].unique())

validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,19 @@ def WeakspotsDiagnosis(
211211
improvement.
212212
"""
213213
feature_columns = features_columns or datasets[0].feature_columns
214+
numeric_and_categorical_columns = (
215+
datasets[0].feature_columns_numeric + datasets[0].feature_columns_categorical
216+
)
217+
218+
feature_columns = [
219+
col for col in feature_columns if col in numeric_and_categorical_columns
220+
]
221+
222+
if not feature_columns:
223+
raise ValueError(
224+
"No valid numeric or categorical columns found in features_columns"
225+
)
226+
214227
if not all(col in datasets[0].feature_columns for col in feature_columns):
215228
raise ValueError(
216229
"Column(s) provided in features_columns do not exist in the dataset"

0 commit comments

Comments
 (0)