Skip to content

Commit 2a1df61

Browse files
authored
add hmean (#244)
1 parent a1a2c37 commit 2a1df61

File tree

1 file changed

+139
-18
lines changed

1 file changed

+139
-18
lines changed

source/mir/math/stat.d

+139-18
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import core.lifetime: move;
2222
import mir.math.common: fmamath;
2323
import mir.math.sum;
2424
import mir.primitives;
25-
import std.range.primitives: isInputRange;
2625
import std.traits: isArray, isFloatingPoint, isMutable, isIterable;
2726

2827
/++
@@ -73,7 +72,7 @@ struct MeanAccumulator(T, Summation summation)
7372
version(mir_test)
7473
@safe pure nothrow unittest
7574
{
76-
import mir.ndslice.slice : sliced;
75+
import mir.ndslice.slice: sliced;
7776

7877
MeanAccumulator!(double, Summation.pairwise) x;
7978
x.put([0.0, 1, 2, 3, 4].sliced);
@@ -85,7 +84,7 @@ version(mir_test)
8584
version(mir_test)
8685
@safe pure nothrow unittest
8786
{
88-
import mir.ndslice.slice : sliced;
87+
import mir.ndslice.slice: sliced;
8988

9089
MeanAccumulator!(float, Summation.pairwise) x;
9190
x.put([0, 1, 2, 3, 4].sliced);
@@ -106,10 +105,10 @@ template mean(F, Summation summation = Summation.appropriate)
106105
Params:
107106
r = range
108107
+/
109-
F mean(Range)(Range r)
108+
@fmamath F mean(Range)(Range r)
110109
if (isIterable!Range)
111110
{
112-
MeanAccumulator!(F, ResolveSummationType!(summation, Range, sumType!Range)) mean;
111+
MeanAccumulator!(F, ResolveSummationType!(summation, Range, F)) mean;
113112
mean.put(r.move);
114113
return mean.mean;
115114
}
@@ -122,7 +121,7 @@ template mean(Summation summation = Summation.appropriate)
122121
Params:
123122
r = range
124123
+/
125-
sumType!Range mean(Range)(Range r)
124+
@fmamath sumType!Range mean(Range)(Range r)
126125
if (isIterable!Range)
127126
{
128127
return .mean!(sumType!Range, summation)(r.move);
@@ -145,7 +144,7 @@ template mean(string summation)
145144
version(mir_test)
146145
@safe pure nothrow unittest
147146
{
148-
import mir.ndslice.slice : sliced;
147+
import mir.ndslice.slice: sliced;
149148

150149
assert(mean([1.0, 2, 3]) == 2);
151150
assert(mean([1.0 + 3i, 2, 3]) == 2 + 1i);
@@ -160,7 +159,7 @@ version(mir_test)
160159
@safe @nogc pure nothrow
161160
unittest
162161
{
163-
import mir.ndslice.slice : sliced;
162+
import mir.ndslice.slice: sliced;
164163

165164
static immutable x = [0.0, 1.0, 1.5, 2.0, 3.5, 4.25,
166165
2.0, 7.5, 5.0, 1.0, 1.5, 0.0];
@@ -172,7 +171,7 @@ version(mir_test)
172171
@safe @nogc pure nothrow
173172
unittest
174173
{
175-
import mir.ndslice.slice : sliced;
174+
import mir.ndslice.slice: sliced;
176175

177176
static immutable x = [0.0, 1.0, 1.5, 2.0, 3.5, 4.25,
178177
2.0, 7.5, 5.0, 1.0, 1.5, 0.0];
@@ -184,8 +183,8 @@ version(mir_test)
184183
@safe @nogc pure nothrow
185184
unittest
186185
{
187-
import mir.ndslice.slice : sliced;
188-
import mir.ndslice.topology : alongDim, byDim, map;
186+
import mir.ndslice.slice: sliced;
187+
import mir.ndslice.topology: alongDim, byDim, map;
189188

190189
static immutable x = [0.0, 1.0, 1.5, 2.0, 3.5, 4.25,
191190
2.0, 7.5, 5.0, 1.0, 1.5, 0.0];
@@ -204,6 +203,7 @@ unittest
204203
/// Can also set algorithm or output type
205204
version(mir_test)
206205
@safe @nogc pure nothrow
206+
207207
unittest
208208
{
209209
import mir.ndslice.slice: sliced;
@@ -231,8 +231,8 @@ version(mir_test)
231231
@safe @nogc pure nothrow
232232
unittest
233233
{
234-
import mir.ndslice.slice : sliced;
235-
import std.math : approxEqual;
234+
import mir.ndslice.slice: sliced;
235+
import mir.math.common: approxEqual;
236236

237237
static immutable x = [0, 1, 1, 2, 4, 4,
238238
2, 7, 5, 1, 2, 0];
@@ -244,7 +244,7 @@ version(mir_test)
244244
@safe @nogc pure nothrow
245245
unittest
246246
{
247-
import mir.ndslice.slice : sliced;
247+
import mir.ndslice.slice: sliced;
248248

249249
static immutable cdouble[] x = [1.0 + 2i, 2 + 3i, 3 + 4i, 4 + 5i];
250250
static immutable cdouble result = 2.5 + 3.5i;
@@ -256,7 +256,7 @@ version(mir_test)
256256
@safe @nogc pure nothrow
257257
unittest
258258
{
259-
import mir.ndslice : alongDim, iota, as, map;
259+
import mir.ndslice: alongDim, iota, as, map;
260260
/*
261261
[[0,1,2],
262262
[3,4,5]]
@@ -280,8 +280,8 @@ version(mir_test)
280280
@safe @nogc pure nothrow
281281
unittest
282282
{
283-
import mir.ndslice.slice : sliced;
284-
import mir.ndslice.topology : alongDim, byDim, map;
283+
import mir.ndslice.slice: sliced;
284+
import mir.ndslice.topology: alongDim, byDim, map;
285285

286286
static immutable x = [0.0, 1.00, 1.50, 2.0, 3.5, 4.250,
287287
2.0, 7.50, 5.00, 1.0, 1.5, 0.000];
@@ -298,12 +298,133 @@ version(mir_test)
298298
assert([1.0, 2, 3, 4].mean == 2.5);
299299
}
300300

301+
/++
302+
Computes the harmonic mean of a range.
303+
304+
See_also: $(SUBREF sum, sum)
305+
+/
306+
template hmean(F, Summation summation = Summation.appropriate)
307+
{
308+
/++
309+
Params:
310+
r = range
311+
Returns:
312+
harmonic mean of the range
313+
+/
314+
@fmamath F hmean(Range)(Range r)
315+
if (isIterable!Range)
316+
{
317+
import mir.ndslice.topology: map;
318+
static if (summation == Summation.fast && __traits(compiles, r.move.map!"1.0 / a"))
319+
{
320+
return 1.0 / r.move.map!"1.0 / a".mean!(F, summation);
321+
}
322+
else
323+
{
324+
MeanAccumulator!(F, ResolveSummationType!(summation, Range, F)) imean;
325+
foreach (e; r)
326+
imean.put(1.0 / e);
327+
return 1.0 / imean.mean;
328+
}
329+
}
330+
}
331+
332+
/// ditto
333+
template hmean(Summation summation = Summation.appropriate)
334+
{
335+
/++
336+
Params:
337+
r = range
338+
Returns:
339+
harmonic mean of the range
340+
+/
341+
@fmamath sumType!Range hmean(Range)(Range r)
342+
if (isIterable!Range)
343+
{
344+
return .hmean!(typeof(1.0 / sumType!Range.init), summation)(r.move);
345+
}
346+
}
347+
348+
/// ditto
349+
template hmean(F, string summation)
350+
{
351+
mixin("alias hmean = .hmean!(F, Summation." ~ summation ~ ");");
352+
}
353+
354+
/// ditto
355+
template hmean(string summation)
356+
{
357+
mixin("alias hmean = .hmean!(Summation." ~ summation ~ ");");
358+
}
359+
360+
/// Harmonic mean of vector
361+
pure @safe nothrow @nogc
362+
unittest
363+
{
364+
import mir.math.common: approxEqual;
365+
366+
static immutable x = [20.0, 100.0, 2000.0, 10.0, 5.0, 2.0];
367+
368+
assert(x.hmean.approxEqual(6.97269));
369+
}
370+
371+
/// Harmonic mean of matrix
372+
pure @safe
373+
unittest
374+
{
375+
import mir.math.common: approxEqual;
376+
import mir.ndslice.fuse: fuse;
377+
378+
auto x = [[20.0, 100.0, 2000.0], [10.0, 5.0, 2.0]].fuse;
379+
380+
assert(x.hmean.approxEqual(6.97269));
381+
}
382+
383+
/// Column harmonic mean of matrix
384+
pure @safe
385+
unittest
386+
{
387+
import mir.algorithm.iteration: all;
388+
import mir.math.common: approxEqual;
389+
import mir.ndslice: fuse;
390+
import mir.ndslice.topology: alongDim, byDim, map;
391+
392+
auto x = [
393+
[20.0, 100.0, 2000.0],
394+
[ 10.0, 5.0, 2.0]
395+
].fuse;
396+
397+
auto y = [13.33333, 9.52381, 3.996004];
398+
399+
// Use byDim or alongDim with map to compute mean of row/column.
400+
assert(x.byDim!1.map!hmean.all!approxEqual(y));
401+
assert(x.alongDim!0.map!hmean.all!approxEqual(y));
402+
}
403+
404+
/// Can also pass arguments to hmean
405+
pure @safe
406+
unittest
407+
{
408+
import mir.ndslice.topology: map, repeat;
409+
import mir.math.common: approxEqual;
410+
411+
//Set sum algorithm or output type
412+
auto x = [1, 1e-100, 1, -1e-100];
413+
414+
assert(x.hmean!"kb2".approxEqual(2));
415+
assert(x.hmean!"precise".approxEqual(2));
416+
417+
//Provide the summation type
418+
assert(float.max.repeat(3).hmean!(double, "fast").approxEqual(float.max));
419+
}
420+
301421
/++
302422
A linear regression model with a single explanatory variable.
303423
+/
304424
template simpleLinearRegression(Summation summation = Summation.kbn)
305425
{
306426
import mir.ndslice.slice;
427+
import std.range.primitives: isInputRange;
307428

308429
/++
309430
Params:
@@ -323,7 +444,7 @@ template simpleLinearRegression(Summation summation = Summation.kbn)
323444
do {
324445
alias X = typeof(sumType!XRange.init * sumType!XRange.init);
325446
alias Y = sumType!YRange;
326-
enum summationX = !__traits(isIntegral, X) ? summation : Summation.naive;
447+
enum summationX = !__traits(isIntegral, X) ? summation: Summation.naive;
327448
Summator!(X, summationX) xms = 0;
328449
Summator!(Y, summation) yms = 0;
329450
Summator!(X, summationX) xxms = 0;

0 commit comments

Comments
 (0)