6ec6ef39ee64 — Nathan Michaels 4 years ago
Cartesian Product of arbitrary iterators.
1 files changed, 51 insertions(+), 39 deletions(-)

M src/iterators.zig
M src/iterators.zig +51 -39
@@ 182,65 182,75 @@ pub fn SliceIter(comptime T: type) type 
     };
 }
 
-pub fn Product(comptime T: type) type {
+/// Get the cartesian product of iterators. All iterators' next
+/// methods must return the same type, passed in as T. The number of
+/// iterators to process is passed as num.
+///
+/// Sample usage:
+/// var a = try Range(u32, 1).init(0, 2);
+/// var b = try Range(u32, 2).init(4, 8);
+/// var iterators = [_]*Iterator(u32){ &a.iterator, &b.iterator };
+/// var product = Product(u32, 2).init(iterators);
+/// var iter = &product.iterator;
+/// while (iter.next()) |vals| {
+///     // vals will have these values: {0, 4}, {1, 4}, {0, 6}, {1, 6}
+/// }
+pub fn Product(comptime T: type, comptime num: usize) type {
     return struct {
         const Self = @This();
         // This would be better as a tuple. Note for when we get tuple
-        // types.
-        const Result = struct { a: T, b: T };
+        // types. Then the iterators being multiplied don't need to
+        // all return the same type.
+        const Result = [num]T;
 
         iterator: Iterator(Result),
-        a: *Iterator(T),
-        b: *Iterator(T),
-        next_a: T,
-        next_b: T,
+        children: [num]*Iterator(T),
+        next: [num]T,
         done: bool = false,
 
         pub fn next(iterator: *Iterator(Result)) ?Result {
             const self = @fieldParentPtr(Self, "iterator", iterator);
             if (self.done)
                 return null;
-            const prev_a = self.next_a;
-            const prev_b = self.next_b;
-            if (self.a.next()) |aa| {
-                self.next_a = aa;
-            } else {
-                self.a.reset();
-                if (self.a.next()) |aa| {
-                    self.next_a = aa;
+            const prev: Result = self.next;
+            for (self.children) |child, idx| {
+                if (child.next()) |val| {
+                    self.next[idx] = val;
+                    self.done = false;
+                    break;
                 } else {
-                    // a iterator is empty.
-                    return null;
+                    child.reset();
+                    self.next[idx] = child.next().?;
                 }
-                if (self.b.next()) |bb| {
-                    self.next_b = bb;
-                } else {
-                    // b iterator has been exhausted; all done
-                    self.done = true;
-                }
+            } else {
+                self.done = true;
             }
-            return Result{ .a = prev_a, .b = prev_b };
+            return prev;
         }
 
         pub fn reset(iterator: *Iterator(Result)) void {
             const self = @fieldParentPtr(Self, "iterator", iterator);
-            self.a.reset();
-            self.b.reset();
+            for (self.children) |child| {
+                child.reset();
+            }
         }
 
-        pub fn init(a: *Iterator(T), b: *Iterator(T)) Self {
-            const prev_a = a.next().?;
-            const prev_b = b.next().?;
-            return Self{
-                .a = a,
-                .b = b,
-                .next_a = prev_a,
-                .next_b = prev_b,
+        /// Invokes safety-checked illegal behavior if any iterators
+        /// are empty.
+        pub fn init(args: [num]*Iterator(T)) Self {
+            var rv = Self{
                 .iterator = Iterator(Result){
                     .nextFn = next,
                     .resetFn = reset,
                 },
+                .children = args,
+                .next = undefined,
             };
+
+            for (args) |iter, idx| {
+                rv.next[idx] = iter.next().?;
+            }
+            return rv;
         }
     };
 }

          
@@ 311,21 321,23 @@ test "reverse" {
 test "product" {
     const R = Range(u32, 1);
     var a = try R.init(0, 10);
-    var b = try R.init(10, 20);
-    var product = Product(u32).init(&a.iterator, &b.iterator);
+    var b = try R.init(10, 25);
+    var iterators = [_]*Iterator(u32){ &a.iterator, &b.iterator };
+    var product = Product(u32, 2).init(iterators);
+
     var pi = &product.iterator;
 
     var correct_a: u32 = 0;
     var correct_b: u32 = 10;
     while (pi.next()) |ab| {
-        testing.expectEqual(ab.a, correct_a);
-        testing.expectEqual(ab.b, correct_b);
+        testing.expectEqual(ab[0], correct_a);
+        testing.expectEqual(ab[1], correct_b);
         correct_a += 1;
         if (correct_a == 10) {
             correct_a = 0;
             correct_b += 1;
         }
     }
-    testing.expectEqual(correct_b, 20);
+    testing.expectEqual(correct_b, 25);
     testing.expectEqual(pi.next(), null);
 }