feat: Decoding + tests + better error checking
2 files changed, 245 insertions(+), 21 deletions(-)

M src/baseenc.c
M src/baseenc.zig
M src/baseenc.c +22 -0
@@ 22,8 22,30 @@ static Janet Baseenc_base64Encode(int ar
   return janet_wrap_buffer(buf);
 }
 
+static Janet Baseenc_base64Decode(int argc, Janet argv[]) {
+  janet_arity(argc, 1, 2);
+  JanetByteView data = janet_getbytes(argv, 0);
+  JanetBuffer *buf;
+  if (argc > 1) {
+    buf = janet_getbuffer(argv, 1);
+  } else {
+    buf = janet_buffer((data.len + 3) / 4 * 3);
+  }
+
+  ssize_t len = base_64_decode(&data.bytes[0], (size_t)data.len,
+    &buf->data[0], (size_t)buf->capacity);
+  if (len == -1) {
+    janet_panic("buffer overflow");
+  } else if (len == -2) {
+    janet_panic("invalid base64 string");
+  }
+  buf->count = (int32_t)len;
+  return janet_wrap_buffer(buf);
+}
+
 static const JanetReg Baseenc_cfuns[] = {
   { "base64-encode", Baseenc_base64Encode, NULL },
+  { "base64-decode", Baseenc_base64Decode, NULL },
   { NULL, NULL, NULL },
 };
 

          
M src/baseenc.zig +223 -21
@@ 2,6 2,7 @@ const std = @import("std");
 
 const EncDecError = error{
     BufferOverflow,
+    InvalidEncoding,
 };
 
 const Base64 = struct {

          
@@ 10,37 11,159 @@ const Base64 = struct {
         "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ++
         "abcdefghijklmnopqrstuvwxyz" ++
         "0123456789+/";
+    const reverse_alphabet = blk: {
+        var rev: [256]u8 = .{255} ** 256;
+        for (alphabet) |l, i| {
+            rev[l] = i;
+        }
+        rev[pad] = 0;
+        break :blk rev;
+    };
+
+    inline fn encode_mask(out: *[4]u8, a: u8, b: u8, c: u8) void {
+        //const ai = @intCast(u32, a);
+        //const bi = @intCast(u32, b);
+        //const ci = @intCast(u32, c);
+        out[0] = alphabet[a >> 2];
+        out[1] = alphabet[(a << 4 & 0b110000) | (b >> 4 & 0b001111)];
+        out[2] = alphabet[(b << 2 & 0b111100) | (c >> 6 & 0b000011)];
+        out[3] = alphabet[c & 0b111111];
+    }
     fn encode(in: []const u8, out: []u8) EncDecError!usize {
+        // Pre-check sizes to avoid unnecessary branching.
+        const sym = @divTrunc(in.len + 2, 3) * 4;
+        if (out.len < sym) {
+            return error.BufferOverflow;
+        }
+
+        const odd = @mod(in.len, 3);
+
+        // Encode triplets as long as branching isn't needed.
+        const limit = in.len - odd;
         var i = @as(usize, 0);
         var j = @as(usize, 0);
-        while (i < in.len) {
-            if (j + 4 > out.len) return EncDecError.BufferOverflow;
-            const a = @intCast(u32, in[i]);
-            const b = if (i + 1 < in.len) @intCast(u32, in[i + 1]) else 0;
-            const c = if (i + 2 < in.len) @intCast(u32, in[i + 2]) else 0;
-            out[j] = alphabet[a >> 2];
-            out[j + 1] = alphabet[(a << 4 & 0b110000) | (b >> 4 & 0b001111)];
-            out[j + 2] = alphabet[(b << 2 & 0b111100) | (c >> 6 & 0b000011)];
-            out[j + 3] = alphabet[c & 0b111111];
-            if (i + 1 >= in.len) {
+        while (i < limit) : ({
+            i += 3;
+            j += 4;
+        }) {
+            encode_mask(out[j..][0..4], in[i], in[i + 1], in[i + 2]);
+        }
+
+        // Special case the remaining lengths.
+        switch (odd) {
+            0 => {}, // noop
+            1 => {
+                encode_mask(out[j..][0..4], in[i], 0, 0);
                 out[j + 2] = pad;
                 out[j + 3] = pad;
-            } else if (i + 2 >= in.len) {
+            },
+            2 => {
+                encode_mask(out[j..][0..4], in[i], in[i + 1], 0);
                 out[j + 3] = pad;
+            },
+            else => unreachable,
+        }
+
+        return sym;
+    }
+
+    /// Presumes that at most only the bottom 6 bits of a, b, c and d are set.
+    inline fn decode_unmask(tgt: *[3]u8, a: u8, b: u8, c: u8, d: u8) void {
+        // 0                   1                   2
+        // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3
+        // <=============> <=============> <=============>
+        // <=========> <=========> <=========> <=========>
+        tgt[0] = (a << 2) | (b >> 4 & 0b0000_0011);
+        tgt[1] = (b << 4 & 0b1111_0000) | (c >> 2 & 0b0000_1111);
+        tgt[2] = (c << 6 & 0b1100_0000) | d;
+    }
+    fn decode(in: []const u8, out: []u8) EncDecError!usize {
+        if (in.len == 0) {
+            return 0;
+        }
+
+        var out_len = 3 * @divTrunc(in.len + 3, 4);
+        var padding = @as(usize, 0);
+        var i = in.len - 1;
+        while (i > 0) : (i -= 1) {
+            if (in[i] == pad) {
+                padding += 1;
+            } else {
+                break;
             }
+        }
+        if (padding == 1) {
+            out_len -= 1;
+        } else if (padding == 2) {
+            out_len -= 2;
+        } else if (padding >= 3) {
+            out_len -= 3;
+        }
+        if (out.len < out_len) return error.BufferOverflow;
 
-            i += 3;
-            j += 4;
+        i = 0;
+        var j = @as(usize, 0);
+
+        const in_len = in.len - padding;
+        const odd = @mod(in_len, 4);
+        const limit = in_len - odd;
+        while (i < limit) : ({
+            i += 4;
+            j += 3;
+        }) {
+            const a = reverse_alphabet[in[i]];
+            const b = reverse_alphabet[in[i + 1]];
+            const c = reverse_alphabet[in[i + 2]];
+            const d = reverse_alphabet[in[i + 3]];
+            // Normally only the bottom 6 bits are set. The rest are used to
+            // indicate invalid bytes.
+            if ((a | b | c | d) & 0x80 != 0) return error.InvalidEncoding;
+
+            decode_unmask(out[j..][0..3], a, b, c, d);
+        }
+
+        var tmp: [3]u8 = undefined;
+        switch (odd) {
+            0 => {
+                // Strip padding, if present.
+                while (true) {
+                    i -= 4;
+                    if (in[i + 1] == pad and in[i + 2] == pad and in[i + 3] == pad) {
+                        j -= 3;
+                    } else if (in[i + 2] == pad and in[i + 3] == pad) {
+                        j -= 2;
+                        break;
+                    } else if (in[i + 3] == pad) {
+                        j -= 1;
+                        break;
+                    } else {
+                        break;
+                    }
+                }
+            }, // noop
+            1 => {}, // noop -- can't assemble a whole byte?
+            2 => {
+                const a = reverse_alphabet[in[i]];
+                const b = reverse_alphabet[in[i + 1]];
+                if ((a | b) & 0x80 != 0) return error.InvalidEncoding;
+                decode_unmask(&tmp, a, b, 0, 0);
+                std.mem.copy(u8, out[j .. j + 1], tmp[0..1]);
+                j += 1;
+            },
+            3 => {
+                const a = reverse_alphabet[in[i]];
+                const b = reverse_alphabet[in[i + 1]];
+                const c = reverse_alphabet[in[i + 2]];
+                if ((a | b | c) & 0x80 != 0) return error.InvalidEncoding;
+                decode_unmask(&tmp, a, b, c, 0);
+                std.mem.copy(u8, out[j .. j + 2], tmp[0..2]);
+                j += 2;
+            },
+            else => unreachable,
         }
 
         return j;
     }
-
-    fn decode(in: []const u8, out: []u8) EncDecError!usize {
-        _ = in;
-        _ = out;
-        return EncDecError.BufferOverflow;
-    }
 };
 
 export fn base_64_encode(in_buf: [*]const u8, in_len: usize, out_buf: [*]u8, out_len: usize) isize {

          
@@ 48,7 171,8 @@ export fn base_64_encode(in_buf: [*]cons
     const out = out_buf[0..out_len];
 
     return @intCast(isize, Base64.encode(in, out) catch |err| switch (err) {
-        EncDecError.BufferOverflow => return -1,
+        error.BufferOverflow => return -1,
+        error.InvalidEncoding => unreachable,
     });
 }
 

          
@@ 57,6 181,84 @@ export fn base_64_decode(in_buf: [*]cons
     const out = out_buf[0..out_len];
 
     return @intCast(isize, Base64.decode(in, out) catch |err| switch (err) {
-        EncDecError.BufferOverflow => return -1,
+        error.BufferOverflow => return -1,
+        error.InvalidEncoding => return -2,
     });
 }
+
+fn hexDigitToNybble(digit: u8) u4 {
+    if (digit >= '0' and digit <= '9') {
+        return @intCast(u4, digit - '0');
+    } else if (digit >= 'A' and digit <= 'F') {
+        return @intCast(u4, 0xA + digit - 'A');
+    } else if (digit >= 'a' and digit <= 'f') {
+        return @intCast(u4, 0xA + digit - 'a');
+    } else {
+        unreachable; // invalid hex digit
+    }
+}
+fn fromHex(comptime src: []const u8) []const u8 {
+    if (src.len % 2 != 0) {
+        @compileError("hex string length not even");
+    }
+    var raw: [src.len >> 1]u8 = undefined;
+    var i = @as(usize, 0);
+    while (i < src.len) : (i += 2) {
+        raw[i >> 1] = @intCast(u8, hexDigitToNybble(src[i])) << 4 |
+            @intCast(u8, hexDigitToNybble(src[i + 1]));
+    }
+    return raw[0..];
+}
+const alloc = std.testing.allocator;
+const expectEqualSlices = std.testing.expectEqualSlices;
+test "base64 encode" {
+    for ([_][2][]const u8{
+        .{ "", "" },
+        .{ ([_]u8{0})[0..] ** 1, "AA==" },
+        .{ ([_]u8{0})[0..] ** 2, "AAA=" },
+        .{ ([_]u8{0})[0..] ** 3, "AAAA" },
+        .{ "hey", "aGV5" },
+        .{ "Discard medicine more than two years old.", "RGlzY2FyZCBtZWRpY2luZSBtb3JlIHRoYW4gdHdvIHllYXJzIG9sZC4=" },
+    }) |c| {
+        const vector = c[0];
+        const expected = c[1];
+        const max_len = @maximum(expected.len, 1);
+        const actual = try alloc.alloc(u8, max_len);
+        defer alloc.free(actual);
+        const real_len = base_64_encode(vector.ptr, vector.len, actual.ptr, max_len);
+        if (real_len < 0) {
+            std.debug.print("base64 encode: buffer overflow for {s}\n", .{vector});
+            return error.BufferOverflow;
+        }
+
+        try expectEqualSlices(u8, expected, actual[0..@intCast(usize, real_len)]);
+    }
+}
+
+test "base64 decode" {
+    for ([_][2][]const u8{
+        .{ "", "" },
+        .{ "AA==", &.{0} },
+        .{ "AAA=", &.{ 0, 0 } },
+        .{ "AAAA", &.{ 0, 0, 0 } },
+        .{ "AAAAA===", &.{ 0, 0, 0 } },
+        .{ "AAAAAA==", &.{ 0, 0, 0, 0 } },
+        .{ "AAAAAAA=", &.{ 0, 0, 0, 0, 0 } },
+        .{ "AAAAAAAA", &.{ 0, 0, 0, 0, 0, 0 } },
+        .{ "aGV5", "hey" },
+        .{ "RGlzY2FyZCBtZWRpY2luZSBtb3JlIHRoYW4gdHdvIHllYXJzIG9sZC4=", "Discard medicine more than two years old." },
+    }) |c| {
+        const vector = c[0];
+        const expected = c[1];
+        const max_len = @maximum(expected.len, 1);
+        const actual = try alloc.alloc(u8, max_len);
+        defer alloc.free(actual);
+        const real_len = base_64_decode(vector.ptr, vector.len, actual.ptr, max_len);
+        if (real_len < 0) {
+            std.debug.print("base64 decode: failure {} for {s}\n", .{ real_len, vector });
+            return error.TestFailed;
+        }
+
+        try expectEqualSlices(u8, expected, actual[0..@intCast(usize, real_len)]);
+    }
+}