Skip to content

Commit f493634

Browse files
committed
fix orthogonalReduceField and member topologies
1 parent 06ab5a5 commit f493634

File tree

2 files changed

+34
-28
lines changed

2 files changed

+34
-28
lines changed

source/mir/ndslice/field.d

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ unittest
534534
}
535535

536536
///
537-
struct OrthogonalReduceField(FieldsIterator, alias fun)
537+
struct OrthogonalReduceField(FieldsIterator, alias fun, T)
538538
{
539539
import mir.ndslice.slice: Slice;
540540

@@ -543,34 +543,36 @@ struct OrthogonalReduceField(FieldsIterator, alias fun)
543543

544544
Slice!FieldsIterator _fields;
545545

546+
///
547+
T _initialValue;
548+
546549
///
547550
auto lightConst()() const @property
548551
{
549552
auto fields = _fields.lightConst;
550-
return OrthogonalReduceField!(fields.Iterator, fun)(fields);
553+
return OrthogonalReduceField!(fields.Iterator, fun, T)(fields, _initialValue);
551554
}
552555

553556
///
554557
auto lightImmutable()() immutable @property
555558
{
556559
auto fields = _fields.lightImmutable;
557-
return OrthogonalReduceField!(fields.Iterator, fun)(fields);
560+
return OrthogonalReduceField!(fields.Iterator, fun, T)(fields, _initialValue);
558561
}
559562

560563
/// `r = fun(r, fields[i][index]);` reduction by `i`
561564
auto opIndex()(size_t index)
562565
{
563566
import std.traits: Unqual;
564567
assert(_fields.length);
565-
auto fields = _fields.lightConst;
566-
Unqual!(typeof(fun(fields.front[index], fields.front[index]))) r = fields.front[index];
567-
for(;;)
568+
auto fields = _fields;
569+
T r = _initialValue;
570+
if (!fields.empty) do
568571
{
572+
r = cast(T) fun(r, fields.front[index]);
569573
fields.popFront;
570-
if (fields.empty)
571-
break;
572-
r = fun(r, fields.front[index]);
573574
}
575+
while(!fields.empty);
574576
return r;
575577
}
576578
}

source/mir/ndslice/topology.d

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4142,6 +4142,19 @@ template member(string name)
41424142
{
41434143
return typeof(return)(slice._lengths, slice._strides, MemberIterator!(Iterator, name)(slice._iterator));
41444144
}
4145+
4146+
/// ditto
4147+
Slice!(MemberIterator!(T*, name)) member(T)(T[] array)
4148+
{
4149+
return typeof(return)(array.length, sizediff_t[0].init, MemberIterator!(T*, name)(array.ptr));
4150+
}
4151+
4152+
/// ditto
4153+
auto member(T)(auto ref T withAsSlice)
4154+
if (hasAsSlice!T)
4155+
{
4156+
return member!name(withAsSlice.asSlice);
4157+
}
41454158
}
41464159

41474160
///
@@ -4181,9 +4194,6 @@ template member(string name)
41814194
assert(matrix.member!"y" == [2, 3].iota * 2);
41824195
}
41834196

4184-
version(D_Exceptions)
4185-
private immutable orthogonalReduceFieldException = new Exception("orthogonalReduceField: Slice composed of fields must not be empty");
4186-
41874197
/++
41884198
Functional deep-element wise reduce of a slice composed of fields or iterators.
41894199
+/
@@ -4199,29 +4209,22 @@ template orthogonalReduceField(alias fun)
41994209
Returns:
42004210
a lazy field with each element of which is reduced value of element of the same index of all iterators.
42014211
+/
4202-
OrthogonalReduceField!(Iterator, fun) orthogonalReduceField(Iterator)(Slice!Iterator slice)
4212+
OrthogonalReduceField!(Iterator, fun, I) orthogonalReduceField(I, Iterator)(I initialValue, Slice!Iterator slice)
42034213
{
4204-
if (_expect(slice.empty, false))
4205-
{
4206-
version(D_Exceptions)
4207-
throw orthogonalReduceFieldException;
4208-
else
4209-
assert(0);
4210-
}
4211-
return typeof(return)(slice);
4214+
return typeof(return)(slice, initialValue);
42124215
}
42134216

42144217
/// ditto
4215-
auto orthogonalReduceField(T)(T[] array)
4218+
OrthogonalReduceField!(T*, fun, I) orthogonalReduceField(I, T)(I initialValue, T[] array)
42164219
{
4217-
return orthogonalReduceField(array.sliced);
4220+
return orthogonalReduceField(initialValue, array.sliced);
42184221
}
42194222

42204223
/// ditto
4221-
auto orthogonalReduceField(T)(auto ref T withAsSlice)
4224+
auto orthogonalReduceField(I, T)(I initialValue, auto ref T withAsSlice)
42224225
if (hasAsSlice!T)
42234226
{
4224-
return orthogonalReduceField(withAsSlice.asSlice);
4227+
return orthogonalReduceField(initialValue, withAsSlice.asSlice);
42254228
}
42264229
}
42274230
else alias orthogonalReduceField = .orthogonalReduceField!(naryFun!fun);
@@ -4243,11 +4246,12 @@ unittest
42434246
c[len.iota.strided!0(13)][] = true;
42444247

42454248
// this is valid since bitslices above are oroginal slices of allocated memory.
4246-
auto and = [
4249+
auto and =
4250+
orthogonalReduceField!"a & b"(size_t.max, [
42474251
a.iterator._field._field, // get raw data pointers
42484252
b.iterator._field._field,
4249-
c.iterator._field._field]
4250-
.orthogonalReduceField!"a & b" // operation on size_t
4253+
c.iterator._field._field,
4254+
]) // operation on size_t
42514255
.bitwiseField
42524256
.slicedField(len);
42534257

0 commit comments

Comments
 (0)