forked from scikit-learn-contrib/imbalanced-learn
-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathtest_svm_smote.py
88 lines (74 loc) · 2.79 KB
/
test_svm_smote.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import pytest
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import NearestNeighbors
from sklearn.svm import SVC
from sklearn.utils._testing import assert_allclose, assert_array_equal
from imblearn.over_sampling import SVMSMOTE
@pytest.fixture
def data():
X = np.array(
[
[0.11622591, -0.0317206],
[0.77481731, 0.60935141],
[1.25192108, -0.22367336],
[0.53366841, -0.30312976],
[1.52091956, -0.49283504],
[-0.28162401, -2.10400981],
[0.83680821, 1.72827342],
[0.3084254, 0.33299982],
[0.70472253, -0.73309052],
[0.28893132, -0.38761769],
[1.15514042, 0.0129463],
[0.88407872, 0.35454207],
[1.31301027, -0.92648734],
[-1.11515198, -0.93689695],
[-0.18410027, -0.45194484],
[0.9281014, 0.53085498],
[-0.14374509, 0.27370049],
[-0.41635887, -0.38299653],
[0.08711622, 0.93259929],
[1.70580611, -0.11219234],
]
)
y = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0])
return X, y
def test_svm_smote(data):
svm_smote = SVMSMOTE(random_state=42)
svm_smote_nn = SVMSMOTE(
random_state=42,
k_neighbors=NearestNeighbors(n_neighbors=6),
m_neighbors=NearestNeighbors(n_neighbors=11),
svm_estimator=SVC(gamma="scale", random_state=42),
)
X_res_1, y_res_1 = svm_smote.fit_resample(*data)
X_res_2, y_res_2 = svm_smote_nn.fit_resample(*data)
assert_allclose(X_res_1, X_res_2)
assert_array_equal(y_res_1, y_res_2)
def test_svm_smote_not_svm(data):
"""Check that we raise a proper error if passing an estimator that does not
expose a `support_` fitted attribute."""
err_msg = "`svm_estimator` is required to exposed a `support_` fitted attribute."
with pytest.raises(RuntimeError, match=err_msg):
SVMSMOTE(svm_estimator=LogisticRegression()).fit_resample(*data)
def test_svm_smote_all_noise(data):
"""Check that we raise a proper error message when all support vectors are
detected as noise and there is nothing that we can do.
Non-regression test for:
https://github.com/scikit-learn-contrib/imbalanced-learn/issues/742
"""
X, y = make_classification(
n_classes=3,
class_sep=0.001,
weights=[0.004, 0.451, 0.545],
n_informative=3,
n_redundant=0,
flip_y=0,
n_features=3,
n_clusters_per_class=2,
n_samples=1000,
random_state=10,
)
with pytest.raises(ValueError, match="SVM-SMOTE is not adapted to your dataset"):
SVMSMOTE(k_neighbors=4, random_state=42).fit_resample(X, y)