Skip to content

Commit f1d855c

Browse files
authored
Add center (and average) (#246)
* Add average and center * Rework center with alias * Fixup * Fixup2 * Fixup3 * Address some comments * Remove impl and refactor * Remove attributes from center * Remove artifact * Fix mistake * Replace static immutable * Remove excess import * Fixup UT
1 parent d29b10b commit f1d855c

File tree

1 file changed

+150
-3
lines changed

1 file changed

+150
-3
lines changed

source/mir/math/stat.d

+150-3
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,7 @@ template hmean(string summation)
360360
}
361361

362362
/// Harmonic mean of vector
363+
version(mir_test)
363364
pure @safe nothrow @nogc
364365
unittest
365366
{
@@ -371,6 +372,7 @@ unittest
371372
}
372373

373374
/// Harmonic mean of matrix
375+
version(mir_test)
374376
pure @safe
375377
unittest
376378
{
@@ -383,6 +385,7 @@ unittest
383385
}
384386

385387
/// Column harmonic mean of matrix
388+
version(mir_test)
386389
pure @safe
387390
unittest
388391
{
@@ -404,14 +407,15 @@ unittest
404407
}
405408

406409
/// Can also pass arguments to hmean
407-
pure @safe
410+
version(mir_test)
411+
pure @safe nothrow @nogc
408412
unittest
409413
{
410-
import mir.ndslice.topology: map, repeat;
414+
import mir.ndslice.topology: repeat;
411415
import mir.math.common: approxEqual;
412416

413417
//Set sum algorithm or output type
414-
auto x = [1, 1e-100, 1, -1e-100];
418+
static immutable x = [1, 1e-100, 1, -1e-100];
415419

416420
assert(x.hmean!"kb2".approxEqual(2));
417421
assert(x.hmean!"precise".approxEqual(2));
@@ -420,6 +424,149 @@ unittest
420424
assert(float.max.repeat(3).hmean!(double, "fast").approxEqual(float.max));
421425
}
422426

427+
/++
428+
Centers `slice`, which must be a finite iterable.
429+
430+
By default, `slice` is centered by the mean. A custom function may also be provided
431+
using `centralTendency`.
432+
433+
Returns:
434+
The elements in the slice with the average subtracted from them.
435+
+/
436+
template center(alias centralTendency = mean!(Summation.appropriate))
437+
{
438+
import mir.ndslice.slice: Slice, SliceKind, sliced, hasAsSlice;
439+
/++
440+
Params:
441+
slice = slice
442+
+/
443+
auto center(Iterator, size_t N, SliceKind kind)(
444+
Slice!(Iterator, N, kind) slice)
445+
{
446+
import core.lifetime: move;
447+
import mir.ndslice.topology: vmap;
448+
import mir.ndslice.internal: LeftOp, ImplicitlyUnqual;
449+
450+
auto m = centralTendency(slice.lightScope);
451+
alias T = typeof(m);
452+
return slice.move.vmap(LeftOp!("-", ImplicitlyUnqual!T)(m));
453+
}
454+
455+
/// ditto
456+
auto center(T)(T[] array)
457+
{
458+
return center(array.sliced);
459+
}
460+
461+
/// ditto
462+
auto center(T)(T withAsSlice)
463+
if (hasAsSlice!T)
464+
{
465+
return center(withAsSlice.asSlice);
466+
}
467+
}
468+
469+
/// Center vector
470+
version(mir_test)
471+
@safe pure nothrow
472+
unittest
473+
{
474+
import mir.ndslice.slice: sliced;
475+
import mir.algorithm.iteration: all;
476+
import mir.math.common: approxEqual;
477+
478+
auto x = [1.0, 2, 3, 4, 5, 6].sliced;
479+
assert(x.center.all!approxEqual([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5]));
480+
481+
// Can center using different functions
482+
assert(x.center!hmean.all!approxEqual([-1.44898, -0.44898, 0.55102, 1.55102, 2.55102, 3.55102]));
483+
}
484+
485+
/// Center dynamic array
486+
version(mir_test)
487+
@safe pure nothrow
488+
unittest
489+
{
490+
import mir.algorithm.iteration: all;
491+
import mir.math.common: approxEqual;
492+
493+
auto x = [1.0, 2, 3, 4, 5, 6];
494+
assert(x.center.all!approxEqual([-2.5, -1.5, -0.5, 0.5, 1.5, 2.5]));
495+
}
496+
497+
/// Center matrix
498+
version(mir_test)
499+
@safe pure
500+
unittest
501+
{
502+
import mir.ndslice: fuse;
503+
import mir.algorithm.iteration: all;
504+
import mir.math.common: approxEqual;
505+
506+
auto x = [
507+
[0.0, 1, 2],
508+
[3.0, 4, 5]
509+
].fuse;
510+
511+
auto y = [
512+
[-2.5, -1.5, -0.5],
513+
[ 0.5, 1.5, 2.5]
514+
].fuse;
515+
516+
assert(x.center.all!approxEqual(y));
517+
}
518+
519+
/// Column center matrix
520+
version(mir_test)
521+
pure @safe
522+
unittest
523+
{
524+
import mir.algorithm.iteration: all, equal;
525+
import mir.math.common: approxEqual;
526+
import mir.ndslice: fuse;
527+
import mir.ndslice.topology: alongDim, byDim, map;
528+
529+
auto x = [
530+
[20.0, 100.0, 2000.0],
531+
[10.0, 5.0, 2.0]
532+
].fuse;
533+
534+
auto result = [
535+
[ 5.0, 47.5, 999],
536+
[-5.0, -47.5, -999]
537+
].fuse;
538+
539+
// Use byDim with map to compute average of row/column.
540+
auto xCenterByDim = x.byDim!1.map!center;
541+
auto resultByDim = result.byDim!1;
542+
assert(xCenterByDim.equal!(equal!approxEqual)(resultByDim));
543+
544+
auto xCenterAlongDim = x.alongDim!0.map!center;
545+
auto resultAlongDim = result.alongDim!0;
546+
assert(xCenterByDim.equal!(equal!approxEqual)(resultAlongDim));
547+
}
548+
549+
/// Can also pass arguments to average function used by center
550+
version(mir_test)
551+
pure @safe nothrow
552+
unittest
553+
{
554+
import mir.ndslice.slice: sliced;
555+
import mir.algorithm.iteration: all;
556+
import mir.ndslice.topology: repeat;
557+
import mir.math.common: approxEqual;
558+
559+
//Set sum algorithm or output type
560+
auto a = [1, 1e100, 1, -1e100];
561+
562+
auto x = a.sliced * 10_000;
563+
auto result = [5000, 1e104 - 5000, 5000, -1e104 - 5000].sliced;
564+
565+
assert(x.center!(mean!"kbn").all!approxEqual(result));
566+
assert(x.center!(mean!"kb2").all!approxEqual(result));
567+
assert(x.center!(mean!"precise").all!approxEqual(result));
568+
}
569+
423570
/++
424571
A linear regression model with a single explanatory variable.
425572
+/

0 commit comments

Comments
 (0)