# HG changeset patch # User paulsnar # Date 1657543992 -10800 # Mon Jul 11 15:53:12 2022 +0300 # Node ID 349ccd61c7ab35a23f69a1d78a2ed5fb56c8e13c # Parent 4f1a0d9faabc9b81646163ce9e12eedbfe29a61e feat: Decoding + tests + better error checking diff --git a/src/baseenc.c b/src/baseenc.c --- a/src/baseenc.c +++ b/src/baseenc.c @@ -22,8 +22,30 @@ return janet_wrap_buffer(buf); } +static Janet Baseenc_base64Decode(int argc, Janet argv[]) { + janet_arity(argc, 1, 2); + JanetByteView data = janet_getbytes(argv, 0); + JanetBuffer *buf; + if (argc > 1) { + buf = janet_getbuffer(argv, 1); + } else { + buf = janet_buffer((data.len + 3) / 4 * 3); + } + + ssize_t len = base_64_decode(&data.bytes[0], (size_t)data.len, + &buf->data[0], (size_t)buf->capacity); + if (len == -1) { + janet_panic("buffer overflow"); + } else if (len == -2) { + janet_panic("invalid base64 string"); + } + buf->count = (int32_t)len; + return janet_wrap_buffer(buf); +} + static const JanetReg Baseenc_cfuns[] = { { "base64-encode", Baseenc_base64Encode, NULL }, + { "base64-decode", Baseenc_base64Decode, NULL }, { NULL, NULL, NULL }, }; diff --git a/src/baseenc.zig b/src/baseenc.zig --- a/src/baseenc.zig +++ b/src/baseenc.zig @@ -2,6 +2,7 @@ const EncDecError = error{ BufferOverflow, + InvalidEncoding, }; const Base64 = struct { @@ -10,37 +11,159 @@ "ABCDEFGHIJKLMNOPQRSTUVWXYZ" ++ "abcdefghijklmnopqrstuvwxyz" ++ "0123456789+/"; + const reverse_alphabet = blk: { + var rev: [256]u8 = .{255} ** 256; + for (alphabet) |l, i| { + rev[l] = i; + } + rev[pad] = 0; + break :blk rev; + }; + + inline fn encode_mask(out: *[4]u8, a: u8, b: u8, c: u8) void { + //const ai = @intCast(u32, a); + //const bi = @intCast(u32, b); + //const ci = @intCast(u32, c); + out[0] = alphabet[a >> 2]; + out[1] = alphabet[(a << 4 & 0b110000) | (b >> 4 & 0b001111)]; + out[2] = alphabet[(b << 2 & 0b111100) | (c >> 6 & 0b000011)]; + out[3] = alphabet[c & 0b111111]; + } fn encode(in: []const u8, out: []u8) EncDecError!usize { + // Pre-check sizes to avoid unnecessary branching. + const sym = @divTrunc(in.len + 2, 3) * 4; + if (out.len < sym) { + return error.BufferOverflow; + } + + const odd = @mod(in.len, 3); + + // Encode triplets as long as branching isn't needed. + const limit = in.len - odd; var i = @as(usize, 0); var j = @as(usize, 0); - while (i < in.len) { - if (j + 4 > out.len) return EncDecError.BufferOverflow; - const a = @intCast(u32, in[i]); - const b = if (i + 1 < in.len) @intCast(u32, in[i + 1]) else 0; - const c = if (i + 2 < in.len) @intCast(u32, in[i + 2]) else 0; - out[j] = alphabet[a >> 2]; - out[j + 1] = alphabet[(a << 4 & 0b110000) | (b >> 4 & 0b001111)]; - out[j + 2] = alphabet[(b << 2 & 0b111100) | (c >> 6 & 0b000011)]; - out[j + 3] = alphabet[c & 0b111111]; - if (i + 1 >= in.len) { + while (i < limit) : ({ + i += 3; + j += 4; + }) { + encode_mask(out[j..][0..4], in[i], in[i + 1], in[i + 2]); + } + + // Special case the remaining lengths. + switch (odd) { + 0 => {}, // noop + 1 => { + encode_mask(out[j..][0..4], in[i], 0, 0); out[j + 2] = pad; out[j + 3] = pad; - } else if (i + 2 >= in.len) { + }, + 2 => { + encode_mask(out[j..][0..4], in[i], in[i + 1], 0); out[j + 3] = pad; + }, + else => unreachable, + } + + return sym; + } + + /// Presumes that at most only the bottom 6 bits of a, b, c and d are set. + inline fn decode_unmask(tgt: *[3]u8, a: u8, b: u8, c: u8, d: u8) void { + // 0 1 2 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 + // <=============> <=============> <=============> + // <=========> <=========> <=========> <=========> + tgt[0] = (a << 2) | (b >> 4 & 0b0000_0011); + tgt[1] = (b << 4 & 0b1111_0000) | (c >> 2 & 0b0000_1111); + tgt[2] = (c << 6 & 0b1100_0000) | d; + } + fn decode(in: []const u8, out: []u8) EncDecError!usize { + if (in.len == 0) { + return 0; + } + + var out_len = 3 * @divTrunc(in.len + 3, 4); + var padding = @as(usize, 0); + var i = in.len - 1; + while (i > 0) : (i -= 1) { + if (in[i] == pad) { + padding += 1; + } else { + break; } + } + if (padding == 1) { + out_len -= 1; + } else if (padding == 2) { + out_len -= 2; + } else if (padding >= 3) { + out_len -= 3; + } + if (out.len < out_len) return error.BufferOverflow; - i += 3; - j += 4; + i = 0; + var j = @as(usize, 0); + + const in_len = in.len - padding; + const odd = @mod(in_len, 4); + const limit = in_len - odd; + while (i < limit) : ({ + i += 4; + j += 3; + }) { + const a = reverse_alphabet[in[i]]; + const b = reverse_alphabet[in[i + 1]]; + const c = reverse_alphabet[in[i + 2]]; + const d = reverse_alphabet[in[i + 3]]; + // Normally only the bottom 6 bits are set. The rest are used to + // indicate invalid bytes. + if ((a | b | c | d) & 0x80 != 0) return error.InvalidEncoding; + + decode_unmask(out[j..][0..3], a, b, c, d); + } + + var tmp: [3]u8 = undefined; + switch (odd) { + 0 => { + // Strip padding, if present. + while (true) { + i -= 4; + if (in[i + 1] == pad and in[i + 2] == pad and in[i + 3] == pad) { + j -= 3; + } else if (in[i + 2] == pad and in[i + 3] == pad) { + j -= 2; + break; + } else if (in[i + 3] == pad) { + j -= 1; + break; + } else { + break; + } + } + }, // noop + 1 => {}, // noop -- can't assemble a whole byte? + 2 => { + const a = reverse_alphabet[in[i]]; + const b = reverse_alphabet[in[i + 1]]; + if ((a | b) & 0x80 != 0) return error.InvalidEncoding; + decode_unmask(&tmp, a, b, 0, 0); + std.mem.copy(u8, out[j .. j + 1], tmp[0..1]); + j += 1; + }, + 3 => { + const a = reverse_alphabet[in[i]]; + const b = reverse_alphabet[in[i + 1]]; + const c = reverse_alphabet[in[i + 2]]; + if ((a | b | c) & 0x80 != 0) return error.InvalidEncoding; + decode_unmask(&tmp, a, b, c, 0); + std.mem.copy(u8, out[j .. j + 2], tmp[0..2]); + j += 2; + }, + else => unreachable, } return j; } - - fn decode(in: []const u8, out: []u8) EncDecError!usize { - _ = in; - _ = out; - return EncDecError.BufferOverflow; - } }; export fn base_64_encode(in_buf: [*]const u8, in_len: usize, out_buf: [*]u8, out_len: usize) isize { @@ -48,7 +171,8 @@ const out = out_buf[0..out_len]; return @intCast(isize, Base64.encode(in, out) catch |err| switch (err) { - EncDecError.BufferOverflow => return -1, + error.BufferOverflow => return -1, + error.InvalidEncoding => unreachable, }); } @@ -57,6 +181,84 @@ const out = out_buf[0..out_len]; return @intCast(isize, Base64.decode(in, out) catch |err| switch (err) { - EncDecError.BufferOverflow => return -1, + error.BufferOverflow => return -1, + error.InvalidEncoding => return -2, }); } + +fn hexDigitToNybble(digit: u8) u4 { + if (digit >= '0' and digit <= '9') { + return @intCast(u4, digit - '0'); + } else if (digit >= 'A' and digit <= 'F') { + return @intCast(u4, 0xA + digit - 'A'); + } else if (digit >= 'a' and digit <= 'f') { + return @intCast(u4, 0xA + digit - 'a'); + } else { + unreachable; // invalid hex digit + } +} +fn fromHex(comptime src: []const u8) []const u8 { + if (src.len % 2 != 0) { + @compileError("hex string length not even"); + } + var raw: [src.len >> 1]u8 = undefined; + var i = @as(usize, 0); + while (i < src.len) : (i += 2) { + raw[i >> 1] = @intCast(u8, hexDigitToNybble(src[i])) << 4 | + @intCast(u8, hexDigitToNybble(src[i + 1])); + } + return raw[0..]; +} +const alloc = std.testing.allocator; +const expectEqualSlices = std.testing.expectEqualSlices; +test "base64 encode" { + for ([_][2][]const u8{ + .{ "", "" }, + .{ ([_]u8{0})[0..] ** 1, "AA==" }, + .{ ([_]u8{0})[0..] ** 2, "AAA=" }, + .{ ([_]u8{0})[0..] ** 3, "AAAA" }, + .{ "hey", "aGV5" }, + .{ "Discard medicine more than two years old.", "RGlzY2FyZCBtZWRpY2luZSBtb3JlIHRoYW4gdHdvIHllYXJzIG9sZC4=" }, + }) |c| { + const vector = c[0]; + const expected = c[1]; + const max_len = @maximum(expected.len, 1); + const actual = try alloc.alloc(u8, max_len); + defer alloc.free(actual); + const real_len = base_64_encode(vector.ptr, vector.len, actual.ptr, max_len); + if (real_len < 0) { + std.debug.print("base64 encode: buffer overflow for {s}\n", .{vector}); + return error.BufferOverflow; + } + + try expectEqualSlices(u8, expected, actual[0..@intCast(usize, real_len)]); + } +} + +test "base64 decode" { + for ([_][2][]const u8{ + .{ "", "" }, + .{ "AA==", &.{0} }, + .{ "AAA=", &.{ 0, 0 } }, + .{ "AAAA", &.{ 0, 0, 0 } }, + .{ "AAAAA===", &.{ 0, 0, 0 } }, + .{ "AAAAAA==", &.{ 0, 0, 0, 0 } }, + .{ "AAAAAAA=", &.{ 0, 0, 0, 0, 0 } }, + .{ "AAAAAAAA", &.{ 0, 0, 0, 0, 0, 0 } }, + .{ "aGV5", "hey" }, + .{ "RGlzY2FyZCBtZWRpY2luZSBtb3JlIHRoYW4gdHdvIHllYXJzIG9sZC4=", "Discard medicine more than two years old." }, + }) |c| { + const vector = c[0]; + const expected = c[1]; + const max_len = @maximum(expected.len, 1); + const actual = try alloc.alloc(u8, max_len); + defer alloc.free(actual); + const real_len = base_64_decode(vector.ptr, vector.len, actual.ptr, max_len); + if (real_len < 0) { + std.debug.print("base64 decode: failure {} for {s}\n", .{ real_len, vector }); + return error.TestFailed; + } + + try expectEqualSlices(u8, expected, actual[0..@intCast(usize, real_len)]); + } +}