a6a2b0f0fb73 — Nathan Michaels 5 years ago
Fix the stream API.
3 files changed, 96 insertions(+), 72 deletions(-)

M .hgignore
M build.zig
M src/secretstream.zig
M .hgignore +2 -0
@@ 1,2 1,4 @@ 
 syntax: glob
 zig-cache
+*.a
+*.o

          
M build.zig +1 -1
@@ 1,4 1,5 @@ 
 const Builder = @import("std").build.Builder;
+const std = @import("std");
 
 pub fn build(b: *Builder) void {
     const mode = b.standardReleaseOptions();

          
@@ 6,7 7,6 @@ pub fn build(b: *Builder) void {
     lib.setBuildMode(mode);
     lib.install();
 
-
     var tests = b.addTest("src/sodium.zig");
     tests.setBuildMode(mode);
     tests.linkSystemLibrary("c");

          
M src/secretstream.zig +93 -71
@@ 30,7 30,7 @@ pub fn keygen(key: *[KEYBYTES]u8) void {
 
 /// Initialize state and header with key for writing. Call before
 /// trying to encrypt things.
-pub fn init_push(
+pub fn initPush(
     state: *StreamState,
     header: *[HEADERBYTES]u8,
     key: *const [KEYBYTES]u8,

          
@@ 49,7 49,7 @@ pub fn push(
     message: []const u8,
     additional_data: ?[]const u8,
     tag: Tag,
-) !void {
+) !c_ulonglong {
     var clen: c_ulonglong = undefined;
     // ciphertext length is guaranteed to always be mlen +
     // crypto_secretstream_xchacha20poly1305_ABYTES, so let's make

          
@@ 74,11 74,12 @@ pub fn push(
     if (res != 0) {
         return SodiumError.EncryptError;
     }
+    return clen;
 }
 
 /// Initialize state and header with key for reading. Call before
 /// trying to decrypt things.
-pub fn init_pull(
+pub fn initPull(
     state: *StreamState,
     header: *const [HEADERBYTES]u8,
     key: *const [KEYBYTES]u8,

          
@@ 97,7 98,7 @@ pub fn pull(
     tag: ?*Tag,
     ciphertext: []const u8,
     additional_data: ?[]const u8,
-) !void {
+) !c_ulonglong {
     if (message.len < ciphertext.len - ABYTES) {
         return SodiumError.BufferTooSmall;
     }

          
@@ 118,41 119,37 @@ pub fn pull(
     if (mlen > message.len) {
         return SodiumError.BufferTooSmall;
     }
+    return mlen;
 }
 
 /// Encrypt fixed-sized chunks of data, as many as you
 /// want. Initialize with init, then push until you run out of data to
-/// encrypt. Ciphertext will end up in a dynamically allocated
-/// ArrayList called data. Pass these (in order) along with the header
-/// stored in the hdr member through a ChunkDecrypter to get them
+/// encrypt. Ciphertext will end up in a stream passed to the
+/// initializer. Pass these to a ChunkDecrypter to get the plaintext
 /// back.
-///
-/// A better way to do this would be to have it accept an output
-/// stream that it can just dump stuff to. That's on the TODO list.
-pub fn ChunkEncrypter(chunk_size: usize) type {
+pub fn ChunkEncrypter(chunk_size: usize, comptime StreamType: type) type {
     return struct {
         const Self = @This();
         key: [KEYBYTES]u8,
         hdr: [HEADERBYTES]u8,
         state: StreamState,
-        allocator: *Allocator,
-        data: ArrayList([chunk_size + ABYTES]u8),
+        out: StreamType,
 
         /// Create a new encrypter. Push chunks to encrypt.
-        pub fn init(allocator: *Allocator, key: [KEYBYTES]u8) !Self {
+        pub fn init(out_stream: StreamType, key: [KEYBYTES]u8) !Self {
             var st = Self{
                 .key = key,
-                .allocator = allocator,
+                .out = out_stream,
                 .hdr = undefined,
                 .state = undefined,
-                .data = ArrayList([chunk_size + ABYTES]u8).init(allocator),
             };
-            try init_push(&st.state, &st.hdr, &key);
+            try initPush(&st.state, &st.hdr, &key);
+            try out_stream.writeAll(&st.hdr);
             return st;
         }
 
         /// Call with data to encrypt it.
-        pub fn push_chunk(self: *Self, msg: []const u8) !void {
+        pub fn pushChunk(self: *Self, msg: []const u8) !void {
             if (msg.len > chunk_size) {
                 return SodiumError.ChunkTooBig;
             }

          
@@ 163,61 160,76 @@ pub fn ChunkEncrypter(chunk_size: usize)
                 m[idx] = val;
             }
             var ctxt: [chunk_size + ABYTES]u8 = undefined;
-            try push(&self.state, ctxt[0..], m[0..], null, Tag.MESSAGE);
-            try self.data.append(ctxt);
+            const clen = try push(
+                &self.state,
+                ctxt[0..],
+                msg,
+                null,
+                Tag.MESSAGE,
+            );
+            try self.out.writeAll(ctxt[0..clen]);
         }
 
         /// Free up all held resources.
-        pub fn deinit(self: *Self) void {
-            self.data.deinit();
-        }
+        pub fn deinit(self: *Self) void {}
     };
 }
 
 /// ChunkEncrypter's buddy. Initialize with the same key and header
 /// from one, and use it to decrypt chunks.
-pub fn ChunkDecrypter(chunk_size: usize) type {
+pub fn ChunkDecrypter(
+    chunk_size: usize,
+    comptime InStreamType: type,
+    comptime OutStreamType: type,
+) type {
     return struct {
         const Self = @This();
         key: [KEYBYTES]u8,
         hdr: [HEADERBYTES]u8,
         state: StreamState,
-        allocator: *Allocator,
-        data: ArrayList([chunk_size]u8),
+        in_stream: InStreamType,
+        out_stream: OutStreamType,
 
         /// Create a new decrypter. Use the header as read from the
         /// beginning of the ciphertext.
         pub fn init(
-            allocator: *Allocator,
+            in_stream: InStreamType,
+            out_stream: OutStreamType,
             key: [KEYBYTES]u8,
-            hdr: [HEADERBYTES]u8,
         ) !Self {
             var st = Self{
+                .in_stream = in_stream,
+                .out_stream = out_stream,
                 .key = key,
-                .allocator = allocator,
-                .hdr = hdr,
+                .hdr = undefined,
                 .state = undefined,
-                .data = ArrayList([chunk_size]u8).init(allocator),
             };
-            try init_pull(&st.state, &st.hdr, &key);
+            try st.in_stream.readNoEof(&st.hdr);
+            try initPull(&st.state, &st.hdr, &key);
             return st;
         }
 
-        /// Decrypt chunks, one at a time. Put decrypted data in
-        /// self.data.
-        pub fn pull_chunk(
-            self: *Self,
-            ciphertext: [chunk_size + ABYTES]u8,
-        ) !void {
+        /// Decrypt chunks, one at a time. Returns null on the last
+        /// chunk.
+        pub fn pullChunk(self: *Self) !?void {
+            var ciphertext: [chunk_size + ABYTES]u8 = undefined;
             var msg: [chunk_size]u8 = undefined;
-            try pull(&self.state, msg[0..], null, ciphertext[0..], null);
-            try self.data.append(msg);
+            const clen = try self.in_stream.readAll(&ciphertext);
+            const mlen = try pull(
+                &self.state,
+                &msg,
+                null,
+                ciphertext[0..clen],
+                null,
+            );
+            try self.out_stream.writeAll(msg[0..mlen]);
+            if (clen < ciphertext.len) {
+                return null;
+            }
         }
 
         /// Free all held resources.
-        pub fn deinit(self: *Self) void {
-            self.data.deinit();
-        }
+        pub fn deinit(self: *Self) void {}
     };
 }
 

          
@@ 232,56 244,66 @@ test "stream" {
     var state: StreamState = undefined;
 
     keygen(&key);
-    try init_push(&state, &hdr, &key);
-    try push(&state, ciphertext[0..], msg[0..], null, Tag.MESSAGE);
-    try push(&state, ciphertext2[0..], msg2[0..], null, Tag.MESSAGE);
+    try initPush(&state, &hdr, &key);
+    const cl = try push(&state, ciphertext[0..], msg[0..], null, Tag.MESSAGE);
+    const c2 = try push(&state, ciphertext2[0..], msg2[0..], null, Tag.MESSAGE);
 
     var clear: [ciphertext.len - ABYTES]u8 = undefined;
     var clear2: [ciphertext2.len - ABYTES]u8 = undefined;
-    try init_pull(&state, &hdr, &key);
-    try pull(&state, clear[0..], null, ciphertext[0..], null);
-    try pull(&state, clear2[0..], null, ciphertext2[0..], null);
+    try initPull(&state, &hdr, &key);
+    const len = try pull(&state, clear[0..], null, ciphertext[0..cl], null);
+    testing.expectEqual(msg.len, len);
+    const len2 = try pull(&state, clear2[0..], null, ciphertext2[0..c2], null);
+    testing.expectEqual(msg2.len, len2);
     testing.expectEqualSlices(u8, msg[0..], clear[0..]);
     testing.expectEqualSlices(u8, msg2[0..], clear2[0..]);
 }
 
 test "chunks" {
     try sodium.init();
-    const msg = "This message is longer than my chunk size.";
+    const msg = "This message is longer than my chunk size!!";
     const malloc = std.heap.c_allocator;
     var key: [KEYBYTES]u8 = undefined;
 
     keygen(&key);
     const chunk_size = 4;
-    const Encrypter = ChunkEncrypter(chunk_size);
-    var encrypter = try Encrypter.init(malloc, key);
+    var buf: [chunk_size * 1024]u8 = undefined;
+    var stream = std.io.fixedBufferStream(&buf);
+    const out = stream.outStream();
+    const Encrypter = ChunkEncrypter(chunk_size, @TypeOf(out));
+    var encrypter = try Encrypter.init(out, key);
     defer encrypter.deinit();
     var start: usize = 0;
+    var chunk_count: usize = 0;
     while (start < msg.len) : (start += chunk_size) {
         const off = start + chunk_size;
         const end = if (off < msg.len) off else msg.len;
-        try encrypter.push_chunk(msg[start..end]);
-    }
-
-    const Decrypter = ChunkDecrypter(chunk_size);
-    var decrypter = try Decrypter.init(malloc, key, encrypter.hdr);
-    defer decrypter.deinit();
-
-    for (encrypter.data.items) |chunk| {
-        try decrypter.pull_chunk(chunk);
+        try encrypter.pushChunk(msg[start..end]);
+        chunk_count += 1;
     }
 
-    const decrypted = try malloc.alloc(u8, decrypter.data.items.len *
-        chunk_size);
-    defer malloc.free(decrypted);
+    const cipherlen = try stream.getPos();
+    var cipherstream = std.io.fixedBufferStream(buf[0..cipherlen]).inStream();
+    var decrypted: [chunk_size * 1024]u8 = undefined;
+    var cleartext = std.io.fixedBufferStream(&decrypted);
+    var clearstream = cleartext.outStream();
+    const Decrypter = ChunkDecrypter(
+        chunk_size,
+        @TypeOf(cipherstream),
+        @TypeOf(clearstream),
+    );
+    var decrypter = try Decrypter.init(cipherstream, clearstream, key);
+    defer decrypter.deinit();
 
-    for (decrypter.data.items) |chunk, off| {
-        const base: usize = off * chunk_size;
-        var idx: usize = 0;
-        while (idx < chunk_size) : (idx += 1) {
-            decrypted[idx + base] = chunk[idx];
-        }
+    while (chunk_count > 0) {
+        decrypter.pullChunk() catch |err| {
+            std.debug.warn("Unexpected error: {}\n", .{err});
+            return err;
+        } orelse break;
+        chunk_count -= 1;
     }
 
-    testing.expectEqualSlices(u8, msg[0..], decrypted[0..msg.len]);
+    const len = try cleartext.getPos();
+    testing.expectEqual(len, msg.len);
+    testing.expectEqualSlices(u8, msg[0..], decrypted[0..len]);
 }