42d7a07d4256 — Nathan Michaels 4 years ago
Update for new allocator interface.
2 files changed, 49 insertions(+), 41 deletions(-)

M src/crypto_box.zig
M src/mem.zig
M src/crypto_box.zig +1 -1
@@ 3,7 3,7 @@ const testing = std.testing;
 const Allocator = std.mem.Allocator;
 const randombytes = @import("randombytes.zig");
 const sodium = @import("sodium.zig");
-const SodiumError = sodium.SodiumError;
+const SodiumError = @import("errors.zig").SodiumError;
 
 const c = @cImport({
     @cInclude("sodium.h");

          
M src/mem.zig +48 -40
@@ 35,61 35,69 @@ pub fn zero(arr: []u8) void {
 /// does, however, mlock allocated pages and zero freed memory.
 pub const sodium_allocator: *Allocator = &sodium_allocator_state;
 var sodium_allocator_state = Allocator{
-    .reallocFn = sodiumRealloc,
-    .shrinkFn = sodiumShrink,
+    .allocFn = sodiumAlloc,
+    .resizeFn = sodiumResize,
 };
 
-fn sodiumRealloc(
+fn sodiumAlloc(
     self: *Allocator,
-    old_mem: []u8,
-    old_alignment: u29,
-    new_byte_count: usize,
-    new_alignment: u29,
+    len: usize,
+    ptr_align: u29,
+    len_align: u29,
 ) Error![]u8 {
-    if (old_mem.len != 0) {
-        return Error.OutOfMemory; // Not really, but we don't have a realloc.
+    var bytes = len;
+    const alignment = std.math.max(ptr_align, len_align);
+    if (alignment != 0) {
+        // Slow but easy. Fix with math later.
+        while (bytes % alignment != 0) {
+            bytes += 1;
+        }
     }
 
-    // new_alignment is guaranteed to be a power of 2, >= 1, so
-    // subtracting 1 will give all 1s to the right of the set bit.
-    const aligned_zeros = new_alignment - 1;
-    var bytes = new_byte_count;
-    while (bytes & aligned_zeros != 0) {
-        // sodium_malloc doesn't align the return address if size
-        // isn't a multiple of the required alignment, so we have to
-        // round up. This is the dumb slow way, but it doesn't involve
-        // any math so we can fix it later.
-        bytes += 1;
-    }
+    // Zig helps prevent access past the end of the slice, but that's
+    // the only mechanism in cases where len_align < ptr_align. Sodium
+    // will catch writes past the end of the allocated buffer.
+    const slice_len = if (len_align >= ptr_align)
+        bytes
+    else if (len_align == 0)
+        len
+    else
+        len + (len_align - (len % len_align));
+
     const allocated = c.sodium_malloc(bytes);
     if (allocated) |ptr| {
-        return @ptrCast([*]u8, ptr)[0..new_byte_count];
+        return @ptrCast([*]u8, ptr)[0..slice_len];
     } else {
         return Error.OutOfMemory;
     }
 }
 
-fn sodiumShrink(
+fn sodiumResize(
     self: *Allocator,
-    old_mem: []u8,
-    old_alignment: u29,
-    new_byte_count: usize,
-    new_alignment: u29,
-) []u8 {
-    var new_mem: []u8 = undefined;
-    if (new_byte_count != 0) {
-        // This function isn't allowed to fail, so instead of trying
-        // to allocate additional memory and move the old buffer into
-        // it, we'll just pretend the buffer got smaller. This loses
-        // some of libsodium's nice fencing, and it doesn't free any
-        // memory, but Zig should prevent casual access to memory past
-        // the end of the buffer. We'll zero the bytes we're losing,
-        // just to avoid keeping potentially sensitive but unneeded
-        // data in memory.
-        zero(old_mem[new_byte_count..]);
-        new_mem = old_mem[0..new_byte_count];
+    buf: []u8,
+    new_len: usize,
+    len_align: u29,
+) Error!usize {
+    if (new_len > buf.len) {
+        return Error.OutOfMemory;
+    }
+    if (new_len == 0) {
+        c.sodium_free(buf.ptr);
+        return 0;
     }
-    return old_mem[0..new_byte_count];
+    const real_new_len = len: {
+        if (len_align == 0) {
+            break :len new_len;
+        }
+        var bytes = new_len;
+        while (bytes % len_align != 0) {
+            bytes += 1;
+        }
+        break :len bytes;
+    };
+
+    zero(buf[real_new_len..]);
+    return real_new_len;
 }
 
 test "mem lock" {