diff --git a/examples/combine/plot_illustration_spider.py b/examples/combine/plot_illustration_spider.py new file mode 100644 index 000000000..3d9e698e6 --- /dev/null +++ b/examples/combine/plot_illustration_spider.py @@ -0,0 +1,256 @@ +""" +========================================================================== +Illustration of the sample selection for the different SPIDER algorithms +========================================================================== + +This example illustrates the different ways of resampling with SPIDER. + +""" + +# Authors: Matthew Eding +# License: MIT + +from collections import namedtuple +from functools import partial + +import matplotlib.pyplot as plt +import numpy as np + +from imblearn.combine import SPIDER +from matplotlib.patches import Circle +from scipy.stats import mode + +print(__doc__) + +############################################################################### +# These are helper functions for plotting aspects of the algorithm + +Neighborhood = namedtuple('Neighborhood', 'radius, neighbors') + + +def plot_X(X, ax, **kwargs): + ax.scatter(X[:, 0], X[:, 1], **kwargs) + + +def correct(nn, y_fit, X, y, additional=False): + n_neighbors = nn.n_neighbors + if additional: + n_neighbors += 2 + nn_idxs = nn.kneighbors(X, n_neighbors, return_distance=False)[:, 1:] + y_pred, _ = mode(y_fit[nn_idxs], axis=1) + return (y == y_pred.ravel()) + + +def get_neighborhoods(spider, X_fit, y_fit, X_flagged, y_flagged, idx): + point = X_flagged[idx] + + additional = (spider.kind_sel_sel == 'strong') + if correct(spider.nn_, y_fit, point[np.newaxis], + y_flagged[idx][np.newaxis], additional=additional): + additional = False + + idxs_k = spider._locate_neighbors(point[np.newaxis]) + neighbors_k = X_fit[idxs_k].squeeze() + farthest_k = neighbors_k[-1] + radius_k = np.linalg.norm(point - farthest_k) + neighborhood_k = Neighborhood(radius_k, neighbors_k) + + idxs_k2 = spider._locate_neighbors(point[np.newaxis], additional=True) + neighbors_k2 = X_fit[idxs_k2].squeeze() + farthest_k2 = neighbors_k2[-1] + radius_k2 = np.linalg.norm(point - farthest_k2) + neighborhood_k2 = Neighborhood(radius_k2, neighbors_k2) + + return neighborhood_k, neighborhood_k2, point, additional + + +def draw_neighborhoods(spider, neighborhood_k, neighborhood_k2, point, + additional, ax, outer=True, alpha=0.5): + PartialCircle = partial(Circle, facecolor='none', edgecolor='black', + alpha=alpha) + + circle_k = PartialCircle(point, neighborhood_k.radius, linestyle='-') + + circle_k2 = PartialCircle(point, neighborhood_k2.radius, + linestyle=('-' if additional else '--')) + + if not additional: + ax.add_patch(circle_k) + + if (spider.kind_sel == 'strong') and outer: + ax.add_patch(circle_k2) + + +def draw_amplification(X_flagged, point, neighbors, ax): + for neigh in neighbors: + arr = np.vstack([point, neigh]) + xs, ys = np.split(arr, 2, axis=1) + linestyle = 'solid' if neigh in X_flagged else 'dotted' + ax.plot(xs, ys, color='black', linestyle=linestyle) + + +def plot_spider(kind_sel, X, y): + if kind_sel == 'strong': + _, axes = plt.subplots(2, 1, figsize=(12, 16)) + else: + _, axes = plt.subplots(1, 1, figsize=(12, 8)) + axes = np.atleast_1d(axes) + + spider = SPIDER(kind_sel=kind_sel) + spider.fit_resample(X, y) + + is_safe = correct(spider.nn_, y, X, y) + is_minor = (y == 1) + + X_major = X[~is_minor] + X_minor = X[is_minor] + X_noise = X[~is_safe] + + X_minor_noise = X[is_minor & ~is_safe] + y_minor_noise = y[is_minor & ~is_safe] + X_major_safe = X[~is_minor & is_safe] + X_minor_safe = X[is_minor & is_safe] + y_minor_safe = y[is_minor & is_safe] + + partial_neighborhoods = partial(get_neighborhoods, spider, X, y) + partial_amplification = partial(draw_amplification, X_major_safe) + partial_draw_neighborhoods = partial(draw_neighborhoods, spider) + + size = 500 + for axis in axes: + plot_X(X_minor, ax=axis, label='Minority class', s=size, marker='_') + plot_X(X_major, ax=axis, label='Minority class', s=size, marker='+') + + #: Overlay ring around noisy samples for both classes + plot_X(X_noise, ax=axis, label='Noisy Sample', s=size, marker='o', + facecolors='none', edgecolors='black') + + #: Neighborhoods for Noisy Minority Samples + for idx in range(len(X_minor_noise)): + neighborhoods = partial_neighborhoods(X_minor_noise, y_minor_noise, + idx=idx) + partial_draw_neighborhoods(*neighborhoods, ax=axes[0], + outer=(spider.kind_sel == 'strong')) + neigh_k, neigh_k2, point, additional = neighborhoods + neighbors = neigh_k2.neighbors if additional else neigh_k.neighbors + partial_amplification(point, neighbors, ax=axes[0]) + + axes[0].axis('equal') + axes[0].legend(markerscale=0.5) + axes[0].set_title(f'SPIDER-{spider.kind_sel.title()}') + + #: Neighborhoods for Safe Minority Samples (kind_sel='strong' only) + if spider.kind_sel == 'strong': + for idx in range(len(X_minor_safe)): + neighborhoods = partial_neighborhoods(X_minor_safe, y_minor_safe, + idx=idx) + neigh_k, _, point, additional = neighborhoods + neighbors = neigh_k.neighbors + draw_flag = np.any(np.isin(neighbors, X_major_safe)) + + alpha = 0.5 if draw_flag else 0.1 + partial_draw_neighborhoods(*neighborhoods[:-1], additional=False, + ax=axes[1], outer=False, alpha=alpha) + + if draw_flag: + partial_amplification(point, neighbors, ax=axes[1]) + + axes[1].axis('equal') + axes[1].legend(markerscale=0.5) + axes[1].set_title(f'SPIDER-{spider.kind_sel.title()}') + + +############################################################################### +# We can start by generating some data to later illustrate the principle of +# each SPIDER heuritic rules. + +X = np.array([ + [-11.83, -6.81], + [-11.72, -2.34], + [-11.43, -5.85], + [-10.66, -4.33], + [-9.64, -7.05], + [-8.39, -4.41], + [-8.07, -5.66], + [-7.28, 0.91], + [-7.24, -2.41], + [-6.13, -4.81], + [-5.92, -6.81], + [-4., -1.81], + [-3.96, 2.67], + [-3.74, -7.31], + [-2.96, 4.69], + [-1.56, -2.33], + [-1.02, -4.57], + [0.46, 4.07], + [1.2, -1.53], + [1.32, 0.41], + [1.56, -5.19], + [2.52, 5.89], + [3.03, -4.15], + [4., -0.59], + [4.4, 2.07], + [4.41, -7.45], + [4.45, -4.12], + [5.13, -6.28], + [5.4, -5], + [6.26, 4.65], + [7.02, -6.22], + [7.5, -0.11], + [8.1, -2.05], + [8.42, 2.47], + [9.62, 3.87], + [10.54, -4.47], + [11.42, 0.01] +]) + +y = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0]) + + +############################################################################### +# SPIDER-Weak / SPIDER-Relabel +############################################################################### + +############################################################################### +# Both SPIDER-Weak and SPIDER-Relabel start by labeling whether samples are +# 'safe' or 'noisy' by looking at each point's 3-NN and seeing if it would be +# classified correctly using KNN classification. For each minority-noisy +# sample, we amplify it by the number of majority-safe samples in its 3-NN. In +# the diagram below, the amplification amount is indicated by the number of +# solid lines for a given minority-noisy sample's neighborhood. +# +# We can observe that the leftmost minority-noisy sample will be duplicated 3 +# times, the middle one 1 time, and the rightmost one will not be amplified. +# +# Then if SPIDER-Weak, every majority-noisy sample is removed from the dataset. +# Othewise if SPIDER-Relabel, we relabel their class to be the minority class +# instead. These would be the samples indicated by a circled plus-sign. + +plot_spider('weak', X, y) + +############################################################################### +# SPIDER-Strong +############################################################################### + +############################################################################### +# SPIDER-Strong still uses 3-NN to classify samples as 'safe' or 'noisy' as the +# first step. However for the amplification step, each minority-noisy sample +# looks at its 5-NN, and if the larger neighborhood still misclassifies the +# sample, the 5-NN is used to amplify. Otherwise if the sample is correctly +# classified with 5-NN, the regular 3-NN is used to amplify. +# +# In the diagram below, we can see that the left/rightmost minority-noisy +# samples are misclassified using 5-NN and will be amplified by 5 and 1 +# respectively. The middle minority-noisy sample is classified correctly by +# using 5-NN, so amplification will be done using 3-NN. +# +# Next for each minority-safe sample, the amplification process is applied +# using 3-NN. In the lower subplot, all but one of these samples will not be +# amplified since they do not have majority-safe samples in their +# neighborhoods. The one minority-safe sample to be amplified is indicated in a +# darker neighborhood with lines. + +plot_spider('strong', X, y) + +plt.show() diff --git a/imblearn/combine/__init__.py b/imblearn/combine/__init__.py index a0833f996..b7dcb3ba7 100644 --- a/imblearn/combine/__init__.py +++ b/imblearn/combine/__init__.py @@ -4,5 +4,10 @@ from ._smote_enn import SMOTEENN from ._smote_tomek import SMOTETomek +from ._preprocess import SPIDER -__all__ = ["SMOTEENN", "SMOTETomek"] +__all__ = [ + "SMOTEENN", + "SMOTETomek", + "SPIDER", +] diff --git a/imblearn/combine/_preprocess/__init__.py b/imblearn/combine/_preprocess/__init__.py new file mode 100644 index 000000000..31b8b6d52 --- /dev/null +++ b/imblearn/combine/_preprocess/__init__.py @@ -0,0 +1,3 @@ +from ._spider import SPIDER + +__all__ = ["SPIDER"] diff --git a/imblearn/combine/_preprocess/_spider.py b/imblearn/combine/_preprocess/_spider.py new file mode 100644 index 000000000..92f6fe5a5 --- /dev/null +++ b/imblearn/combine/_preprocess/_spider.py @@ -0,0 +1,310 @@ +"""Class to perform cleaning and selective pre-processing using SPIDER""" + +# Authors: Matthew Eding +# License: MIT + + +from numbers import Integral + +import numpy as np +from scipy import sparse +from scipy import stats + +from sklearn.utils import safe_mask +from sklearn.utils import _safe_indexing + +from .base import BasePreprocessSampler +from ...utils import check_neighbors_object +from ...utils import Substitution +from ...utils._docstring import _n_jobs_docstring + +SEL_KIND = ("weak", "relabel", "strong") + + +@Substitution( + sampling_strategy=BasePreprocessSampler._sampling_strategy_docstring, + n_jobs=_n_jobs_docstring, +) +class SPIDER(BasePreprocessSampler): + """Perform filtering and over-sampling using SPIDER algorithm. + + This object is an implementation of SPIDER - Selective Pre-processing of + Imbalanced Data as presented in [1]_ and [2]_. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + {sampling_strategy} + + kind_sel : {{"weak", "relabel", "strong"}}, default='weak' + Strategy to use in order to preprocess samples in the SPIDER sampling. + + - If ``'weak'``, amplify noisy minority class samples based on the + number of safe majority neighbors. + - If ``'relabel'``, perform ``'weak'`` amplification and then relabel + noisy majority neighbors for each noisy minority class sample. + - If ``'strong'``, amplify all minority class samples by an extra + ``additional_neighbors`` if the sample is classified incorrectly + by its neighbors. Otherwise each minority sample is amplified in a + manner akin to ``'weak'`` amplification. + + n_neighbors : int or object, optional (default=3) + If ``int``, number of nearest neighbours to used to construct synthetic + samples. If object, an estimator that inherits from + :class:`sklearn.neighbors.base.KNeighborsMixin` that will be used to + find the nearest-neighbors. + + additional_neighbors : int, optional (default=2) + The number to add to amplified samples during if ``kind_sel`` is + ``'strong'``. This has no effect otherwise. + + {n_jobs} + + See Also + -------- + NeighborhoodClearingRule : Undersample by editing noisy samples. + + RandomOverSampler : Random oversample the dataset. + + Notes + ----- + The implementation is based on [1]_ and [2]_. + + Supports multi-class resampling. A one-vs.-rest scheme is used. + + References + ---------- + .. [1] Stefanowski, J., & Wilk, S, "Selective pre-processing of imbalanced + data for improving classification performance," In: Song, I.-Y., Eder, + J., Nguyen, T.M. (Eds.): DaWaK 2008, LNCS, vol. 5182, pp. 283–292. + Springer, Heidelberg, 2008. + + .. [2] Błaszczyński, J., Deckert, M., Stefanowski, J., & Wilk, S, + "Integrating Selective Pre-processing of Imbalanced Data with Ivotes + Ensemble," In: M. Szczuka et al. (Eds.): RSCTC 2010, LNAI 6086, pp. + 148–157, 2010. + + Examples + -------- + + >>> from collections import Counter + >>> from sklearn.datasets import make_classification + >>> from imblearn.combine import \ +SPIDER # doctest: +NORMALIZE_WHITESPACE + >>> X, y = make_classification(n_classes=2, class_sep=2, + ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, + ... n_features=20, n_clusters_per_class=1, n_samples=1000, + ... random_state=10) + >>> print('Original dataset shape %s' % Counter(y)) + Original dataset shape Counter({{1: 900, 0: 100}}) + >>> spider = SPIDER() + >>> X_res, y_res = spider.fit_resample(X, y) + >>> print('Resampled dataset shape %s' % Counter(y_res)) + Resampled dataset shape Counter({{1: 897, 0: 115}}) + """ + + def __init__( + self, + sampling_strategy="auto", + kind_sel="weak", + n_neighbors=3, + additional_neighbors=2, + n_jobs=None, + ): + super().__init__(sampling_strategy=sampling_strategy) + self.kind_sel = kind_sel + self.n_neighbors = n_neighbors + self.additional_neighbors = additional_neighbors + self.n_jobs = n_jobs + + def _validate_estimator(self): + """Create the necessary objects for SPIDER""" + self.nn_ = check_neighbors_object( + "n_neighbors", self.n_neighbors, additional_neighbor=1) + self.nn_.set_params(**{"n_jobs": self.n_jobs}) + + if self.kind_sel not in SEL_KIND: + raise ValueError( + 'The possible "kind" of algorithm are "weak", "relabel",' + ' and "strong". Got {} instead.'.format(self.kind_sel) + ) + + if not isinstance(self.additional_neighbors, Integral): + raise TypeError("additional_neighbors must be an integer.") + + if self.additional_neighbors < 1: + raise ValueError("additional_neighbors must be at least 1.") + + def _locate_neighbors(self, X, additional=False): + """Find nearest neighbors for samples. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + The feature samples to find neighbors for. + + additional : bool, optional (default=False) + Flag to indicate whether to increase ``n_neighbors`` by + ``additional_neighbors``. + + Returns + ------- + nn_indices : ndarray, shape (n_samples, n_neighbors) + Indices of the nearest neighbors for the subset. + """ + n_neighbors = self.nn_.n_neighbors + if additional: + n_neighbors += self.additional_neighbors + + nn_indices = self.nn_.kneighbors( + X, n_neighbors, return_distance=False)[:, 1:] + return nn_indices + + def _knn_correct(self, X, y, additional=False): + """Apply KNN to classify samples. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + The feature samples to classify. + + y : ndarray, shape (n_samples,) + The label samples to classify. + + additional : bool, optional (default=False) + Flag to indicate whether to increase ``n_neighbors`` by + additional_neighbors``. + + Returns + ------- + is_correct : ndarray[bool], shape (n_samples,) + Mask that indicates if KNN classifed samples correctly. + """ + if not X.size: + return np.empty(0, dtype=bool) + + nn_indices = self._locate_neighbors(X, additional) + mode, _ = stats.mode(self._y[nn_indices], axis=1) + is_correct = (y == mode.ravel()) + return is_correct + + def _amplify(self, X, y, additional=False): + """In-place amplification of samples based on their neighborhood + counts of samples that are safe and belong to the other class(es). + Returns ``nn_indices`` for relabel usage. + + Parameters + ---------- + X : ndarray, shape (n_samples, n_features) + The feature samples to amplify. + + y : ndarray, shape (n_samples,) + The label samples to amplify. + + additional : bool, optional (default=False) + Flag to indicate whether to amplify with ``additional_neighbors``. + + Returns + ------- + nn_indices : ndarray, shape (n_samples, n_neighbors) + Indices of the nearest neighbors for the subset. + """ + if not X.size: + return np.empty(0, dtype=int) + + nn_indices = self._locate_neighbors(X, additional) + amplify_amounts = np.isin( + nn_indices, self._amplify_indices).sum(axis=1) + + X_parts = [] + y_parts = [] + for amount in filter(bool, np.unique(amplify_amounts)): + mask = safe_mask(X, amplify_amounts == amount) + X_part = X[mask] + y_part = y[mask] + X_parts.extend([X_part] * amount) + y_parts.extend([y_part] * amount) + + if sparse.issparse(X): + X_new = sparse.vstack(X_parts) + else: + X_new = np.vstack(X_parts) + y_new = np.hstack(y_parts) + + self._X_resampled.append(X_new) + self._y_resampled.append(y_new) + return nn_indices + + def _fit_resample(self, X, y): + self._validate_estimator() + + self._X_resampled = [] + self._y_resampled = [] + self._y = y.copy() + + self.nn_.fit(X) + is_safe = self._knn_correct(X, y) + + strategy = self.sampling_strategy_ + for class_sample in filter(strategy.get, strategy): + is_class = (y == class_sample) + self._amplify_indices = np.flatnonzero(~is_class & is_safe) + discard_indices = np.flatnonzero(~is_class & ~is_safe) + + class_noisy_indices = np.flatnonzero(is_class & ~is_safe) + X_class_noisy = _safe_indexing(X, class_noisy_indices) + y_class_noisy = y[class_noisy_indices] + + if self.kind_sel in ("weak", "relabel"): + nn_indices = self._amplify(X_class_noisy, y_class_noisy) + + if self.kind_sel == "relabel": + relabel_mask = np.isin(nn_indices, discard_indices) + relabel_indices = np.unique(nn_indices[relabel_mask]) + self._y[relabel_indices] = class_sample + discard_indices = np.setdiff1d( + discard_indices, relabel_indices) + + elif self.kind_sel == "strong": + class_safe_indices = np.flatnonzero(is_class & is_safe) + X_class_safe = _safe_indexing(X, class_safe_indices) + y_class_safe = y[class_safe_indices] + self._amplify(X_class_safe, y_class_safe) + + is_correct = self._knn_correct( + X_class_noisy, y_class_noisy, additional=True) + + X_correct = X_class_noisy[ + safe_mask(X_class_noisy, is_correct)] + y_correct = y_class_noisy[is_correct] + self._amplify(X_correct, y_correct) + + X_incorrect = X_class_noisy[ + safe_mask(X_class_noisy, ~is_correct)] + y_incorrect = y_class_noisy[~is_correct] + self._amplify(X_incorrect, y_incorrect, additional=True) + else: + raise NotImplementedError(self.kind_sel) + + discard_mask = np.ones_like(y, dtype=bool) + try: + discard_mask[discard_indices] = False + except UnboundLocalError: + pass + + X_resampled = self._X_resampled + y_resampled = self._y_resampled + + X_resampled.append(X[safe_mask(X, discard_mask)]) + y_resampled.append(self._y[discard_mask]) + + if sparse.issparse(X): + X_resampled = sparse.vstack(X_resampled, format=X.format) + else: + X_resampled = np.vstack(X_resampled) + y_resampled = np.hstack(y_resampled) + + del self._X_resampled, self._y_resampled, self._y + self._amplify_indices = None + return X_resampled, y_resampled diff --git a/imblearn/combine/_preprocess/base.py b/imblearn/combine/_preprocess/base.py new file mode 100644 index 000000000..fa0bfb92f --- /dev/null +++ b/imblearn/combine/_preprocess/base.py @@ -0,0 +1,41 @@ +"""Base class for the preprocess-sampling method.""" + +# Author: Matthew Eding +# License: MIT + +from ...base import BaseSampler + + +class BasePreprocessSampler(BaseSampler): + """Base class for preprocess-sampling algorithms. + + Warning: This class should not be used directly. Use the derive classes + instead. + """ + _sampling_type = 'preprocess-sampling' + + _sampling_strategy_docstring = \ + """sampling_strategy : str, list or callable + Sampling information to sample the data set. + + - When ``str``, specify the class targeted by the resampling. Note the + the number of samples will not be equal in each. Possible choices + are: + + ``'minority'``: resample only the minority class; + + ``'not minority'``: resample all classes but the minority class; + + ``'not majority'``: resample all classes but the majority class; + + ``'all'``: resample all classes; + + ``'auto'``: equivalent to ``'not majority'``. + + - When ``list``, the list contains the classes targeted by the + resampling. + + - When callable, function taking ``y`` and returns a ``dict``. The keys + correspond to the targeted classes. The values correspond to the + desired number of samples for each class. + """.rstrip() diff --git a/imblearn/combine/tests/test_spider.py b/imblearn/combine/tests/test_spider.py new file mode 100644 index 000000000..0fcea0a6b --- /dev/null +++ b/imblearn/combine/tests/test_spider.py @@ -0,0 +1,328 @@ +"""Test the module SPIDER.""" +# Authors: Matthew Eding +# License: MIT + +import pytest +import numpy as np + +from sklearn.neighbors import NearestNeighbors +from sklearn.utils._testing import assert_allclose +from sklearn.utils._testing import assert_array_equal + +from imblearn.combine import SPIDER + + +RND_SEED = 0 +X = np.array( + [ + [-11.83, -6.81], + [-11.72, -2.34], + [-11.43, -5.85], + [-10.66, -4.33], + [-9.64, -7.05], + [-8.39, -4.41], + [-8.07, -5.66], + [-7.28, 0.91], + [-7.24, -2.41], + [-6.13, -4.81], + [-5.92, -6.81], + [-4., -1.81], + [-3.96, 2.67], + [-3.74, -7.31], + [-2.96, 4.69], + [-1.56, -2.33], + [-1.02, -4.57], + [0.46, 4.07], + [1.2, -1.53], + [1.32, 0.41], + [1.56, -5.19], + [2.52, 5.89], + [3.03, -4.15], + [4., -0.59], + [4.4, 2.07], + [4.41, -7.45], + [4.45, -4.12], + [5.13, -6.28], + [5.4, -5], + [6.26, 4.65], + [7.02, -6.22], + [7.5, -0.11], + [8.1, -2.05], + [8.42, 2.47], + [9.62, 3.87], + [10.54, -4.47], + [11.42, 0.01] + ] +) +y = np.array( + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, + 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0 + ] +) +R_TOL = 1e-4 + + +def test_spider_init(): + spider = SPIDER() + assert spider.n_neighbors == 3 + assert spider.additional_neighbors == 2 + assert spider.kind_sel == "weak" + assert spider.n_jobs is None + + +def test_spider_weak(): + weak = SPIDER(kind_sel="weak") + X_resampled, y_resampled = weak.fit_resample(X, y) + X_gt = np.array( + [ + [3.03, -4.15], + [-3.96, 2.67], + [-3.96, 2.67], + [-3.96, 2.67], + [-11.83, -6.81], + [-11.72, -2.34], + [-11.43, -5.85], + [-10.66, -4.33], + [-9.64, -7.05], + [-8.39, -4.41], + [-8.07, -5.66], + [-7.28, 0.91], + [-7.24, -2.41], + [-6.13, -4.81], + [-5.92, -6.81], + [-4., -1.81], + [-3.96, 2.67], + [-3.74, -7.31], + [-2.96, 4.69], + [-1.56, -2.33], + [-1.02, -4.57], + [0.46, 4.07], + [1.2, -1.53], + [1.32, 0.41], + [1.56, -5.19], + [3.03, -4.15], + [4., -0.59], + [4.4, 2.07], + [4.41, -7.45], + [5.13, -6.28], + [5.4, -5.], + [6.26, 4.65], + [7.02, -6.22], + [8.1, -2.05], + [8.42, 2.47], + [10.54, -4.47], + [11.42, 0.01] + ] + ) + y_gt = np.array( + [ + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 0, 0 + ] + ) + assert_allclose(X_resampled, X_gt, rtol=R_TOL) + assert_array_equal(y_resampled, y_gt) + + +def test_spider_relabel(): + relabel = SPIDER(kind_sel="relabel") + X_resampled, y_resampled = relabel.fit_resample(X, y) + X_gt = np.array( + [ + [3.03, -4.15], + [-3.96, 2.67], + [-3.96, 2.67], + [-3.96, 2.67], + [-11.83, -6.81], + [-11.72, -2.34], + [-11.43, -5.85], + [-10.66, -4.33], + [-9.64, -7.05], + [-8.39, -4.41], + [-8.07, -5.66], + [-7.28, 0.91], + [-7.24, -2.41], + [-6.13, -4.81], + [-5.92, -6.81], + [-4., -1.81], + [-3.96, 2.67], + [-3.74, -7.31], + [-2.96, 4.69], + [-1.56, -2.33], + [-1.02, -4.57], + [0.46, 4.07], + [1.2, -1.53], + [1.32, 0.41], + [1.56, -5.19], + [3.03, -4.15], + [4., -0.59], + [4.4, 2.07], + [4.41, -7.45], + [4.45, -4.12], + [5.13, -6.28], + [5.4, -5.], + [6.26, 4.65], + [7.02, -6.22], + [7.5, -0.11], + [8.1, -2.05], + [8.42, 2.47], + [9.62, 3.87], + [10.54, -4.47], + [11.42, 0.01] + ] + ) + y_gt = np.array( + [ + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 0, 1, 1, 0, 0 + ] + ) + assert_allclose(X_resampled, X_gt, rtol=R_TOL) + assert_array_equal(y_resampled, y_gt) + + +def test_spider_strong(): + strong = SPIDER(kind_sel="strong") + X_resampled, y_resampled = strong.fit_resample(X, y) + X_gt = np.array( + [ + [1.2, -1.53], + [3.03, -4.15], + [8.42, 2.47], + [-3.96, 2.67], + [-3.96, 2.67], + [-3.96, 2.67], + [-3.96, 2.67], + [-3.96, 2.67], + [-11.83, -6.81], + [-11.72, -2.34], + [-11.43, -5.85], + [-10.66, -4.33], + [-9.64, -7.05], + [-8.39, -4.41], + [-8.07, -5.66], + [-7.28, 0.91], + [-7.24, -2.41], + [-6.13, -4.81], + [-5.92, -6.81], + [-4., -1.81], + [-3.96, 2.67], + [-3.74, -7.31], + [-2.96, 4.69], + [-1.56, -2.33], + [-1.02, -4.57], + [0.46, 4.07], + [1.2, -1.53], + [1.32, 0.41], + [1.56, -5.19], + [3.03, -4.15], + [4., -0.59], + [4.4, 2.07], + [4.41, -7.45], + [5.13, -6.28], + [5.4, -5.], + [6.26, 4.65], + [7.02, -6.22], + [8.1, -2.05], + [8.42, 2.47], + [10.54, -4.47], + [11.42, 0.01] + ] + ) + y_gt = np.array( + [ + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, + 1, 1, 1, 0, 1, 0, 0 + ] + ) + assert_allclose(X_resampled, X_gt, rtol=R_TOL) + assert_array_equal(y_resampled, y_gt) + + +def test_spider_wrong_kind_sel(): + spider = SPIDER(kind_sel="rand") + with pytest.raises(ValueError, match='The possible "kind" of algorithm'): + spider.fit_resample(X, y) + + +def test_spider_fit_resample_with_nn_object(): + nn = NearestNeighbors(n_neighbors=4) + spider = SPIDER(n_neighbors=nn) + X_resampled, y_resampled = spider.fit_resample(X, y) + X_gt = np.array( + [ + [3.03, -4.15], + [-3.96, 2.67], + [-3.96, 2.67], + [-3.96, 2.67], + [-11.83, -6.81], + [-11.72, -2.34], + [-11.43, -5.85], + [-10.66, -4.33], + [-9.64, -7.05], + [-8.39, -4.41], + [-8.07, -5.66], + [-7.28, 0.91], + [-7.24, -2.41], + [-6.13, -4.81], + [-5.92, -6.81], + [-4., -1.81], + [-3.96, 2.67], + [-3.74, -7.31], + [-2.96, 4.69], + [-1.56, -2.33], + [-1.02, -4.57], + [0.46, 4.07], + [1.2, -1.53], + [1.32, 0.41], + [1.56, -5.19], + [3.03, -4.15], + [4., -0.59], + [4.4, 2.07], + [4.41, -7.45], + [5.13, -6.28], + [5.4, -5.], + [6.26, 4.65], + [7.02, -6.22], + [8.1, -2.05], + [8.42, 2.47], + [10.54, -4.47], + [11.42, 0.01] + ] + ) + y_gt = np.array( + [ + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 0, 0 + ] + ) + assert_allclose(X_resampled, X_gt, rtol=R_TOL) + assert_array_equal(y_resampled, y_gt) + + +def test_spider_not_good_object(): + nn = "rand" + spider = SPIDER(n_neighbors=nn) + with pytest.raises(ValueError, match="has to be one of"): + spider.fit_resample(X, y) + + +@pytest.mark.parametrize( + "add_neigh, err_type, err_msg", + [ + (0, ValueError, "additional_neighbors must be at least 1"), + (0.0, TypeError, "additional_neighbors must be an integer"), + (2.0, TypeError, "additional_neighbors must be an integer"), + ("2", TypeError, "additional_neighbors must be an integer"), + (2 + 0j, TypeError, "additional_neighbors must be an integer"), + ], +) +def test_spider_invalid_additional_neighbors(add_neigh, err_type, err_msg): + spider = SPIDER(additional_neighbors=add_neigh) + with pytest.raises(err_type, match=err_msg): + spider.fit_resample(X, y) diff --git a/imblearn/utils/_validation.py b/imblearn/utils/_validation.py index 8cb505f50..22749c1c0 100644 --- a/imblearn/utils/_validation.py +++ b/imblearn/utils/_validation.py @@ -22,6 +22,7 @@ "clean-sampling", "ensemble", "bypass", + "preprocess-sampling", ) TARGET_KIND = ("binary", "multiclass", "multilabel-indicator") @@ -103,15 +104,13 @@ def check_target_type(y, indicate_one_vs_all=False): def _sampling_strategy_all(y, sampling_type): """Returns sampling target by targeting all classes.""" target_stats = _count_class_sample(y) - if sampling_type == "over-sampling": + if sampling_type in ("over-sampling", "preprocess-sampling"): n_sample_majority = max(target_stats.values()) sampling_strategy = { key: n_sample_majority - value for (key, value) in target_stats.items() } - elif ( - sampling_type == "under-sampling" or sampling_type == "clean-sampling" - ): + elif sampling_type in ("under-sampling", "clean-sampling"): n_sample_minority = min(target_stats.values()) sampling_strategy = { key: n_sample_minority for key in target_stats.keys() @@ -124,14 +123,12 @@ def _sampling_strategy_all(y, sampling_type): def _sampling_strategy_majority(y, sampling_type): """Returns sampling target by targeting the majority class only.""" - if sampling_type == "over-sampling": + if sampling_type in ("over-sampling", "preprocess-sampling"): raise ValueError( "'sampling_strategy'='majority' cannot be used with" - " over-sampler." + " over-sampler or preprocess-sampler." ) - elif ( - sampling_type == "under-sampling" or sampling_type == "clean-sampling" - ): + elif sampling_type in ("under-sampling", "clean-sampling"): target_stats = _count_class_sample(y) class_majority = max(target_stats, key=target_stats.get) n_sample_minority = min(target_stats.values()) @@ -150,7 +147,7 @@ def _sampling_strategy_not_majority(y, sampling_type): """Returns sampling target by targeting all classes but not the majority.""" target_stats = _count_class_sample(y) - if sampling_type == "over-sampling": + if sampling_type in ("over-sampling", "preprocess-sampling"): n_sample_majority = max(target_stats.values()) class_majority = max(target_stats, key=target_stats.get) sampling_strategy = { @@ -158,9 +155,7 @@ def _sampling_strategy_not_majority(y, sampling_type): for (key, value) in target_stats.items() if key != class_majority } - elif ( - sampling_type == "under-sampling" or sampling_type == "clean-sampling" - ): + elif sampling_type in ("under-sampling", "clean-sampling"): n_sample_minority = min(target_stats.values()) class_majority = max(target_stats, key=target_stats.get) sampling_strategy = { @@ -178,7 +173,7 @@ def _sampling_strategy_not_minority(y, sampling_type): """Returns sampling target by targeting all classes but not the minority.""" target_stats = _count_class_sample(y) - if sampling_type == "over-sampling": + if sampling_type in ("over-sampling", "preprocess-sampling"): n_sample_majority = max(target_stats.values()) class_minority = min(target_stats, key=target_stats.get) sampling_strategy = { @@ -186,9 +181,7 @@ def _sampling_strategy_not_minority(y, sampling_type): for (key, value) in target_stats.items() if key != class_minority } - elif ( - sampling_type == "under-sampling" or sampling_type == "clean-sampling" - ): + elif sampling_type in ("under-sampling", "clean-sampling"): n_sample_minority = min(target_stats.values()) class_minority = min(target_stats, key=target_stats.get) sampling_strategy = { @@ -205,7 +198,7 @@ def _sampling_strategy_not_minority(y, sampling_type): def _sampling_strategy_minority(y, sampling_type): """Returns sampling target by targeting the minority class only.""" target_stats = _count_class_sample(y) - if sampling_type == "over-sampling": + if sampling_type in ("over-sampling", "preprocess-sampling"): n_sample_majority = max(target_stats.values()) class_minority = min(target_stats, key=target_stats.get) sampling_strategy = { @@ -213,9 +206,7 @@ def _sampling_strategy_minority(y, sampling_type): for (key, value) in target_stats.items() if key == class_minority } - elif ( - sampling_type == "under-sampling" or sampling_type == "clean-sampling" - ): + elif sampling_type in ("under-sampling", "clean-sampling"): raise ValueError( "'sampling_strategy'='minority' cannot be used with" " under-sampler and clean-sampler." @@ -229,11 +220,9 @@ def _sampling_strategy_minority(y, sampling_type): def _sampling_strategy_auto(y, sampling_type): """Returns sampling target auto for over-sampling and not-minority for under-sampling.""" - if sampling_type == "over-sampling": + if sampling_type in ("over-sampling", "preprocess-sampling"): return _sampling_strategy_not_majority(y, sampling_type) - elif ( - sampling_type == "under-sampling" or sampling_type == "clean-sampling" - ): + elif sampling_type in ("under-sampling", "clean-sampling"): return _sampling_strategy_not_minority(y, sampling_type) @@ -301,11 +290,11 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type): ) ) sampling_strategy_[class_sample] = n_samples - elif sampling_type == "clean-sampling": + elif sampling_type in ("clean-sampling", "preprocess-sampling"): raise ValueError( - "'sampling_strategy' as a dict for cleaning methods is " - "not supported. Please give a list of the classes to be " - "targeted by the sampling." + "'sampling_strategy' as a dict for cleaning or preprocess " + "methods is not supported. Please give a list of the classes " + "to be targeted by the sampling." ) else: raise NotImplementedError @@ -316,10 +305,10 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type): def _sampling_strategy_list(sampling_strategy, y, sampling_type): """With cleaning methods, sampling_strategy can be a list to target the class of interest.""" - if sampling_type != "clean-sampling": + if sampling_type not in ("clean-sampling", "preprocess-sampling"): raise ValueError( "'sampling_strategy' cannot be a list for samplers " - "which are not cleaning methods." + "which are not cleaning or preprocess methods." ) target_stats = _count_class_sample(y) @@ -385,8 +374,8 @@ def _sampling_strategy_float(sampling_strategy, y, sampling_type): ) else: raise ValueError( - "'clean-sampling' methods do let the user " - "specify the sampling ratio." + "'clean-sampling' and 'preprocess-sampling' methods do not let " + "the user specify the sampling ratio." ) return sampling_strategy_ @@ -420,12 +409,13 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): .. warning:: ``float`` is only available for **binary** classification. An error is raised for multi-class classification and with cleaning - samplers. + or preprocessing samplers. - When ``str``, specify the class targeted by the resampling. For **under- and over-sampling methods**, the number of samples in the - different classes will be equalized. For **cleaning methods**, the - number of samples will not be equal. Possible choices are: + different classes will be equalized. For **cleaning and + preprocessing methods**, the number of samples will not be equal. + Possible choices are: ``'minority'``: resample only the minority class; @@ -438,8 +428,8 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): ``'all'``: resample all classes; ``'auto'``: for under-sampling methods, equivalent to ``'not - minority'`` and for over-sampling methods, equivalent to ``'not - majority'``. + minority'`` and for preprocessing and over-sampling methods, + equivalent to ``'not majority'``. - When ``dict``, the keys correspond to the targeted classes. The values correspond to the desired number of samples for each targeted @@ -448,14 +438,14 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): .. warning:: ``dict`` is available for both **under- and over-sampling methods**. An error is raised with **cleaning methods**. Use a - ``list`` instead. + ``list`` instead. An error is raised with **preprocess methods**. - When ``list``, the list contains the targeted classes. It used only for **cleaning methods**. .. warning:: - ``list`` is available for **cleaning methods**. An error is raised - with **under- and over-sampling methods**. + ``list`` is available for **cleaning and preprocess methods**. An + error is raised with **under- and over-sampling methods**. - When callable, function taking ``y`` and returns a ``dict``. The keys correspond to the targeted classes. The values correspond to the @@ -466,7 +456,8 @@ def check_sampling_strategy(sampling_strategy, y, sampling_type, **kwargs): sampling_type : str, The type of sampling. Can be either ``'over-sampling'``, - ``'under-sampling'``, or ``'clean-sampling'``. + ``'under-sampling'``, ``'clean-sampling'``, or + ``'preprocess-sampling'``. kwargs : dict, optional Dictionary of additional keyword arguments to pass to diff --git a/imblearn/utils/tests/test_validation.py b/imblearn/utils/tests/test_validation.py index 634f502f0..7bc52304c 100644 --- a/imblearn/utils/tests/test_validation.py +++ b/imblearn/utils/tests/test_validation.py @@ -68,11 +68,14 @@ def test_check_target_type_ova(target, output_target, is_ova): assert binarize_target == is_ova -def test_check_sampling_strategy_warning(): - msg = "dict for cleaning methods is not supported" +@pytest.mark.parametrize( + "sampling_method", ["clean-sampling", "preprocess-sampling"] +) +def test_check_sampling_strategy_warning(sampling_method): + msg = "dict for cleaning or preprocess methods is not supported" with pytest.raises(ValueError, match=msg): check_sampling_strategy( - {1: 0, 2: 0, 3: 0}, multiclass_target, "clean-sampling" + {1: 0, 2: 0, 3: 0}, multiclass_target, sampling_method ) @@ -83,7 +86,13 @@ def test_check_sampling_strategy_warning(): 0.5, binary_target, "clean-sampling", - "'clean-sampling' methods do let the user specify the sampling ratio", # noqa + "sampling' methods do not let the user specify the sampling ratio", + ), + ( + 0.5, + binary_target, + "preprocess-sampling", + "sampling' methods do not let the user specify the sampling ratio", ), ( 0.1, @@ -122,6 +131,7 @@ def test_check_sampling_strategy_error(): [ ("majority", "over-sampling", "over-sampler"), ("minority", "under-sampling", "under-sampler"), + ("minority", "clean-sampling", "under-sampler"), ], ) def test_check_sampling_strategy_error_wrong_string(