diff --git a/imblearn/base.py b/imblearn/base.py index 0b2d94e84..ba7ce3ec7 100644 --- a/imblearn/base.py +++ b/imblearn/base.py @@ -75,8 +75,8 @@ def fit(self, X, y): Return the instance itself. """ X, y, _ = self._check_X_y(X, y) - self.sampling_strategy_ = check_sampling_strategy( - self.sampling_strategy, y, self._sampling_type + self.sampling_strategy_, self._original_class_counts = check_sampling_strategy( + self.sampling_strategy, y, self._sampling_type, return_original_counts=True ) return self @@ -105,8 +105,8 @@ def fit_resample(self, X, y): arrays_transformer = ArraysTransformer(X, y) X, y, binarize_y = self._check_X_y(X, y) - self.sampling_strategy_ = check_sampling_strategy( - self.sampling_strategy, y, self._sampling_type + self.sampling_strategy_, self._original_class_counts = check_sampling_strategy( + self.sampling_strategy, y, self._sampling_type, return_original_counts=True ) output = self._fit_resample(X, y) @@ -363,8 +363,8 @@ def fit(self, X, y): check_classification_targets(y) X, y, _ = self._check_X_y(X, y, accept_sparse=self.accept_sparse) - self.sampling_strategy_ = check_sampling_strategy( - self.sampling_strategy, y, self._sampling_type + self.sampling_strategy_, self._original_class_counts = check_sampling_strategy( + self.sampling_strategy, y, self._sampling_type, return_original_counts=True ) return self @@ -396,8 +396,8 @@ def fit_resample(self, X, y): check_classification_targets(y) X, y, binarize_y = self._check_X_y(X, y, accept_sparse=self.accept_sparse) - self.sampling_strategy_ = check_sampling_strategy( - self.sampling_strategy, y, self._sampling_type + self.sampling_strategy_, self._original_class_counts = check_sampling_strategy( + self.sampling_strategy, y, self._sampling_type, return_original_counts=True ) output = self._fit_resample(X, y) diff --git a/imblearn/ensemble/_bagging.py b/imblearn/ensemble/_bagging.py index acb0c70fa..5089cba10 100644 --- a/imblearn/ensemble/_bagging.py +++ b/imblearn/ensemble/_bagging.py @@ -17,6 +17,7 @@ from sklearn.exceptions import NotFittedError from sklearn.tree import DecisionTreeClassifier from sklearn.utils.fixes import parse_version +from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import check_is_fitted try: @@ -35,7 +36,11 @@ from ..utils._docstring import _n_jobs_docstring, _random_state_docstring from ..utils._param_validation import HasMethods, Interval, StrOptions from ..utils.fixes import _fit_context -from ._common import _bagging_parameter_constraints, _estimator_has +from ._common import ( + _bagging_parameter_constraints, + _estimate_reweighting, + _estimator_has, +) sklearn_version = parse_version(sklearn.__version__) @@ -121,6 +126,13 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier): .. versionadded:: 0.8 + recalibrate : bool, default=False + Whether to recalibrate the output of `predict_proba` and `predict_log_proba` + using the sampling ratio of the different bootstrap samples. Note that the + correction is only working for binary classification. + + .. versionadded:: 0.13 + Attributes ---------- estimator_ : estimator @@ -264,6 +276,7 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier): ], "replacement": ["boolean"], "sampler": [HasMethods(["fit_resample"]), None], + "recalibrate": ["boolean"], } ) # TODO: remove when minimum supported version of scikit-learn is 1.4 @@ -287,6 +300,7 @@ def __init__( random_state=None, verbose=0, sampler=None, + recalibrate=False, ): super().__init__( n_estimators=n_estimators, @@ -304,6 +318,7 @@ def __init__( self.sampling_strategy = sampling_strategy self.replacement = replacement self.sampler = sampler + self.recalibrate = recalibrate def _validate_y(self, y): y_encoded = super()._validate_y(y) @@ -371,6 +386,15 @@ def fit(self, X, y): """ # overwrite the base class method by disallowing `sample_weight` self._validate_params() + if self.recalibrate: + # compute the type of target only if we need to recalibrate since this is + # potentially costly + y_type = type_of_target(y) + if y_type != "binary": + raise ValueError( + "Only possible to recalibrate the probabilities for binary " + f"classification. Got {y_type} instead." + ) return super().fit(X, y) def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None): @@ -388,6 +412,60 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None): # None. return super()._fit(X, y, self.max_samples) + def predict_proba(self, X): + """Predict class probabilities for X. + + The predicted class probabilities of an input sample is computed as + the mean predicted class probabilities of the base estimators in the + ensemble. If base estimators do not implement a ``predict_proba`` + method, then it resorts to voting and the predicted class probabilities + of an input sample represents the proportion of estimators predicting + each class. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only if + they are supported by the base estimator. + + Returns + ------- + p : ndarray of shape (n_samples, n_classes) + The class probabilities of the input samples. The order of the + classes corresponds to that in the attribute :term:`classes_`. + """ + proba = super().predict_proba(X) + + if self.recalibrate: + weight = _estimate_reweighting([est[0] for est in self.estimators_]) + proba[:, 1] /= proba[:, 1] + (1 - proba[:, 1]) / weight + proba[:, 0] = 1 - proba[:, 1] + + return proba + + def predict_log_proba(self, X): + """Predict class log-probabilities for X. + + The predicted class log-probabilities of an input sample is computed as + the log of the mean predicted class probabilities of the base + estimators in the ensemble. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only if + they are supported by the base estimator. + + Returns + ------- + p : ndarray of shape (n_samples, n_classes) + The class log-probabilities of the input samples. The order of the + classes corresponds to that in the attribute :term:`classes_`. + """ + # To take into account the calibration correction, we use our implementation + # of `predict_proba` and then apply the log. + return np.log(self.predict_proba(X)) + # TODO: remove when minimum supported version of scikit-learn is 1.1 @available_if(_estimator_has("decision_function")) def decision_function(self, X): diff --git a/imblearn/ensemble/_common.py b/imblearn/ensemble/_common.py index 588fa5e2c..1195a39f7 100644 --- a/imblearn/ensemble/_common.py +++ b/imblearn/ensemble/_common.py @@ -1,5 +1,7 @@ +import copy from numbers import Integral, Real +import numpy as np from sklearn.tree._criterion import Criterion from ..utils._param_validation import ( @@ -28,6 +30,39 @@ def check(self): return check +def _estimate_reweighting(samplers): + """Estimate the reweighting factor to calibrate the probabilities. + + The reweighting factor is the averaged ratio of the probability of the + positive class before and after resampling for all samplers. + + Parameters + ---------- + samplers : list of samplers + The list of samplers. + + Returns + ------- + weight : float + The reweighting factor. + """ + weights = [] + for sampler in samplers: + # Since the samplers are internally created, we know that we have target encoded + # with 0 and 1. + p_y_1_original = sampler._original_class_counts[1] / sum( + sampler._original_class_counts[k] for k in [0, 1] + ) + resampled_counts = copy.copy(sampler._original_class_counts) + resampled_counts.update(sampler.sampling_strategy_) + p_y_1_resampled = resampled_counts[1] / sum(resampled_counts[k] for k in [0, 1]) + weights.append( + (p_y_1_original / (1 - p_y_1_original)) + * ((1 - p_y_1_resampled) / p_y_1_resampled) + ) + return np.mean(weights) + + _bagging_parameter_constraints = { "estimator": [HasMethods(["fit", "predict"]), None], "n_estimators": [Interval(Integral, 1, None, closed="left")], diff --git a/imblearn/ensemble/_easy_ensemble.py b/imblearn/ensemble/_easy_ensemble.py index e3c85741c..97253047c 100644 --- a/imblearn/ensemble/_easy_ensemble.py +++ b/imblearn/ensemble/_easy_ensemble.py @@ -17,6 +17,7 @@ from sklearn.exceptions import NotFittedError from sklearn.utils._tags import _safe_tags from sklearn.utils.fixes import parse_version +from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import check_is_fitted try: @@ -35,7 +36,11 @@ from ..utils._docstring import _n_jobs_docstring, _random_state_docstring from ..utils._param_validation import Interval, StrOptions from ..utils.fixes import _fit_context -from ._common import _bagging_parameter_constraints, _estimator_has +from ._common import ( + _bagging_parameter_constraints, + _estimate_reweighting, + _estimator_has, +) MAX_INT = np.iinfo(np.int32).max sklearn_version = parse_version(sklearn.__version__) @@ -85,6 +90,13 @@ class EasyEnsembleClassifier(_ParamsValidationMixin, BaggingClassifier): verbose : int, default=0 Controls the verbosity of the building process. + recalibrate : bool, default=False + Whether to recalibrate the output of `predict_proba` and `predict_log_proba` + using the sampling ratio of the different bootstrap samples. Note that the + correction is only working for binary classification. + + .. versionadded:: 0.13 + Attributes ---------- estimator_ : estimator @@ -198,6 +210,7 @@ class EasyEnsembleClassifier(_ParamsValidationMixin, BaggingClassifier): callable, ], "replacement": ["boolean"], + "recalibrate": ["boolean"], } ) # TODO: remove when minimum supported version of scikit-learn is 1.4 @@ -215,6 +228,7 @@ def __init__( n_jobs=None, random_state=None, verbose=0, + recalibrate=False, ): super().__init__( n_estimators=n_estimators, @@ -231,6 +245,7 @@ def __init__( self.estimator = estimator self.sampling_strategy = sampling_strategy self.replacement = replacement + self.recalibrate = recalibrate def _validate_y(self, y): y_encoded = super()._validate_y(y) @@ -294,6 +309,15 @@ def fit(self, X, y): """ self._validate_params() # overwrite the base class method by disallowing `sample_weight` + if self.recalibrate: + # compute the type of target only if we need to recalibrate since this is + # potentially costly + y_type = type_of_target(y) + if y_type != "binary": + raise ValueError( + "Only possible to recalibrate the probabilities for binary " + f"classification. Got {y_type} instead." + ) return super().fit(X, y) def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None): @@ -302,6 +326,60 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None): # None. return super()._fit(X, y, self.max_samples) + def predict_proba(self, X): + """Predict class probabilities for X. + + The predicted class probabilities of an input sample is computed as + the mean predicted class probabilities of the base estimators in the + ensemble. If base estimators do not implement a ``predict_proba`` + method, then it resorts to voting and the predicted class probabilities + of an input sample represents the proportion of estimators predicting + each class. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only if + they are supported by the base estimator. + + Returns + ------- + p : ndarray of shape (n_samples, n_classes) + The class probabilities of the input samples. The order of the + classes corresponds to that in the attribute :term:`classes_`. + """ + proba = super().predict_proba(X) + + if self.recalibrate: + weight = _estimate_reweighting([est[0] for est in self.estimators_]) + proba[:, 1] /= proba[:, 1] + (1 - proba[:, 1]) / weight + proba[:, 0] = 1 - proba[:, 1] + + return proba + + def predict_log_proba(self, X): + """Predict class log-probabilities for X. + + The predicted class log-probabilities of an input sample is computed as + the log of the mean predicted class probabilities of the base + estimators in the ensemble. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only if + they are supported by the base estimator. + + Returns + ------- + p : ndarray of shape (n_samples, n_classes) + The class log-probabilities of the input samples. The order of the + classes corresponds to that in the attribute :term:`classes_`. + """ + # To take into account the calibration correction, we use our implementation + # of `predict_proba` and then apply the log. + return np.log(self.predict_proba(X)) + # TODO: remove when minimum supported version of scikit-learn is 1.1 @available_if(_estimator_has("decision_function")) def decision_function(self, X): diff --git a/imblearn/ensemble/_forest.py b/imblearn/ensemble/_forest.py index 5f8d08e91..3b55b3039 100644 --- a/imblearn/ensemble/_forest.py +++ b/imblearn/ensemble/_forest.py @@ -42,7 +42,10 @@ from ..utils._param_validation import Hidden, Interval, StrOptions from ..utils._validation import check_sampling_strategy from ..utils.fixes import _fit_context -from ._common import _random_forest_classifier_parameter_constraints +from ._common import ( + _estimate_reweighting, + _random_forest_classifier_parameter_constraints, +) MAX_INT = np.iinfo(np.int32).max sklearn_version = parse_version(sklearn.__version__) @@ -327,6 +330,13 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif Only supported when scikit-learn >= 1.4 is installed. Otherwise, a `ValueError` is raised. + recalibrate : bool, default=False + Whether to recalibrate the output of `predict_proba` and `predict_log_proba` + using the sampling ratio of the different bootstrap samples. Note that the + correction is only working for binary classification. + + .. versionadded:: 0.13 + Attributes ---------- estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier` instance @@ -445,6 +455,7 @@ class labels (multi-output problem). Hidden(StrOptions({"warn"})), ], "replacement": ["boolean", Hidden(StrOptions({"warn"}))], + "recalibrate": ["boolean"], } ) @@ -472,6 +483,7 @@ def __init__( ccp_alpha=0.0, max_samples=None, monotonic_cst=None, + recalibrate=False, ): params_random_forest = { "criterion": criterion, @@ -510,6 +522,7 @@ def __init__( self.sampling_strategy = sampling_strategy self.replacement = replacement + self.recalibrate = recalibrate def _validate_estimator(self, default=DecisionTreeClassifier()): """Check the estimator and the n_estimator attribute, set the @@ -572,6 +585,15 @@ def fit(self, X, y, sample_weight=None): The fitted instance. """ self._validate_params() + if self.recalibrate: + # compute the type of target only if we need to recalibrate since this is + # potentially costly + y_type = type_of_target(y) + if y_type != "binary": + raise ValueError( + "Only possible to recalibrate the probabilities for binary " + f"classification. Got {y_type} instead." + ) # TODO: remove in 0.13 if self.sampling_strategy == "warn": warn( @@ -895,6 +917,37 @@ def _compute_oob_predictions(self, X, y): return oob_pred + def predict_proba(self, X): + """ + Predict class probabilities for X. + + The predicted class probabilities of an input sample are computed as + the mean predicted class probabilities of the trees in the forest. + The class probability of a single tree is the fraction of samples of + the same class in a leaf. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The input samples. Internally, its dtype will be converted to + ``dtype=np.float32``. If a sparse matrix is provided, it will be + converted into a sparse ``csr_matrix``. + + Returns + ------- + p : ndarray of shape (n_samples, n_classes), or a list of such arrays + The class probabilities of the input samples. The order of the + classes corresponds to that in the attribute :term:`classes_`. + """ + proba = super().predict_proba(X) + + if self.recalibrate: + weight = _estimate_reweighting(self.samplers_) + proba[:, 1] /= proba[:, 1] + (1 - proba[:, 1]) / weight + proba[:, 0] = 1 - proba[:, 1] + + return proba + # TODO: remove when supporting scikit-learn>=1.2 @property def n_features_(self): diff --git a/imblearn/ensemble/tests/test_bagging.py b/imblearn/ensemble/tests/test_bagging.py index 5705de553..64b126054 100644 --- a/imblearn/ensemble/tests/test_bagging.py +++ b/imblearn/ensemble/tests/test_bagging.py @@ -1,4 +1,5 @@ """Test the module ensemble classifiers.""" + # Authors: Guillaume Lemaitre # Christos Aridas # License: MIT @@ -13,6 +14,7 @@ from sklearn.dummy import DummyClassifier from sklearn.feature_selection import SelectKBest from sklearn.linear_model import LogisticRegression, Perceptron +from sklearn.metrics import log_loss from sklearn.model_selection import GridSearchCV, ParameterGrid, train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC @@ -592,3 +594,45 @@ def test_balanced_bagging_classifier_n_features(): estimator = BalancedBaggingClassifier().fit(X, y) with pytest.warns(FutureWarning, match="`n_features_` was deprecated"): estimator.n_features_ + + +def test_balanced_bagging_classifier_recalibrate_error(): + """Check that we raise a ValueError when trying to recalibrate the + classifier when the problem is not a binary classification problem. + """ + X, y = load_iris(return_X_y=True) + err_msg = "Only possible to recalibrate the probabilities for binary classification" + with pytest.raises(ValueError, match=err_msg): + BalancedBaggingClassifier(recalibrate=True).fit(X, y) + + +def test_balanced_bagging_classifier_recalibrate(): + """Check the behaviour of the `recalibrate` parameter.""" + X, y = make_classification( + n_samples=10_000, + n_classes=2, + n_clusters_per_class=2, + weights=[0.2, 0.8], + class_sep=0.1, + random_state=0, + ) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + bbc_uncalibrated = BalancedBaggingClassifier( + n_estimators=50, random_state=42, n_jobs=-1, recalibrate=False + ).fit(X_train, y_train) + bbc_recalibrated = BalancedBaggingClassifier( + n_estimators=50, random_state=42, n_jobs=-1, recalibrate=True + ).fit(X_train, y_train) + + # Since the resampling is breaking the calibration, we expected that a proper + # scoring rule error to be much lower on the recalibrate model + log_loss_uncalibrated = log_loss(y_test, bbc_uncalibrated.predict_proba(X_test)) + log_loss_recalibrated = log_loss(y_test, bbc_recalibrated.predict_proba(X_test)) + assert log_loss_uncalibrated > log_loss_recalibrated + log_loss_uncalibrated = log_loss( + y_test, np.exp(bbc_uncalibrated.predict_log_proba(X_test)) + ) + log_loss_recalibrated = log_loss( + y_test, np.exp(bbc_recalibrated.predict_log_proba(X_test)) + ) + assert log_loss_uncalibrated > log_loss_recalibrated diff --git a/imblearn/ensemble/tests/test_easy_ensemble.py b/imblearn/ensemble/tests/test_easy_ensemble.py index 7dc04414a..b4ede73b7 100644 --- a/imblearn/ensemble/tests/test_easy_ensemble.py +++ b/imblearn/ensemble/tests/test_easy_ensemble.py @@ -1,4 +1,5 @@ """Test the module easy ensemble.""" + # Authors: Guillaume Lemaitre # Christos Aridas # License: MIT @@ -6,10 +7,12 @@ import numpy as np import pytest import sklearn -from sklearn.datasets import load_iris, make_hastie_10_2 +from sklearn.datasets import load_iris, make_classification, make_hastie_10_2 from sklearn.ensemble import AdaBoostClassifier from sklearn.feature_selection import SelectKBest +from sklearn.metrics import log_loss from sklearn.model_selection import GridSearchCV, train_test_split +from sklearn.tree import DecisionTreeClassifier from sklearn.utils._testing import assert_allclose, assert_array_equal from sklearn.utils.fixes import parse_version @@ -233,3 +236,53 @@ def test_easy_ensemble_classifier_n_features(): estimator = EasyEnsembleClassifier().fit(X, y) with pytest.warns(FutureWarning, match="`n_features_` was deprecated"): estimator.n_features_ + + +def test_easy_ensemble_classifier_recalibrate_error(): + """Check that we raise a ValueError when trying to recalibrate the + classifier when the problem is not a binary classification problem. + """ + X, y = load_iris(return_X_y=True) + err_msg = "Only possible to recalibrate the probabilities for binary classification" + with pytest.raises(ValueError, match=err_msg): + EasyEnsembleClassifier(recalibrate=True).fit(X, y) + + +def test_easy_ensemble_classifier_recalibrate(): + """Check the behaviour of the `recalibrate` parameter.""" + X, y = make_classification( + n_samples=10_000, + n_classes=2, + n_clusters_per_class=2, + weights=[0.2, 0.8], + class_sep=0.1, + random_state=0, + ) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + ee_uncalibrated = EasyEnsembleClassifier( + n_estimators=50, + estimator=DecisionTreeClassifier(), + random_state=42, + n_jobs=-1, + recalibrate=False, + ).fit(X_train, y_train) + ee_recalibrated = EasyEnsembleClassifier( + n_estimators=50, + estimator=DecisionTreeClassifier(), + random_state=42, + n_jobs=-1, + recalibrate=True, + ).fit(X_train, y_train) + + # Since the resampling is breaking the calibration, we expected that a proper + # scoring rule error to be much lower on the recalibrate model + log_loss_uncalibrated = log_loss(y_test, ee_uncalibrated.predict_proba(X_test)) + log_loss_recalibrated = log_loss(y_test, ee_recalibrated.predict_proba(X_test)) + assert log_loss_uncalibrated > log_loss_recalibrated + log_loss_uncalibrated = log_loss( + y_test, np.exp(ee_uncalibrated.predict_log_proba(X_test)) + ) + log_loss_recalibrated = log_loss( + y_test, np.exp(ee_recalibrated.predict_log_proba(X_test)) + ) + assert log_loss_uncalibrated > log_loss_recalibrated diff --git a/imblearn/ensemble/tests/test_forest.py b/imblearn/ensemble/tests/test_forest.py index 3719568e5..24689c266 100644 --- a/imblearn/ensemble/tests/test_forest.py +++ b/imblearn/ensemble/tests/test_forest.py @@ -2,6 +2,7 @@ import pytest import sklearn from sklearn.datasets import load_iris, make_classification +from sklearn.metrics import log_loss from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.utils._testing import assert_allclose, assert_array_equal from sklearn.utils.fixes import parse_version @@ -342,3 +343,57 @@ def test_missing_value_is_predictive(): assert predictive_test_score >= forest_non_predictive.score( X_non_predictive_test, y_test ) + + +def test_balanced_random_forest_classifier_recalibrate_error(): + """Check that we raise a ValueError when trying to recalibrate the + classifier when the problem is not a binary classification problem. + """ + X, y = load_iris(return_X_y=True) + err_msg = "Only possible to recalibrate the probabilities for binary classification" + with pytest.raises(ValueError, match=err_msg): + BalancedRandomForestClassifier(recalibrate=True).fit(X, y) + + +def test_balanced_random_forest_classifier_recalibrate(): + """Check the behaviour of the `recalibrate` parameter.""" + X, y = make_classification( + n_samples=10_000, + n_classes=2, + n_clusters_per_class=2, + weights=[0.2, 0.8], + class_sep=0.1, + random_state=0, + ) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + bbc_uncalibrated = BalancedRandomForestClassifier( + n_estimators=50, + random_state=42, + n_jobs=-1, + sampling_strategy="all", + replacement=True, + bootstrap=False, + recalibrate=False, + ).fit(X_train, y_train) + bbc_recalibrated = BalancedRandomForestClassifier( + n_estimators=50, + random_state=42, + n_jobs=-1, + sampling_strategy="all", + replacement=True, + bootstrap=False, + recalibrate=True, + ).fit(X_train, y_train) + + # Since the resampling is breaking the calibration, we expected that a proper + # scoring rule error to be much lower on the recalibrate model + log_loss_uncalibrated = log_loss(y_test, bbc_uncalibrated.predict_proba(X_test)) + log_loss_recalibrated = log_loss(y_test, bbc_recalibrated.predict_proba(X_test)) + assert log_loss_uncalibrated > log_loss_recalibrated + log_loss_uncalibrated = log_loss( + y_test, np.exp(bbc_uncalibrated.predict_log_proba(X_test)) + ) + log_loss_recalibrated = log_loss( + y_test, np.exp(bbc_recalibrated.predict_log_proba(X_test)) + ) + assert log_loss_uncalibrated > log_loss_recalibrated diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index b21c15788..16d619e69 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -186,9 +186,8 @@ def check_target_type(y, indicate_one_vs_all=False): return (y, type_y == "multilabel-indicator") if indicate_one_vs_all else y -def _sampling_strategy_all(y, sampling_type): +def _sampling_strategy_all(target_stats, sampling_type): """Returns sampling target by targeting all classes.""" - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": n_sample_majority = max(target_stats.values()) sampling_strategy = { @@ -203,14 +202,13 @@ def _sampling_strategy_all(y, sampling_type): return sampling_strategy -def _sampling_strategy_majority(y, sampling_type): +def _sampling_strategy_majority(target_stats, sampling_type): """Returns sampling target by targeting the majority class only.""" if sampling_type == "over-sampling": raise ValueError( "'sampling_strategy'='majority' cannot be used with over-sampler." ) elif sampling_type == "under-sampling" or sampling_type == "clean-sampling": - target_stats = _count_class_sample(y) class_majority = max(target_stats, key=target_stats.get) n_sample_minority = min(target_stats.values()) sampling_strategy = { @@ -224,10 +222,9 @@ def _sampling_strategy_majority(y, sampling_type): return sampling_strategy -def _sampling_strategy_not_majority(y, sampling_type): +def _sampling_strategy_not_majority(target_stats, sampling_type): """Returns sampling target by targeting all classes but not the majority.""" - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": n_sample_majority = max(target_stats.values()) class_majority = max(target_stats, key=target_stats.get) @@ -250,10 +247,9 @@ def _sampling_strategy_not_majority(y, sampling_type): return sampling_strategy -def _sampling_strategy_not_minority(y, sampling_type): +def _sampling_strategy_not_minority(target_stats, sampling_type): """Returns sampling target by targeting all classes but not the minority.""" - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": n_sample_majority = max(target_stats.values()) class_minority = min(target_stats, key=target_stats.get) @@ -276,9 +272,8 @@ def _sampling_strategy_not_minority(y, sampling_type): return sampling_strategy -def _sampling_strategy_minority(y, sampling_type): +def _sampling_strategy_minority(target_stats, sampling_type): """Returns sampling target by targeting the minority class only.""" - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": n_sample_majority = max(target_stats.values()) class_minority = min(target_stats, key=target_stats.get) @@ -298,19 +293,18 @@ def _sampling_strategy_minority(y, sampling_type): return sampling_strategy -def _sampling_strategy_auto(y, sampling_type): +def _sampling_strategy_auto(target_stats, sampling_type): """Returns sampling target auto for over-sampling and not-minority for under-sampling.""" if sampling_type == "over-sampling": - return _sampling_strategy_not_majority(y, sampling_type) + return _sampling_strategy_not_majority(target_stats, sampling_type) elif sampling_type == "under-sampling" or sampling_type == "clean-sampling": - return _sampling_strategy_not_minority(y, sampling_type) + return _sampling_strategy_not_minority(target_stats, sampling_type) -def _sampling_strategy_dict(sampling_strategy, y, sampling_type): +def _sampling_strategy_dict(sampling_strategy, target_stats, sampling_type): """Returns sampling target by converting the dictionary depending of the sampling.""" - target_stats = _count_class_sample(y) # check that all keys in sampling_strategy are also in y set_diff_sampling_strategy_target = set(sampling_strategy.keys()) - set( target_stats.keys() @@ -363,7 +357,7 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type): return sampling_strategy_ -def _sampling_strategy_list(sampling_strategy, y, sampling_type): +def _sampling_strategy_list(sampling_strategy, target_stats, sampling_type): """With cleaning methods, sampling_strategy can be a list to target the class of interest.""" if sampling_type != "clean-sampling": @@ -371,8 +365,6 @@ class of interest.""" "'sampling_strategy' cannot be a list for samplers " "which are not cleaning methods." ) - - target_stats = _count_class_sample(y) # check that all keys in sampling_strategy are also in y set_diff_sampling_strategy_target = set(sampling_strategy) - set( target_stats.keys() @@ -388,16 +380,14 @@ class of interest.""" } -def _sampling_strategy_float(sampling_strategy, y, sampling_type): +def _sampling_strategy_float(sampling_strategy, target_stats, sampling_type): """Take a proportion of the majority (over-sampling) or minority (under-sampling) class in binary classification.""" - type_y = type_of_target(y) - if type_y != "binary": + if len(target_stats) != 2: raise ValueError( '"sampling_strategy" can be a float only when the type ' "of target is binary. For multi-class, use a dict." ) - target_stats = _count_class_sample(y) if sampling_type == "over-sampling": n_sample_majority = max(target_stats.values()) class_majority = max(target_stats, key=target_stats.get) @@ -439,7 +429,9 @@ def _sampling_strategy_float(sampling_strategy, y, sampling_type): return sampling_strategy_ -def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): +def check_sampling_strategy( + sampling_strategy, y, sampling_type, return_original_counts=False, **kwargs +): """Sampling target validation for samplers. Checks that ``sampling_strategy`` is of consistent type and return a @@ -516,6 +508,11 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): The type of sampling. Can be either ``'over-sampling'``, ``'under-sampling'``, or ``'clean-sampling'``. + return_original_counts : bool, default=False + Whether to return the original class distribution. + + .. versionadded:: 0.13 + **kwargs : dict Dictionary of additional keyword arguments to pass to ``sampling_strategy`` when this is a callable. @@ -526,6 +523,12 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): The converted and validated sampling target. Returns a dictionary with the key being the class target and the value being the desired number of samples. + + original_class_distribution : dict + The original class distribution. Only returned if + ``return_original_counts=True``. + + .. versionadded:: 0.13 """ if sampling_type not in SAMPLING_KIND: raise ValueError( @@ -539,7 +542,10 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): f"Got {np.unique(y).size} class instead" ) + target_stats = _count_class_sample(y) if sampling_type in ("ensemble", "bypass"): + if return_original_counts: + return sampling_strategy, target_stats return sampling_strategy if isinstance(sampling_strategy, str): @@ -549,16 +555,28 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): f" to be one of {SAMPLING_TARGET_KIND}. Got '{sampling_strategy}' " f"instead." ) - return OrderedDict( - sorted(SAMPLING_TARGET_KIND[sampling_strategy](y, sampling_type).items()) + sampling_strategy_converted = OrderedDict( + sorted( + SAMPLING_TARGET_KIND[sampling_strategy]( + target_stats, sampling_type + ).items() + ) ) elif isinstance(sampling_strategy, dict): - return OrderedDict( - sorted(_sampling_strategy_dict(sampling_strategy, y, sampling_type).items()) + sampling_strategy_converted = OrderedDict( + sorted( + _sampling_strategy_dict( + sampling_strategy, target_stats, sampling_type + ).items() + ) ) elif isinstance(sampling_strategy, list): - return OrderedDict( - sorted(_sampling_strategy_list(sampling_strategy, y, sampling_type).items()) + sampling_strategy_converted = OrderedDict( + sorted( + _sampling_strategy_list( + sampling_strategy, target_stats, sampling_type + ).items() + ) ) elif isinstance(sampling_strategy, Real): if sampling_strategy <= 0 or sampling_strategy > 1: @@ -566,19 +584,27 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): f"When 'sampling_strategy' is a float, it should be " f"in the range (0, 1]. Got {sampling_strategy} instead." ) - return OrderedDict( + sampling_strategy_converted = OrderedDict( sorted( - _sampling_strategy_float(sampling_strategy, y, sampling_type).items() + _sampling_strategy_float( + sampling_strategy, target_stats, sampling_type + ).items() ) ) elif callable(sampling_strategy): sampling_strategy_ = sampling_strategy(y, **kwargs) - return OrderedDict( + sampling_strategy_converted = OrderedDict( sorted( - _sampling_strategy_dict(sampling_strategy_, y, sampling_type).items() + _sampling_strategy_dict( + sampling_strategy_, target_stats, sampling_type + ).items() ) ) + if return_original_counts: + return sampling_strategy_converted, target_stats + return sampling_strategy_converted + SAMPLING_TARGET_KIND = { "minority": _sampling_strategy_minority,