From 82a18797b4f54b5ece243c8aa28c649911243a29 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Mon, 25 May 2026 23:01:27 +0900 Subject: [PATCH] core/encoding/base64: Misc fixes and improvements - Add error checking to `decode` - Add `encode_into_buf`/`decode_into_buf` --- core/encoding/base64/base64.odin | 319 +++++++++++++++++-------- core/encoding/cbor/tags.odin | 46 +++- tests/core/encoding/base64/base64.odin | 32 ++- 3 files changed, 285 insertions(+), 112 deletions(-) diff --git a/core/encoding/base64/base64.odin b/core/encoding/base64/base64.odin index 1488e2201..e97652cd4 100644 --- a/core/encoding/base64/base64.odin +++ b/core/encoding/base64/base64.odin @@ -9,10 +9,11 @@ truncate it from the encoded output. */ package encoding_base64 +import "base:intrinsics" import "base:runtime" import "core:io" -import "core:strings" +@(rodata) ENC_TABLE := [64]byte { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', @@ -25,6 +26,7 @@ ENC_TABLE := [64]byte { } // Encoding table for Base64url variant +@(rodata) ENC_URL_TABLE := [64]byte { 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', @@ -38,77 +40,89 @@ ENC_URL_TABLE := [64]byte { PADDING :: '=' -DEC_TABLE := [256]u8 { - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 62, 0, 0, 0, 63, +@(rodata) +DEC_TABLE := [256]i8 { + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, 62, -1, -1, -1, 63, 52, 53, 54, 55, 56, 57, 58, 59, - 60, 61, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 2, 3, 4, 5, 6, + 60, 61, -1, -1, -1, -1, -1, -1, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, - 23, 24, 25, 0, 0, 0, 0, 0, - 0, 26, 27, 28, 29, 30, 31, 32, + 23, 24, 25, -1, -1, -1, -1, -1, + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, - 49, 50, 51, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, + 49, 50, 51, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, } // Decoding table for Base64url variant -DEC_URL_TABLE := [256]u8 { - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 62, 0, 0, +@(rodata) +DEC_URL_TABLE := [256]i8 { + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, 62, -1, -1, 52, 53, 54, 55, 56, 57, 58, 59, - 60, 61, 0, 0, 0, 0, 0, 0, - 0, 0, 1, 2, 3, 4, 5, 6, + 60, 61, -1, -1, -1, -1, -1, -1, + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, - 23, 24, 25, 0, 0, 0, 0, 63, - 0, 26, 27, 28, 29, 30, 31, 32, + 23, 24, 25, -1, -1, -1, -1, 63, + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, - 49, 50, 51, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, + 49, 50, 51, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, } +Error :: union #shared_nil { + runtime.Allocator_Error, + io.Error, + Decode_Error, +} + +Decode_Error :: enum { + None, + Invalid_Character, +} encode :: proc(data: []byte, ENC_TBL := ENC_TABLE, allocator := context.allocator) -> (encoded: string, err: runtime.Allocator_Error) #optional_allocator_error { out_length := encoded_len(data) @@ -116,23 +130,47 @@ encode :: proc(data: []byte, ENC_TBL := ENC_TABLE, allocator := context.allocato return } - out := strings.builder_make(0, out_length, allocator) or_return - ioerr := encode_into(strings.to_stream(&out), data, ENC_TBL) + out := make([]byte, out_length, allocator) or_return + _, ioerr := encode_impl(out, data, ENC_TBL) + assert(ioerr == nil, "encode should not IO error") + assert(len(out) == out_length, "buffer resized, `encoded_len` was wrong") - assert(ioerr == nil, "string builder should not IO error") - assert(strings.builder_cap(out) == out_length, "buffer resized, `encoded_len` was wrong") + encoded = transmute(string)(out) - return strings.to_string(out), nil + return +} + +encode_into_buf :: proc(dst, data: []byte, ENC_TBL := ENC_TABLE) -> (encoded: []byte, err: Error) { + out_length := encoded_len(data) + if out_length == 0 { + return + } + + return encode_impl(dst, data, ENC_TBL) } encode_into :: proc(w: io.Writer, data: []byte, ENC_TBL := ENC_TABLE) -> io.Error { + _, err := encode_impl(w, data, ENC_TBL) + return err +} + +@(private) +encode_impl :: proc(dst: $T, data: []byte, ENC_TBL := ENC_TABLE) -> ([]byte, io.Error) where T == io.Writer || T == []byte { length := len(data) - if length == 0 { - return nil + when T == []byte { + out_length := encoded_len(data) + if len(dst) < out_length { + return nil, io.Error.Short_Buffer + } + buf := dst + } else { + if length == 0 { + return nil, nil + } + buf: [4]byte } c0, c1, c2, block: int - out: [4]byte for i := 0; i < length; i += 3 { #no_bounds_check { c0, c1, c2 = int(data[i]), -1, -1 @@ -141,15 +179,27 @@ encode_into :: proc(w: io.Writer, data: []byte, ENC_TBL := ENC_TABLE) -> io.Erro if i + 2 < length { c2 = int(data[i + 2]) } block = (c0 << 16) | (max(c1, 0) << 8) | max(c2, 0) - - out[0] = ENC_TBL[block >> 18 & 63] - out[1] = ENC_TBL[block >> 12 & 63] - out[2] = c1 == -1 ? PADDING : ENC_TBL[block >> 6 & 63] - out[3] = c2 == -1 ? PADDING : ENC_TBL[block & 63] + + buf[0] = ENC_TBL[block >> 18 & 63] + buf[1] = ENC_TBL[block >> 12 & 63] + buf[2] = c1 == -1 ? PADDING : ENC_TBL[block >> 6 & 63] + buf[3] = c2 == -1 ? PADDING : ENC_TBL[block & 63] + when T == []byte { + buf = buf[4:] + } + } + when T == io.Writer { + if _, err := io.write_full(dst, buf[:]); err != nil { + return nil, err + } } - io.write_full(w, out[:]) or_return } - return nil + + when T == io.Writer { + return nil, nil + } else { + return dst[:out_length], nil + } } encoded_len :: proc(data: []byte) -> int { @@ -161,65 +211,140 @@ encoded_len :: proc(data: []byte) -> int { return ((4 * length / 3) + 3) &~ 3 } -decode :: proc(data: string, DEC_TBL := DEC_TABLE, allocator := context.allocator) -> (decoded: []byte, err: runtime.Allocator_Error) #optional_allocator_error { +decode :: proc(data: string, DEC_TBL := DEC_TABLE, dst: []byte = nil, allocator := context.allocator) -> (decoded: []byte, err: Error) { out_length := decoded_len(data) + if out_length == 0 { + return nil, nil + } - out := strings.builder_make(0, out_length, allocator) or_return - ioerr := decode_into(strings.to_stream(&out), data, DEC_TBL) + buf: []byte + if buf, err = make([]byte, out_length, allocator); err != nil { + return + } - assert(ioerr == nil, "string builder should not IO error") - assert(strings.builder_cap(out) == out_length, "buffer resized, `decoded_len` was wrong") + decoded, err = decode_impl(buf, data, DEC_TBL) + if err != nil { + delete(buf, allocator) + } + assert(err != nil || len(decoded) == out_length, "buffer unexpectedly resized, `decoded_len` was wrong") - return out.buf[:], nil + return } -decode_into :: proc(w: io.Writer, data: string, DEC_TBL := DEC_TABLE) -> io.Error { +decode_into_buf :: proc(dst: []byte, data: string, DEC_TBL := DEC_TABLE) -> (decoded: []byte, err: Error) { + out_length := decoded_len(data) + if out_length == 0 { + return + } + + return decode_impl(dst, data, DEC_TBL) +} + +decode_into :: proc(w: io.Writer, data: string, DEC_TBL := DEC_TABLE) -> Error { + _, err := decode_impl(w, data, DEC_TBL) + return err +} + +@(private) +decode_impl :: proc(dst: $T, data: string, DEC_TBL := DEC_TABLE) -> ([]byte, Error) where T == io.Writer || T == []byte { length := decoded_len(data) - if length == 0 { - return nil + when T == []byte { + if len(dst) < length { + return nil, io.Error.Short_Buffer + } + off: int + } else { + if length == 0 { + return nil, nil + } + buf: [3]byte } c0, c1, c2, c3: int + d0, d1, d2, d3: i8 b0, b1, b2: int - buf: [3]byte i, j: int for ; j + 3 <= length; i, j = i + 4, j + 3 { #no_bounds_check { - c0 = int(DEC_TBL[data[i]]) - c1 = int(DEC_TBL[data[i + 1]]) - c2 = int(DEC_TBL[data[i + 2]]) - c3 = int(DEC_TBL[data[i + 3]]) + d0 = DEC_TBL[data[i]] + d1 = DEC_TBL[data[i + 1]] + d2 = DEC_TBL[data[i + 2]] + d3 = DEC_TBL[data[i + 3]] + + if intrinsics.unlikely((d0 | d1 | d2 | d3) & ~i8(0x3f) != 0) { + return nil, Decode_Error.Invalid_Character + } + + c0, c1, c2, c3 = int(d0), int(d1), int(d2), int(d3) b0 = (c0 << 2) | (c1 >> 4) b1 = (c1 << 4) | (c2 >> 2) b2 = (c2 << 6) | c3 - buf[0] = byte(b0) - buf[1] = byte(b1) - buf[2] = byte(b2) + when T == []byte { + dst[off+0] = byte(b0) + dst[off+1] = byte(b1) + dst[off+2] = byte(b2) + off += 3 + } else { + buf[0] = byte(b0) + buf[1] = byte(b1) + buf[2] = byte(b2) + } } - io.write_full(w, buf[:]) or_return + when T == io.Writer { + if _, err := io.write_full(dst, buf[:]); err != .None { + return nil, err + } + } } rest := length - j if rest > 0 { #no_bounds_check { - c0 = int(DEC_TBL[data[i]]) - c1 = int(DEC_TBL[data[i + 1]]) - c2 = int(DEC_TBL[data[i + 2]]) + // Note: decoded_len handles removing padding. + d0 = DEC_TBL[data[i]] + d1 = DEC_TBL[data[i + 1]] + if d2 = 0; rest == 2 { + d2 = DEC_TBL[data[i + 2]] + } + + if intrinsics.unlikely((d0 | d1 | d2) & ~i8(0x3f) != 0) { + return nil, Decode_Error.Invalid_Character + } + + c0, c1, c2 = int(d0), int(d1), int(d2) b0 = (c0 << 2) | (c1 >> 4) b1 = (c1 << 4) | (c2 >> 2) + + when T == []byte { + switch rest { + case 2: + dst[off+1] = byte(b1) + fallthrough + case 1: + dst[off] = byte(b0) + } + } else { + buf[0] = byte(b0) + buf[1] = byte(b1) + } } - switch rest { - case 1: io.write_byte(w, byte(b0)) or_return - case 2: io.write_full(w, {byte(b0), byte(b1)}) or_return + when T == io.Writer { + if _, err := io.write_full(dst, buf[:rest]); err != .None { + return nil, err + } } } - return nil + when T == io.Writer { + return nil, nil + } else { + return dst[:length], nil + } } decoded_len :: proc(data: string) -> int { diff --git a/core/encoding/cbor/tags.odin b/core/encoding/cbor/tags.odin index fa456673d..ad0a85913 100644 --- a/core/encoding/cbor/tags.odin +++ b/core/encoding/cbor/tags.odin @@ -308,13 +308,13 @@ tag_base64_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, _: Tag_Number, if t.is_cstring { length := base64.decoded_len(bytes) builder := strings.builder_make(0, length+1) - base64.decode_into(strings.to_stream(&builder), bytes) or_return + b64_decode_into(strings.to_stream(&builder), bytes) or_return raw := (^cstring)(v.data) raw^ = cstring(raw_data(builder.buf)) } else { raw := (^string)(v.data) - raw^ = string(base64.decode(bytes) or_return) + raw^ = string(b64_decode(bytes) or_return) } return @@ -325,16 +325,16 @@ tag_base64_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, _: Tag_Number, if elem_base.id != byte { return _unsupported(v, hdr) } raw := (^[]byte)(v.data) - raw^ = base64.decode(bytes) or_return + raw^ = b64_decode(bytes) or_return return - + case reflect.Type_Info_Dynamic_Array: elem_base := reflect.type_info_base(t.elem) if elem_base.id != byte { return _unsupported(v, hdr) } - decoded := base64.decode(bytes) or_return - + decoded := b64_decode(bytes) or_return + raw := (^mem.Raw_Dynamic_Array)(v.data) raw.data = raw_data(decoded) raw.len = len(decoded) @@ -348,9 +348,9 @@ tag_base64_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, _: Tag_Number, if elem_base.id != byte { return _unsupported(v, hdr) } if base64.decoded_len(bytes) > t.count { return _unsupported(v, hdr) } - + slice := ([^]byte)(v.data)[:len(bytes)] - copy(slice, base64.decode(bytes) or_return) + copy(slice, b64_decode(bytes) or_return) return } @@ -384,3 +384,33 @@ tag_base64_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marsha err_conv(_encode_u64(e, u64(out_len), .Text)) or_return return base64.encode_into(e.writer, bytes) } + +@(private="file") +err_from_b64 :: proc(err: base64.Error) -> Unmarshal_Error { + switch e in err { + case runtime.Allocator_Error: + return e + case io.Error: + return e + case base64.Decode_Error: + return Decode_Data_Error.Bad_Tag_Value + } + + // Should NEVER happen, but fail gracefully. + return io.Error.Unknown +} + +@(private="file") +b64_decode :: proc(data: string) -> ([]byte, Unmarshal_Error) { + decoded, err := base64.decode(data) + if err == nil { + return decoded, nil + } + + return nil, err_from_b64(err) +} + +@(private="file") +b64_decode_into :: proc(w: io.Writer, data: string) -> Unmarshal_Error { + return err_from_b64(base64.decode_into(w, data)) +} diff --git a/tests/core/encoding/base64/base64.odin b/tests/core/encoding/base64/base64.odin index 93b3afb59..ba7572139 100644 --- a/tests/core/encoding/base64/base64.odin +++ b/tests/core/encoding/base64/base64.odin @@ -31,12 +31,22 @@ test_encoding :: proc(t: ^testing.T) { @(test) test_decoding :: proc(t: ^testing.T) { for test in tests { - v := string(base64.decode(test.base64)) + v, err := base64.decode(test.base64) + if !testing.expect_value(t, err, nil) { + continue + } defer delete(v) - testing.expect_value(t, v, test.vector) + testing.expect_value(t, string(v), test.vector) } } +@(test) +test_decoding_failure :: proc(t: ^testing.T) { + v, err := base64.decode("!#$%") + testing.expect(t, v == nil) + testing.expect(t, err == base64.Decode_Error.Invalid_Character) +} + @(test) test_roundtrip :: proc(t: ^testing.T) { values: [1024]u8 @@ -44,8 +54,14 @@ test_roundtrip :: proc(t: ^testing.T) { v = u8(i) } - encoded := base64.encode(values[:]); defer delete(encoded) - decoded := base64.decode(encoded); defer delete(decoded) + encoded := base64.encode(values[:]) + defer delete(encoded) + + decoded, err := base64.decode(encoded) + if !testing.expect_value(t, err, nil) { + return + } + defer delete(decoded) for v, i in decoded { testing.expect_value(t, v, values[i]) @@ -61,8 +77,10 @@ test_base64url :: proc(t: ^testing.T) { defer delete(encoded) testing.expect_value(t, encoded, url) - decoded := string(base64.decode(url, base64.DEC_URL_TABLE)) + decoded, err := base64.decode(url, base64.DEC_URL_TABLE) + if !testing.expect_value(t, err, nil) { + return + } defer delete(decoded) - testing.expect_value(t, decoded, plain) - + testing.expect_value(t, string(decoded), plain) }