|
2 | 2 | import pandas as pd
|
3 | 3 | from pandas.testing import assert_frame_equal
|
4 | 4 |
|
5 |
| -from delphi_nwss.run import ( |
6 |
| - add_needed_columns, |
7 |
| - generate_weights, |
8 |
| - sum_all_nan, |
9 |
| - weighted_state_sum, |
10 |
| - weighted_nation_sum, |
11 |
| -) |
12 |
| - |
13 |
| - |
14 |
| -def test_sum_all_nan(): |
15 |
| - """Check that sum_all_nan returns NaN iff everything is a NaN""" |
16 |
| - assert sum_all_nan(np.array([3, 5])) == 8 |
17 |
| - assert np.isclose(sum_all_nan([np.nan, 3, 5]), 8) |
18 |
| - assert np.isnan(np.array([np.nan, np.nan])).all() |
19 |
| - |
20 |
| - |
21 |
| -def test_weight_generation(): |
22 |
| - dataFrame = pd.DataFrame( |
23 |
| - { |
24 |
| - "a": [1, 2, 3, 4, np.nan], |
25 |
| - "b": [5, 6, 7, 8, 9], |
26 |
| - "population_served": [10, 5, 8, 1, 3], |
27 |
| - } |
28 |
| - ) |
29 |
| - weighted = generate_weights(dataFrame, column_aggregating="a") |
30 |
| - weighted_by_hand = pd.DataFrame( |
31 |
| - { |
32 |
| - "a": [1, 2, 3, 4, np.nan], |
33 |
| - "b": [5, 6, 7, 8, 9], |
34 |
| - "population_served": [10, 5, 8, 1, 3], |
35 |
| - "relevant_pop_a": [10, 5, 8, 1, 0], |
36 |
| - "weighted_a": [10.0, 2 * 5.0, 3 * 8, 4.0 * 1, np.nan * 0], |
37 |
| - } |
38 |
| - ) |
39 |
| - assert_frame_equal(weighted, weighted_by_hand) |
40 |
| - # operations are in-place |
41 |
| - assert_frame_equal(weighted, dataFrame) |
42 |
| - |
43 |
| - |
44 |
| -def test_weighted_state_sum(): |
45 |
| - dataFrame = pd.DataFrame( |
46 |
| - { |
47 |
| - "state": ["al", "al", "ca", "ca", "nd", "me", "me"], |
48 |
| - "timestamp": np.zeros(7), |
49 |
| - "a": [1, 2, 3, 4, 12, -2, 2], |
50 |
| - "b": [5, 6, 7, np.nan, np.nan, -1, -2], |
51 |
| - "population_served": [10, 5, 8, 1, 3, 1, 2], |
52 |
| - } |
53 |
| - ) |
54 |
| - weighted = generate_weights(dataFrame, column_aggregating="b") |
55 |
| - agg = weighted_state_sum(weighted, "state", "b") |
56 |
| - expected_agg = pd.DataFrame( |
57 |
| - { |
58 |
| - "timestamp": np.zeros(4), |
59 |
| - "geo_id": ["al", "ca", "me", "nd"], |
60 |
| - "relevant_pop_b": [10 + 5, 8 + 0, 1 + 2, 0], |
61 |
| - "weighted_b": [5 * 10 + 6 * 5, 7 * 8 + 0, 1 * -1 + -2 * 2, np.nan], |
62 |
| - "val": [80 / 15, 56 / 8, -5 / 3, np.nan], |
63 |
| - } |
64 |
| - ) |
65 |
| - assert_frame_equal(agg, expected_agg) |
66 |
| - |
67 |
| - weighted = generate_weights(dataFrame, column_aggregating="a") |
68 |
| - agg_a = weighted_state_sum(weighted, "state", "a") |
69 |
| - expected_agg_a = pd.DataFrame( |
70 |
| - { |
71 |
| - "timestamp": np.zeros(4), |
72 |
| - "geo_id": ["al", "ca", "me", "nd"], |
73 |
| - "relevant_pop_a": [10 + 5, 8 + 1, 1 + 2, 3], |
74 |
| - "weighted_a": [1 * 10 + 2 * 5, 3 * 8 + 1 * 4, -2 * 1 + 2 * 2, 12 * 3], |
75 |
| - "val": [20 / 15, 28 / 9, (-2 * 1 + 2 * 2) / 3, 36 / 3], |
76 |
| - } |
77 |
| - ) |
78 |
| - assert_frame_equal(agg_a, expected_agg_a) |
79 |
| - |
80 |
| - |
81 |
| -def test_weighted_nation_sum(): |
82 |
| - dataFrame = pd.DataFrame( |
83 |
| - { |
84 |
| - "state": [ |
85 |
| - "al", |
86 |
| - "al", |
87 |
| - "ca", |
88 |
| - "ca", |
89 |
| - "nd", |
90 |
| - ], |
91 |
| - "timestamp": np.hstack((np.zeros(3), np.ones(2))), |
92 |
| - "a": [1, 2, 3, 4, 12], |
93 |
| - "b": [5, 6, 7, np.nan, np.nan], |
94 |
| - "population_served": [10, 5, 8, 1, 3], |
95 |
| - } |
96 |
| - ) |
97 |
| - weighted = generate_weights(dataFrame, column_aggregating="a") |
98 |
| - agg = weighted_nation_sum(weighted, "a") |
99 |
| - expected_agg = pd.DataFrame( |
100 |
| - { |
101 |
| - "timestamp": [0.0, 1], |
102 |
| - "relevant_pop_a": [10 + 5 + 8, 1 + 3], |
103 |
| - "weighted_a": [1 * 10 + 2 * 5 + 3 * 8, 1 * 4 + 3 * 12], |
104 |
| - "val": [44 / 23, 40 / 4], |
105 |
| - "geo_id": ["us", "us"], |
106 |
| - } |
107 |
| - ) |
108 |
| - assert_frame_equal(agg, expected_agg) |
| 5 | +from delphi_nwss.run import add_needed_columns |
109 | 6 |
|
110 | 7 |
|
111 | 8 | def test_adding_cols():
|
|
0 commit comments