@@ 2,6 2,7 @@ const std = @import("std");
const EncDecError = error{
BufferOverflow,
+ InvalidEncoding,
};
const Base64 = struct {
@@ 10,37 11,159 @@ const Base64 = struct {
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" ++
"abcdefghijklmnopqrstuvwxyz" ++
"0123456789+/";
+ const reverse_alphabet = blk: {
+ var rev: [256]u8 = .{255} ** 256;
+ for (alphabet) |l, i| {
+ rev[l] = i;
+ }
+ rev[pad] = 0;
+ break :blk rev;
+ };
+
+ inline fn encode_mask(out: *[4]u8, a: u8, b: u8, c: u8) void {
+ //const ai = @intCast(u32, a);
+ //const bi = @intCast(u32, b);
+ //const ci = @intCast(u32, c);
+ out[0] = alphabet[a >> 2];
+ out[1] = alphabet[(a << 4 & 0b110000) | (b >> 4 & 0b001111)];
+ out[2] = alphabet[(b << 2 & 0b111100) | (c >> 6 & 0b000011)];
+ out[3] = alphabet[c & 0b111111];
+ }
fn encode(in: []const u8, out: []u8) EncDecError!usize {
+ // Pre-check sizes to avoid unnecessary branching.
+ const sym = @divTrunc(in.len + 2, 3) * 4;
+ if (out.len < sym) {
+ return error.BufferOverflow;
+ }
+
+ const odd = @mod(in.len, 3);
+
+ // Encode triplets as long as branching isn't needed.
+ const limit = in.len - odd;
var i = @as(usize, 0);
var j = @as(usize, 0);
- while (i < in.len) {
- if (j + 4 > out.len) return EncDecError.BufferOverflow;
- const a = @intCast(u32, in[i]);
- const b = if (i + 1 < in.len) @intCast(u32, in[i + 1]) else 0;
- const c = if (i + 2 < in.len) @intCast(u32, in[i + 2]) else 0;
- out[j] = alphabet[a >> 2];
- out[j + 1] = alphabet[(a << 4 & 0b110000) | (b >> 4 & 0b001111)];
- out[j + 2] = alphabet[(b << 2 & 0b111100) | (c >> 6 & 0b000011)];
- out[j + 3] = alphabet[c & 0b111111];
- if (i + 1 >= in.len) {
+ while (i < limit) : ({
+ i += 3;
+ j += 4;
+ }) {
+ encode_mask(out[j..][0..4], in[i], in[i + 1], in[i + 2]);
+ }
+
+ // Special case the remaining lengths.
+ switch (odd) {
+ 0 => {}, // noop
+ 1 => {
+ encode_mask(out[j..][0..4], in[i], 0, 0);
out[j + 2] = pad;
out[j + 3] = pad;
- } else if (i + 2 >= in.len) {
+ },
+ 2 => {
+ encode_mask(out[j..][0..4], in[i], in[i + 1], 0);
out[j + 3] = pad;
+ },
+ else => unreachable,
+ }
+
+ return sym;
+ }
+
+ /// Presumes that at most only the bottom 6 bits of a, b, c and d are set.
+ inline fn decode_unmask(tgt: *[3]u8, a: u8, b: u8, c: u8, d: u8) void {
+ // 0 1 2
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3
+ // <=============> <=============> <=============>
+ // <=========> <=========> <=========> <=========>
+ tgt[0] = (a << 2) | (b >> 4 & 0b0000_0011);
+ tgt[1] = (b << 4 & 0b1111_0000) | (c >> 2 & 0b0000_1111);
+ tgt[2] = (c << 6 & 0b1100_0000) | d;
+ }
+ fn decode(in: []const u8, out: []u8) EncDecError!usize {
+ if (in.len == 0) {
+ return 0;
+ }
+
+ var out_len = 3 * @divTrunc(in.len + 3, 4);
+ var padding = @as(usize, 0);
+ var i = in.len - 1;
+ while (i > 0) : (i -= 1) {
+ if (in[i] == pad) {
+ padding += 1;
+ } else {
+ break;
}
+ }
+ if (padding == 1) {
+ out_len -= 1;
+ } else if (padding == 2) {
+ out_len -= 2;
+ } else if (padding >= 3) {
+ out_len -= 3;
+ }
+ if (out.len < out_len) return error.BufferOverflow;
- i += 3;
- j += 4;
+ i = 0;
+ var j = @as(usize, 0);
+
+ const in_len = in.len - padding;
+ const odd = @mod(in_len, 4);
+ const limit = in_len - odd;
+ while (i < limit) : ({
+ i += 4;
+ j += 3;
+ }) {
+ const a = reverse_alphabet[in[i]];
+ const b = reverse_alphabet[in[i + 1]];
+ const c = reverse_alphabet[in[i + 2]];
+ const d = reverse_alphabet[in[i + 3]];
+ // Normally only the bottom 6 bits are set. The rest are used to
+ // indicate invalid bytes.
+ if ((a | b | c | d) & 0x80 != 0) return error.InvalidEncoding;
+
+ decode_unmask(out[j..][0..3], a, b, c, d);
+ }
+
+ var tmp: [3]u8 = undefined;
+ switch (odd) {
+ 0 => {
+ // Strip padding, if present.
+ while (true) {
+ i -= 4;
+ if (in[i + 1] == pad and in[i + 2] == pad and in[i + 3] == pad) {
+ j -= 3;
+ } else if (in[i + 2] == pad and in[i + 3] == pad) {
+ j -= 2;
+ break;
+ } else if (in[i + 3] == pad) {
+ j -= 1;
+ break;
+ } else {
+ break;
+ }
+ }
+ }, // noop
+ 1 => {}, // noop -- can't assemble a whole byte?
+ 2 => {
+ const a = reverse_alphabet[in[i]];
+ const b = reverse_alphabet[in[i + 1]];
+ if ((a | b) & 0x80 != 0) return error.InvalidEncoding;
+ decode_unmask(&tmp, a, b, 0, 0);
+ std.mem.copy(u8, out[j .. j + 1], tmp[0..1]);
+ j += 1;
+ },
+ 3 => {
+ const a = reverse_alphabet[in[i]];
+ const b = reverse_alphabet[in[i + 1]];
+ const c = reverse_alphabet[in[i + 2]];
+ if ((a | b | c) & 0x80 != 0) return error.InvalidEncoding;
+ decode_unmask(&tmp, a, b, c, 0);
+ std.mem.copy(u8, out[j .. j + 2], tmp[0..2]);
+ j += 2;
+ },
+ else => unreachable,
}
return j;
}
-
- fn decode(in: []const u8, out: []u8) EncDecError!usize {
- _ = in;
- _ = out;
- return EncDecError.BufferOverflow;
- }
};
export fn base_64_encode(in_buf: [*]const u8, in_len: usize, out_buf: [*]u8, out_len: usize) isize {
@@ 48,7 171,8 @@ export fn base_64_encode(in_buf: [*]cons
const out = out_buf[0..out_len];
return @intCast(isize, Base64.encode(in, out) catch |err| switch (err) {
- EncDecError.BufferOverflow => return -1,
+ error.BufferOverflow => return -1,
+ error.InvalidEncoding => unreachable,
});
}
@@ 57,6 181,84 @@ export fn base_64_decode(in_buf: [*]cons
const out = out_buf[0..out_len];
return @intCast(isize, Base64.decode(in, out) catch |err| switch (err) {
- EncDecError.BufferOverflow => return -1,
+ error.BufferOverflow => return -1,
+ error.InvalidEncoding => return -2,
});
}
+
+fn hexDigitToNybble(digit: u8) u4 {
+ if (digit >= '0' and digit <= '9') {
+ return @intCast(u4, digit - '0');
+ } else if (digit >= 'A' and digit <= 'F') {
+ return @intCast(u4, 0xA + digit - 'A');
+ } else if (digit >= 'a' and digit <= 'f') {
+ return @intCast(u4, 0xA + digit - 'a');
+ } else {
+ unreachable; // invalid hex digit
+ }
+}
+fn fromHex(comptime src: []const u8) []const u8 {
+ if (src.len % 2 != 0) {
+ @compileError("hex string length not even");
+ }
+ var raw: [src.len >> 1]u8 = undefined;
+ var i = @as(usize, 0);
+ while (i < src.len) : (i += 2) {
+ raw[i >> 1] = @intCast(u8, hexDigitToNybble(src[i])) << 4 |
+ @intCast(u8, hexDigitToNybble(src[i + 1]));
+ }
+ return raw[0..];
+}
+const alloc = std.testing.allocator;
+const expectEqualSlices = std.testing.expectEqualSlices;
+test "base64 encode" {
+ for ([_][2][]const u8{
+ .{ "", "" },
+ .{ ([_]u8{0})[0..] ** 1, "AA==" },
+ .{ ([_]u8{0})[0..] ** 2, "AAA=" },
+ .{ ([_]u8{0})[0..] ** 3, "AAAA" },
+ .{ "hey", "aGV5" },
+ .{ "Discard medicine more than two years old.", "RGlzY2FyZCBtZWRpY2luZSBtb3JlIHRoYW4gdHdvIHllYXJzIG9sZC4=" },
+ }) |c| {
+ const vector = c[0];
+ const expected = c[1];
+ const max_len = @maximum(expected.len, 1);
+ const actual = try alloc.alloc(u8, max_len);
+ defer alloc.free(actual);
+ const real_len = base_64_encode(vector.ptr, vector.len, actual.ptr, max_len);
+ if (real_len < 0) {
+ std.debug.print("base64 encode: buffer overflow for {s}\n", .{vector});
+ return error.BufferOverflow;
+ }
+
+ try expectEqualSlices(u8, expected, actual[0..@intCast(usize, real_len)]);
+ }
+}
+
+test "base64 decode" {
+ for ([_][2][]const u8{
+ .{ "", "" },
+ .{ "AA==", &.{0} },
+ .{ "AAA=", &.{ 0, 0 } },
+ .{ "AAAA", &.{ 0, 0, 0 } },
+ .{ "AAAAA===", &.{ 0, 0, 0 } },
+ .{ "AAAAAA==", &.{ 0, 0, 0, 0 } },
+ .{ "AAAAAAA=", &.{ 0, 0, 0, 0, 0 } },
+ .{ "AAAAAAAA", &.{ 0, 0, 0, 0, 0, 0 } },
+ .{ "aGV5", "hey" },
+ .{ "RGlzY2FyZCBtZWRpY2luZSBtb3JlIHRoYW4gdHdvIHllYXJzIG9sZC4=", "Discard medicine more than two years old." },
+ }) |c| {
+ const vector = c[0];
+ const expected = c[1];
+ const max_len = @maximum(expected.len, 1);
+ const actual = try alloc.alloc(u8, max_len);
+ defer alloc.free(actual);
+ const real_len = base_64_decode(vector.ptr, vector.len, actual.ptr, max_len);
+ if (real_len < 0) {
+ std.debug.print("base64 decode: failure {} for {s}\n", .{ real_len, vector });
+ return error.TestFailed;
+ }
+
+ try expectEqualSlices(u8, expected, actual[0..@intCast(usize, real_len)]);
+ }
+}