Skip to content

Commit cdc31b8

Browse files
authored
add squeeze and unsqueeze (#235)
* add addInnerDim and addOuterDim * rework with numir like API
1 parent 60c330e commit cdc31b8

File tree

3 files changed

+172
-4
lines changed

3 files changed

+172
-4
lines changed

source/mir/algorithm/iteration.d

-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ import std.traits;
6060

6161
@optmath:
6262

63-
6463
/+
6564
Bitslice representation for accelerated bitwise algorithm.
6665
1-dimensional contiguousitslice can be split into three chunks: head bits, body chunks, and tail bits.

source/mir/ndslice/package.d

+2
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,13 @@ $(TR $(TDNW $(SUBMODULE topology) $(BR)
161161
$(SUBREF topology, ReshapeError)
162162
$(SUBREF topology, retro)
163163
$(SUBREF topology, slide)
164+
$(SUBREF topology, squeeze)
164165
$(SUBREF topology, stairs)
165166
$(SUBREF topology, stride)
166167
$(SUBREF topology, subSlices)
167168
$(SUBREF topology, triplets)
168169
$(SUBREF topology, universal)
170+
$(SUBREF topology, unsqueeze)
169171
$(SUBREF topology, unzip)
170172
$(SUBREF topology, windows)
171173
$(SUBREF topology, zip)

source/mir/ndslice/topology.d

+170-3
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ $(TR $(TH Function Name) $(TH Description))
6767
6868
$(T2 blocks, n-dimensional slice composed of n-dimensional non-overlapping blocks. If the slice has two dimensions, it is a block matrix.)
6969
$(T2 diagonal, 1-dimensional slice composed of diagonal elements)
70-
$(T2 reshape, New slice with changed dimensions for the same data)
70+
$(T2 reshape, New slice view with changed dimensions)
71+
$(T2 squeeze, New slice view of an n-dimensional slice with dimension removed)
72+
$(T2 unsqueeze, New slice view of an n-dimensional slice with a dimension added)
7173
$(T2 windows, n-dimensional slice of n-dimensional overlapping windows. If the slice has two dimensions, it is a sliding window.)
7274
7375
)
@@ -96,7 +98,7 @@ $(SUBREF slice, Slice.shape), and $(SUBREF slice, Slice.elementCount).
9698
9799
License: $(HTTP boost.org/LICENSE_1_0.txt, Boost License 1.0).
98100
Copyright: Copyright © 2016-, Ilya Yaroshenko
99-
Authors: Ilya Yaroshenko
101+
Authors: Ilya Yaroshenko, Shigeki Karita (original numir code)
100102
101103
Sponsors: Part of this work has been sponsored by $(LINK2 http://symmetryinvestments.com, Symmetry Investments) and Kaleidic Associates.
102104
@@ -3918,7 +3920,7 @@ template byDim(Dimensions...)
39183920
n-dimensional slice ipacked to allow iteration by dimension
39193921
+/
39203922
@optmath auto byDim(Iterator, size_t N, SliceKind kind)
3921-
(Slice!(Iterator, N, kind) slice)
3923+
(Slice!(Iterator, N, kind) slice)
39223924
{
39233925
import mir.ndslice.topology : ipack;
39243926
import mir.ndslice.internal : DimensionsCountCTError;
@@ -4240,6 +4242,171 @@ version(mir_test) unittest
42404242
assert(x == slice);
42414243
}
42424244

4245+
/++
4246+
Constructs a new view of an n-dimensional slice with dimension `axis` removed.
4247+
4248+
Throws:
4249+
`AssertError` if the length of the corresponding dimension doesn' equal 1.
4250+
Params:
4251+
axis = dimension to remove, if it is single-dimensional
4252+
slice = n-dimensional slice
4253+
Returns:
4254+
new view of a slice with dimension removed
4255+
See_also: $(LREF unsqueeze), $(LREF iota).
4256+
+/
4257+
Slice!(Iterator, N - 1, kind != Canonical ? kind : axis == 0 ? Universal : N == 2 ? Contiguous : kind)
4258+
squeeze(sizediff_t axis = 0, Iterator, size_t N, SliceKind kind)
4259+
(Slice!(Iterator, N, kind) slice)
4260+
if (-sizediff_t(N) <= axis && axis < sizediff_t(N) && N > 1)
4261+
in {
4262+
assert(slice._lengths[axis < 0 ? N + axis : axis] == 1);
4263+
}
4264+
do {
4265+
import mir.utility: swap;
4266+
enum sizediff_t a = axis < 0 ? N + axis : axis;
4267+
typeof(return) ret;
4268+
foreach (i; 0 .. a)
4269+
ret._lengths[i] = slice._lengths[i];
4270+
foreach (i; a + 1 .. N)
4271+
ret._lengths[i - 1] = slice._lengths[i];
4272+
static if (kind == Universal)
4273+
{
4274+
foreach (i; 0 .. a)
4275+
ret._strides[i] = slice._strides[i];
4276+
foreach (i; a + 1.. N)
4277+
ret._strides[i - 1] = slice._strides[i];
4278+
}
4279+
else
4280+
static if (kind == Canonical)
4281+
{
4282+
static if (a == 0)
4283+
{
4284+
foreach (i; 0 .. N - 1)
4285+
ret._strides[i] = slice._strides[i];
4286+
}
4287+
else
4288+
{
4289+
foreach (i; 0 .. a - 1)
4290+
ret._strides[i] = slice._strides[i];
4291+
foreach (i; a .. N - 1)
4292+
ret._strides[i - 1] = slice._strides[i];
4293+
}
4294+
}
4295+
swap(ret._iterator, slice._iterator);
4296+
return ret;
4297+
}
4298+
4299+
///
4300+
unittest
4301+
{
4302+
import mir.ndslice.topology : iota;
4303+
import mir.ndslice.allocation : slice;
4304+
4305+
// [[0, 1, 2]] -> [0, 1, 2]
4306+
assert([1, 3].iota.squeeze == [3].iota);
4307+
// [[0], [1], [2]] -> [0, 1, 2]
4308+
assert([3, 1].iota.squeeze!1 == [3].iota);
4309+
assert([3, 1].iota.squeeze!(-1) == [3].iota);
4310+
4311+
assert([1, 3].iota.canonical.squeeze == [3].iota);
4312+
assert([3, 1].iota.canonical.squeeze!1 == [3].iota);
4313+
assert([3, 1].iota.canonical.squeeze!(-1) == [3].iota);
4314+
4315+
assert([1, 3].iota.universal.squeeze == [3].iota);
4316+
assert([3, 1].iota.universal.squeeze!1 == [3].iota);
4317+
assert([3, 1].iota.universal.squeeze!(-1) == [3].iota);
4318+
4319+
assert([1, 3, 4].iota.squeeze == [3, 4].iota);
4320+
assert([3, 1, 4].iota.squeeze!1 == [3, 4].iota);
4321+
assert([3, 4, 1].iota.squeeze!(-1) == [3, 4].iota);
4322+
4323+
assert([1, 3, 4].iota.canonical.squeeze == [3, 4].iota);
4324+
assert([3, 1, 4].iota.canonical.squeeze!1 == [3, 4].iota);
4325+
assert([3, 4, 1].iota.canonical.squeeze!(-1) == [3, 4].iota);
4326+
4327+
assert([1, 3, 4].iota.universal.squeeze == [3, 4].iota);
4328+
assert([3, 1, 4].iota.universal.squeeze!1 == [3, 4].iota);
4329+
assert([3, 4, 1].iota.universal.squeeze!(-1) == [3, 4].iota);
4330+
}
4331+
4332+
/++
4333+
Constructs a view of an n-dimensional slice with a dimension added at `axis`. Used
4334+
to unsqueeze a squeezed slice.
4335+
4336+
Params:
4337+
slice = n-dimensional slice
4338+
axis = dimension to be unsqueezed (add new dimension), default values is 0, the first dimension
4339+
Returns:
4340+
unsqueezed n+1-dimensional slice of the same slice kind
4341+
See_also: $(LREF squeeze), $(LREF iota).
4342+
+/
4343+
Slice!(Iterator, N + 1, kind) unsqueeze(Iterator, size_t N, SliceKind kind)
4344+
(Slice!(Iterator, N, kind) slice, sizediff_t axis = 0)
4345+
in {
4346+
assert(-sizediff_t(N + 1) <= axis && axis <= sizediff_t(N));
4347+
}
4348+
do {
4349+
import mir.utility: swap;
4350+
typeof(return) ret;
4351+
if (axis < 0)
4352+
{
4353+
axis += N + 1;
4354+
}
4355+
foreach (i; 0 .. axis)
4356+
ret._lengths[i] = slice._lengths[i];
4357+
ret._lengths[axis] = 1;
4358+
foreach (i; axis .. N)
4359+
ret._lengths[i + 1] = slice._lengths[i];
4360+
static if (kind == Universal)
4361+
{
4362+
foreach (i; 0 .. axis)
4363+
ret._strides[i] = slice._strides[i];
4364+
foreach (i; axis .. N)
4365+
ret._strides[i + 1] = slice._strides[i];
4366+
}
4367+
else
4368+
static if (kind == Canonical)
4369+
{
4370+
if (axis == 0)
4371+
{
4372+
ret._strides[0] = 1;
4373+
foreach (i; 1 .. N)
4374+
ret._strides[i] = slice._strides[i - 1];
4375+
}
4376+
else
4377+
{
4378+
foreach (i; 1 .. axis)
4379+
ret._strides[i - 1] = slice._strides[i - 1];
4380+
foreach (i; axis .. N)
4381+
ret._strides[i + 0] = slice._strides[i - 1];
4382+
}
4383+
}
4384+
swap(ret._iterator, slice._iterator);
4385+
return ret;
4386+
}
4387+
4388+
///
4389+
version (mir_test)
4390+
@safe pure nothrow @nogc
4391+
unittest
4392+
{
4393+
// [0, 1, 2] -> [[0, 1, 2]]
4394+
assert([3].iota.unsqueeze == [1, 3].iota);
4395+
4396+
assert([3].iota.universal.unsqueeze == [1, 3].iota);
4397+
assert([3, 4].iota.unsqueeze == [1, 3, 4].iota);
4398+
assert([3, 4].iota.canonical.unsqueeze == [1, 3, 4].iota);
4399+
assert([3, 4].iota.universal.unsqueeze == [1, 3, 4].iota);
4400+
4401+
// [0, 1, 2] -> [[0], [1], [2]]
4402+
assert([3].iota.unsqueeze(-1) == [3, 1].iota);
4403+
4404+
assert([3].iota.universal.unsqueeze(-1) == [3, 1].iota);
4405+
assert([3, 4].iota.unsqueeze(-1) == [3, 4, 1].iota);
4406+
assert([3, 4].iota.canonical.unsqueeze(-1) == [3, 4, 1].iota);
4407+
assert([3, 4].iota.universal.unsqueeze(-1) == [3, 4, 1].iota);
4408+
}
4409+
42434410
/++
42444411
Field (element's member) projection.
42454412

0 commit comments

Comments
 (0)