From b6c47e796390924faabd236204bc620ea35c1d13 Mon Sep 17 00:00:00 2001 From: Laytan Laats Date: Sat, 16 Dec 2023 21:40:41 +0100 Subject: [PATCH] encoding/base64: add decode_into, add tests --- core/encoding/base64/base64.odin | 139 +++++++++++++++++-------- tests/core/Makefile | 3 + tests/core/build.bat | 2 + tests/core/encoding/base64/base64.odin | 60 +++++++++++ 4 files changed, 158 insertions(+), 46 deletions(-) create mode 100644 tests/core/encoding/base64/base64.odin diff --git a/core/encoding/base64/base64.odin b/core/encoding/base64/base64.odin index 793f22c57..535d457d5 100644 --- a/core/encoding/base64/base64.odin +++ b/core/encoding/base64/base64.odin @@ -44,21 +44,48 @@ DEC_TABLE := [128]int { } encode :: proc(data: []byte, ENC_TBL := ENC_TABLE, allocator := context.allocator) -> (encoded: string, err: mem.Allocator_Error) #optional_allocator_error { - out_length := encoded_length(data) + out_length := encoded_len(data) if out_length == 0 { return } - out: strings.Builder - strings.builder_init(&out, 0, out_length, allocator) or_return - + out := strings.builder_make(0, out_length, allocator) or_return ioerr := encode_into(strings.to_stream(&out), data, ENC_TBL) - assert(ioerr == nil) + + assert(ioerr == nil, "string builder should not IO error") + assert(strings.builder_cap(out) == out_length, "buffer resized, `encoded_len` was wrong") return strings.to_string(out), nil } -encoded_length :: #force_inline proc(data: []byte) -> int { +encode_into :: proc(w: io.Writer, data: []byte, ENC_TBL := ENC_TABLE) -> io.Error { + length := len(data) + if length == 0 { + return nil + } + + c0, c1, c2, block: int + out: [4]byte + for i := 0; i < length; i += 3 { + #no_bounds_check { + c0, c1, c2 = int(data[i]), -1, -1 + + if i + 1 < length { c1 = int(data[i + 1]) } + if i + 2 < length { c2 = int(data[i + 2]) } + + block = (c0 << 16) | (max(c1, 0) << 8) | max(c2, 0) + + out[0] = ENC_TBL[block >> 18 & 63] + out[1] = ENC_TBL[block >> 12 & 63] + out[2] = c1 == -1 ? PADDING : ENC_TBL[block >> 6 & 63] + out[3] = c2 == -1 ? PADDING : ENC_TBL[block & 63] + } + io.write_full(w, out[:]) or_return + } + return nil +} + +encoded_len :: proc(data: []byte) -> int { length := len(data) if length == 0 { return 0 @@ -67,48 +94,30 @@ encoded_length :: #force_inline proc(data: []byte) -> int { return ((4 * length / 3) + 3) &~ 3 } -encode_into :: proc(w: io.Writer, data: []byte, ENC_TBL := ENC_TABLE) -> (err: io.Error) #no_bounds_check { - length := len(data) - if length == 0 { - return - } +decode :: proc(data: string, DEC_TBL := DEC_TABLE, allocator := context.allocator) -> (decoded: []byte, err: mem.Allocator_Error) #optional_allocator_error { + out_length := decoded_len(data) - c0, c1, c2, block: int + out := strings.builder_make(0, out_length, allocator) or_return + ioerr := decode_into(strings.to_stream(&out), data, DEC_TBL) - for i, d := 0, 0; i < length; i, d = i + 3, d + 4 { - c0, c1, c2 = int(data[i]), -1, -1 + assert(ioerr == nil, "string builder should not IO error") + assert(strings.builder_cap(out) == out_length, "buffer resized, `decoded_len` was wrong") - if i + 1 < length { c1 = int(data[i + 1]) } - if i + 2 < length { c2 = int(data[i + 2]) } - - block = (c0 << 16) | (max(c1, 0) << 8) | max(c2, 0) - - out: [4]byte - out[0] = ENC_TBL[block >> 18 & 63] - out[1] = ENC_TBL[block >> 12 & 63] - out[2] = c1 == -1 ? PADDING : ENC_TBL[block >> 6 & 63] - out[3] = c2 == -1 ? PADDING : ENC_TBL[block & 63] - - #bounds_check { io.write_full(w, out[:]) or_return } - } - return + return out.buf[:], nil } -decode :: proc(data: string, DEC_TBL := DEC_TABLE, allocator := context.allocator) -> (out: []byte, err: mem.Allocator_Error) #optional_allocator_error { - #no_bounds_check { - length := len(data) - if length == 0 { - return - } +decode_into :: proc(w: io.Writer, data: string, DEC_TBL := DEC_TABLE) -> io.Error { + length := decoded_len(data) + if length == 0 { + return nil + } - pad_count := data[length - 1] == PADDING ? (data[length - 2] == PADDING ? 2 : 1) : 0 - out_length := ((length * 6) >> 3) - pad_count - out = make([]byte, out_length, allocator) or_return - - c0, c1, c2, c3: int - b0, b1, b2: int - - for i, j := 0, 0; i < length; i, j = i + 4, j + 3 { + c0, c1, c2, c3: int + b0, b1, b2: int + buf: [3]byte + i, j: int + for ; j + 3 <= length; i, j = i + 4, j + 3 { + #no_bounds_check { c0 = DEC_TBL[data[i]] c1 = DEC_TBL[data[i + 1]] c2 = DEC_TBL[data[i + 2]] @@ -118,10 +127,48 @@ decode :: proc(data: string, DEC_TBL := DEC_TABLE, allocator := context.allocato b1 = (c1 << 4) | (c2 >> 2) b2 = (c2 << 6) | c3 - out[j] = byte(b0) - out[j + 1] = byte(b1) - out[j + 2] = byte(b2) + buf[0] = byte(b0) + buf[1] = byte(b1) + buf[2] = byte(b2) } - return + + io.write_full(w, buf[:]) or_return } + + rest := length - j + if rest > 0 { + #no_bounds_check { + c0 = DEC_TBL[data[i]] + c1 = DEC_TBL[data[i + 1]] + c2 = DEC_TBL[data[i + 2]] + + b0 = (c0 << 2) | (c1 >> 4) + b1 = (c1 << 4) | (c2 >> 2) + } + + switch rest { + case 1: io.write_byte(w, byte(b0)) or_return + case 2: io.write_full(w, {byte(b0), byte(b1)}) or_return + } + } + + return nil +} + +decoded_len :: proc(data: string) -> int { + length := len(data) + if length == 0 { + return 0 + } + + padding: int + if data[length - 1] == PADDING { + if length > 1 && data[length - 2] == PADDING { + padding = 2 + } else { + padding = 1 + } + } + + return ((length * 6) >> 3) - padding } diff --git a/tests/core/Makefile b/tests/core/Makefile index 1fca7bf97..3fa38cd34 100644 --- a/tests/core/Makefile +++ b/tests/core/Makefile @@ -51,11 +51,14 @@ noise_test: $(ODIN) run math/noise $(COMMON) -out:test_noise encoding_test: +<<<<<<< HEAD $(ODIN) run encoding/hxa $(COMMON) $(COLLECTION) -out:test_hxa $(ODIN) run encoding/json $(COMMON) -out:test_json $(ODIN) run encoding/varint $(COMMON) -out:test_varint $(ODIN) run encoding/xml $(COMMON) -out:test_xml $(ODIN) run encoding/cbor $(COMMON) -out:test_cbor + $(ODIN) run encoding/hex $(COMMON) -out:test_hex + $(ODIN) run encoding/base64 $(COMMON) -out:test_base64 math_test: $(ODIN) run math $(COMMON) $(COLLECTION) -out:test_core_math diff --git a/tests/core/build.bat b/tests/core/build.bat index 5bf8e1ead..b9fc4e828 100644 --- a/tests/core/build.bat +++ b/tests/core/build.bat @@ -41,6 +41,8 @@ rem %PATH_TO_ODIN% run encoding/hxa %COMMON% %COLLECTION% -out:test_hxa.exe | %PATH_TO_ODIN% run encoding/varint %COMMON% -out:test_varint.exe || exit /b %PATH_TO_ODIN% run encoding/xml %COMMON% -out:test_xml.exe || exit /b %PATH_TO_ODIN% test encoding/cbor %COMMON% -out:test_cbor.exe || exit /b +%PATH_TO_ODIN% run encoding/hex %COMMON% -out:test_hex.exe || exit /b +%PATH_TO_ODIN% run encoding/base64 %COMMON% -out:test_base64.exe || exit /b echo --- echo Running core:math/noise tests diff --git a/tests/core/encoding/base64/base64.odin b/tests/core/encoding/base64/base64.odin new file mode 100644 index 000000000..41dbba683 --- /dev/null +++ b/tests/core/encoding/base64/base64.odin @@ -0,0 +1,60 @@ +package test_encoding_base64 + +import "core:encoding/base64" +import "core:fmt" +import "core:intrinsics" +import "core:os" +import "core:reflect" +import "core:testing" + +TEST_count := 0 +TEST_fail := 0 + +when ODIN_TEST { + expect_value :: testing.expect_value + +} else { + expect_value :: proc(t: ^testing.T, value, expected: $T, loc := #caller_location) -> bool where intrinsics.type_is_comparable(T) { + TEST_count += 1 + ok := value == expected || reflect.is_nil(value) && reflect.is_nil(expected) + if !ok { + TEST_fail += 1 + fmt.printf("[%v] expected %v, got %v\n", loc, expected, value) + } + return ok + } +} + +main :: proc() { + t := testing.T{} + + test_encoding(&t) + test_decoding(&t) + + fmt.printf("%v/%v tests successful.\n", TEST_count - TEST_fail, TEST_count) + if TEST_fail > 0 { + os.exit(1) + } +} + +@(test) +test_encoding :: proc(t: ^testing.T) { + expect_value(t, base64.encode(transmute([]byte)string("")), "") + expect_value(t, base64.encode(transmute([]byte)string("f")), "Zg==") + expect_value(t, base64.encode(transmute([]byte)string("fo")), "Zm8=") + expect_value(t, base64.encode(transmute([]byte)string("foo")), "Zm9v") + expect_value(t, base64.encode(transmute([]byte)string("foob")), "Zm9vYg==") + expect_value(t, base64.encode(transmute([]byte)string("fooba")), "Zm9vYmE=") + expect_value(t, base64.encode(transmute([]byte)string("foobar")), "Zm9vYmFy") +} + +@(test) +test_decoding :: proc(t: ^testing.T) { + expect_value(t, string(base64.decode("")), "") + expect_value(t, string(base64.decode("Zg==")), "f") + expect_value(t, string(base64.decode("Zm8=")), "fo") + expect_value(t, string(base64.decode("Zm9v")), "foo") + expect_value(t, string(base64.decode("Zm9vYg==")), "foob") + expect_value(t, string(base64.decode("Zm9vYmE=")), "fooba") + expect_value(t, string(base64.decode("Zm9vYmFy")), "foobar") +}