encoding/cbor: add decoder flags and protect from malicious untrusted input

This commit is contained in:
Laytan Laats
2023-12-16 23:02:30 +01:00
parent d77ae9abab
commit 21e6e28a3a
5 changed files with 351 additions and 231 deletions

View File

@@ -10,8 +10,13 @@ import "core:strings"
// If we are decoding a stream of either a map or list, the initial capacity will be this value.
INITIAL_STREAMED_CONTAINER_CAPACITY :: 8
// If we are decoding a stream of either text or bytes, the initial capacity will be this value.
INITIAL_STREAMED_BYTES_CAPACITY :: 16
INITIAL_STREAMED_BYTES_CAPACITY :: 16
// The default maximum amount of bytes to allocate on a buffer/container at once to prevent
// malicious input from causing massive allocations.
DEFAULT_MAX_PRE_ALLOC :: mem.Kilobyte
// Known/common headers are defined, undefined headers can still be valid.
// Higher 3 bits is for the major type and lower 5 bits for the additional information.
@@ -157,6 +162,7 @@ Decode_Data_Error :: enum {
Nested_Indefinite_Length, // When an streamed/indefinite length container nests another, this is not allowed.
Nested_Tag, // When a tag's value is another tag, this is not allowed.
Length_Too_Big, // When the length of a container (map, array, bytes, string) is more than `max(int)`.
Disallowed_Streaming, // When the `.Disallow_Streaming` flag is set and a streaming header is encountered.
Break,
}

View File

@@ -33,16 +33,40 @@ Encoder_Flags :: bit_set[Encoder_Flag]
// Flags for fully deterministic output (if you are not using streaming/indeterminate length).
ENCODE_FULLY_DETERMINISTIC :: Encoder_Flags{.Deterministic_Int_Size, .Deterministic_Float_Size, .Deterministic_Map_Sorting}
// Flags for the smallest encoding output.
ENCODE_SMALL :: Encoder_Flags{.Deterministic_Int_Size, .Deterministic_Float_Size}
// Flags for the fastest encoding output.
ENCODE_FAST :: Encoder_Flags{}
ENCODE_SMALL :: Encoder_Flags{.Deterministic_Int_Size, .Deterministic_Float_Size}
Encoder :: struct {
flags: Encoder_Flags,
writer: io.Writer,
}
Decoder_Flag :: enum {
// Rejects (with an error `.Disallowed_Streaming`) when a streaming CBOR header is encountered.
Disallow_Streaming,
// Pre-allocates buffers and containers with the size that was set in the CBOR header.
// This should only be enabled when you control both ends of the encoding, if you don't,
// attackers can craft input that causes massive (`max(u64)`) byte allocations for a few bytes of
// CBOR.
Trusted_Input,
// Makes the decoder shrink of excess capacity from allocated buffers/containers before returning.
Shrink_Excess,
}
Decoder_Flags :: bit_set[Decoder_Flag]
Decoder :: struct {
// The max amount of bytes allowed to pre-allocate when `.Trusted_Input` is not set on the
// flags.
max_pre_alloc: int,
flags: Decoder_Flags,
reader: io.Reader,
}
/*
Decodes both deterministic and non-deterministic CBOR into a `Value` variant.
@@ -52,28 +76,60 @@ Allocations are done using the given allocator,
*no* allocations are done on the `context.temp_allocator`.
A value can be (fully and recursively) deallocated using the `destroy` proc in this package.
Disable streaming/indeterminate lengths with the `.Disallow_Streaming` flag.
Shrink excess bytes in buffers and containers with the `.Shrink_Excess` flag.
Mark the input as trusted input with the `.Trusted_Input` flag, this turns off the safety feature
of not pre-allocating more than `max_pre_alloc` bytes before reading into the bytes. You should only
do this when you own both sides of the encoding and are sure there can't be malicious bytes used as
an input.
*/
decode :: proc {
decode_string,
decode_reader,
decode_from :: proc {
decode_from_string,
decode_from_reader,
decode_from_decoder,
}
decode :: decode_from
// Decodes the given string as CBOR.
// See docs on the proc group `decode` for more information.
decode_string :: proc(s: string, allocator := context.allocator) -> (v: Value, err: Decode_Error) {
decode_from_string :: proc(s: string, flags: Decoder_Flags = {}, allocator := context.allocator) -> (v: Value, err: Decode_Error) {
context.allocator = allocator
r: strings.Reader
strings.reader_init(&r, s)
return decode(strings.reader_to_stream(&r), allocator=allocator)
return decode_from_reader(strings.reader_to_stream(&r), flags)
}
// Reads a CBOR value from the given reader.
// See docs on the proc group `decode` for more information.
decode_reader :: proc(r: io.Reader, hdr: Header = Header(0), allocator := context.allocator) -> (v: Value, err: Decode_Error) {
decode_from_reader :: proc(r: io.Reader, flags: Decoder_Flags = {}, allocator := context.allocator) -> (v: Value, err: Decode_Error) {
return decode_from_decoder(
Decoder{ DEFAULT_MAX_PRE_ALLOC, flags, r },
allocator=allocator,
)
}
// Reads a CBOR value from the given decoder.
// See docs on the proc group `decode` for more information.
decode_from_decoder :: proc(d: Decoder, allocator := context.allocator) -> (v: Value, err: Decode_Error) {
context.allocator = allocator
d := d
if d.max_pre_alloc <= 0 {
d.max_pre_alloc = DEFAULT_MAX_PRE_ALLOC
}
v, err = _decode_from_decoder(d)
// Normal EOF does not exist here, we try to read the exact amount that is said to be provided.
if err == .EOF { err = .Unexpected_EOF }
return
}
_decode_from_decoder :: proc(d: Decoder, hdr: Header = Header(0)) -> (v: Value, err: Decode_Error) {
hdr := hdr
r := d.reader
if hdr == Header(0) { hdr = _decode_header(r) or_return }
switch hdr {
case .U8: return _decode_u8 (r)
@@ -105,11 +161,11 @@ decode_reader :: proc(r: io.Reader, hdr: Header = Header(0), allocator := contex
switch maj {
case .Unsigned: return _decode_tiny_u8(add)
case .Negative: return Negative_U8(_decode_tiny_u8(add) or_return), nil
case .Bytes: return _decode_bytes_ptr(r, add)
case .Text: return _decode_text_ptr(r, add)
case .Array: return _decode_array_ptr(r, add)
case .Map: return _decode_map_ptr(r, add)
case .Tag: return _decode_tag_ptr(r, add)
case .Bytes: return _decode_bytes_ptr(d, add)
case .Text: return _decode_text_ptr(d, add)
case .Array: return _decode_array_ptr(d, add)
case .Map: return _decode_map_ptr(d, add)
case .Tag: return _decode_tag_ptr(d, add)
case .Other: return _decode_tiny_simple(add)
case: return nil, .Bad_Major
}
@@ -246,7 +302,7 @@ _encode_u8 :: proc(w: io.Writer, v: u8, major: Major = .Unsigned) -> (err: io.Er
}
_decode_tiny_u8 :: proc(additional: Add) -> (u8, Decode_Data_Error) {
if intrinsics.expect(additional < .One_Byte, true) {
if additional < .One_Byte {
return u8(additional), nil
}
@@ -316,64 +372,53 @@ _encode_u64_exact :: proc(w: io.Writer, v: u64, major: Major = .Unsigned) -> (er
return
}
_decode_bytes_ptr :: proc(r: io.Reader, add: Add, type: Major = .Bytes) -> (v: ^Bytes, err: Decode_Error) {
_decode_bytes_ptr :: proc(d: Decoder, add: Add, type: Major = .Bytes) -> (v: ^Bytes, err: Decode_Error) {
v = new(Bytes) or_return
defer if err != nil { free(v) }
v^ = _decode_bytes(r, add, type) or_return
v^ = _decode_bytes(d, add, type) or_return
return
}
_decode_bytes :: proc(r: io.Reader, add: Add, type: Major = .Bytes) -> (v: Bytes, err: Decode_Error) {
_n_items, length_is_unknown := _decode_container_length(r, add) or_return
_decode_bytes :: proc(d: Decoder, add: Add, type: Major = .Bytes) -> (v: Bytes, err: Decode_Error) {
n, scap := _decode_len_str(d, add) or_return
buf := strings.builder_make(0, scap) or_return
defer if err != nil { strings.builder_destroy(&buf) }
buf_stream := strings.to_stream(&buf)
n_items := _n_items.? or_else INITIAL_STREAMED_BYTES_CAPACITY
if length_is_unknown {
buf: strings.Builder
buf.buf = make([dynamic]byte, 0, n_items) or_return
defer if err != nil { strings.builder_destroy(&buf) }
buf_stream := strings.to_stream(&buf)
for {
header := _decode_header(r) or_return
if n == -1 {
indefinite_loop: for {
header := _decode_header(d.reader) or_return
maj, add := _header_split(header)
#partial switch maj {
case type:
_n_items, length_is_unknown := _decode_container_length(r, add) or_return
if length_is_unknown {
iter_n, iter_cap := _decode_len_str(d, add) or_return
if iter_n == -1 {
return nil, .Nested_Indefinite_Length
}
n_items := i64(_n_items.?)
reserve(&buf.buf, len(buf.buf) + iter_cap) or_return
io.copy_n(buf_stream, d.reader, i64(iter_n)) or_return
copied := io.copy_n(buf_stream, r, n_items) or_return
assert(copied == n_items)
case .Other:
if add != .Break { return nil, .Bad_Argument }
v = buf.buf[:]
// Write zero byte so this can be converted to cstring.
io.write_full(buf_stream, {0}) or_return
shrink(&buf.buf) // Ignoring error, this is not critical to succeed.
return
break indefinite_loop
case:
return nil, .Bad_Major
}
}
} else {
v = make([]byte, n_items + 1) or_return // Space for the bytes and a zero byte.
defer if err != nil { delete(v) }
io.read_full(r, v[:n_items]) or_return
v = v[:n_items] // Take off zero byte.
return
io.copy_n(buf_stream, d.reader, i64(n)) or_return
}
v = buf.buf[:]
// Write zero byte so this can be converted to cstring.
strings.write_byte(&buf, 0)
if .Shrink_Excess in d.flags { shrink(&buf.buf) }
return
}
_encode_bytes :: proc(e: Encoder, val: Bytes, major: Major = .Bytes) -> (err: Encode_Error) {
@@ -383,43 +428,41 @@ _encode_bytes :: proc(e: Encoder, val: Bytes, major: Major = .Bytes) -> (err: En
return
}
_decode_text_ptr :: proc(r: io.Reader, add: Add) -> (v: ^Text, err: Decode_Error) {
_decode_text_ptr :: proc(d: Decoder, add: Add) -> (v: ^Text, err: Decode_Error) {
v = new(Text) or_return
defer if err != nil { free(v) }
v^ = _decode_text(r, add) or_return
v^ = _decode_text(d, add) or_return
return
}
_decode_text :: proc(r: io.Reader, add: Add) -> (v: Text, err: Decode_Error) {
return (Text)(_decode_bytes(r, add, .Text) or_return), nil
_decode_text :: proc(d: Decoder, add: Add) -> (v: Text, err: Decode_Error) {
return (Text)(_decode_bytes(d, add, .Text) or_return), nil
}
_encode_text :: proc(e: Encoder, val: Text) -> Encode_Error {
return _encode_bytes(e, transmute([]byte)val, .Text)
}
_decode_array_ptr :: proc(r: io.Reader, add: Add) -> (v: ^Array, err: Decode_Error) {
_decode_array_ptr :: proc(d: Decoder, add: Add) -> (v: ^Array, err: Decode_Error) {
v = new(Array) or_return
defer if err != nil { free(v) }
v^ = _decode_array(r, add) or_return
v^ = _decode_array(d, add) or_return
return
}
_decode_array :: proc(r: io.Reader, add: Add) -> (v: Array, err: Decode_Error) {
_n_items, length_is_unknown := _decode_container_length(r, add) or_return
n_items := _n_items.? or_else INITIAL_STREAMED_CONTAINER_CAPACITY
array := make([dynamic]Value, 0, n_items) or_return
_decode_array :: proc(d: Decoder, add: Add) -> (v: Array, err: Decode_Error) {
n, scap := _decode_len_container(d, add) or_return
array := make([dynamic]Value, 0, scap) or_return
defer if err != nil {
for entry in array { destroy(entry) }
delete(array)
}
for i := 0; length_is_unknown || i < n_items; i += 1 {
val, verr := decode(r)
if length_is_unknown && verr == .Break {
for i := 0; n == -1 || i < n; i += 1 {
val, verr := _decode_from_decoder(d)
if n == -1 && verr == .Break {
break
} else if verr != nil {
err = verr
@@ -428,8 +471,9 @@ _decode_array :: proc(r: io.Reader, add: Add) -> (v: Array, err: Decode_Error) {
append(&array, val) or_return
}
if .Shrink_Excess in d.flags { shrink(&array) }
shrink(&array)
v = array[:]
return
}
@@ -443,19 +487,17 @@ _encode_array :: proc(e: Encoder, arr: Array) -> Encode_Error {
return nil
}
_decode_map_ptr :: proc(r: io.Reader, add: Add) -> (v: ^Map, err: Decode_Error) {
_decode_map_ptr :: proc(d: Decoder, add: Add) -> (v: ^Map, err: Decode_Error) {
v = new(Map) or_return
defer if err != nil { free(v) }
v^ = _decode_map(r, add) or_return
v^ = _decode_map(d, add) or_return
return
}
_decode_map :: proc(r: io.Reader, add: Add) -> (v: Map, err: Decode_Error) {
_n_items, length_is_unknown := _decode_container_length(r, add) or_return
n_items := _n_items.? or_else INITIAL_STREAMED_CONTAINER_CAPACITY
items := make([dynamic]Map_Entry, 0, n_items) or_return
_decode_map :: proc(d: Decoder, add: Add) -> (v: Map, err: Decode_Error) {
n, scap := _decode_len_container(d, add) or_return
items := make([dynamic]Map_Entry, 0, scap) or_return
defer if err != nil {
for entry in items {
destroy(entry.key)
@@ -464,23 +506,24 @@ _decode_map :: proc(r: io.Reader, add: Add) -> (v: Map, err: Decode_Error) {
delete(items)
}
for i := 0; length_is_unknown || i < n_items; i += 1 {
key, kerr := decode(r)
if length_is_unknown && kerr == .Break {
for i := 0; n == -1 || i < n; i += 1 {
key, kerr := _decode_from_decoder(d)
if n == -1 && kerr == .Break {
break
} else if kerr != nil {
return nil, kerr
}
value := decode(r) or_return
value := decode_from_decoder(d) or_return
append(&items, Map_Entry{
key = key,
value = value,
}) or_return
}
if .Shrink_Excess in d.flags { shrink(&items) }
shrink(&items)
v = items[:]
return
}
@@ -537,8 +580,8 @@ _encode_map :: proc(e: Encoder, m: Map) -> (err: Encode_Error) {
return nil
}
_decode_tag_ptr :: proc(r: io.Reader, add: Add) -> (v: Value, err: Decode_Error) {
tag := _decode_tag(r, add) or_return
_decode_tag_ptr :: proc(d: Decoder, add: Add) -> (v: Value, err: Decode_Error) {
tag := _decode_tag(d, add) or_return
if t, ok := tag.?; ok {
defer if err != nil { destroy(t.value) }
tp := new(Tag) or_return
@@ -547,11 +590,11 @@ _decode_tag_ptr :: proc(r: io.Reader, add: Add) -> (v: Value, err: Decode_Error)
}
// no error, no tag, this was the self described CBOR tag, skip it.
return decode(r)
return _decode_from_decoder(d)
}
_decode_tag :: proc(r: io.Reader, add: Add) -> (v: Maybe(Tag), err: Decode_Error) {
num := _decode_tag_nr(r, add) or_return
_decode_tag :: proc(d: Decoder, add: Add) -> (v: Maybe(Tag), err: Decode_Error) {
num := _decode_uint_as_u64(d.reader, add) or_return
// CBOR can be wrapped in a tag that decoders can use to see/check if the binary data is CBOR.
// We can ignore it here.
@@ -561,7 +604,7 @@ _decode_tag :: proc(r: io.Reader, add: Add) -> (v: Maybe(Tag), err: Decode_Error
t := Tag{
number = num,
value = decode(r) or_return,
value = _decode_from_decoder(d) or_return,
}
if nested, ok := t.value.(^Tag); ok {
@@ -572,7 +615,7 @@ _decode_tag :: proc(r: io.Reader, add: Add) -> (v: Maybe(Tag), err: Decode_Error
return t, nil
}
_decode_tag_nr :: proc(r: io.Reader, add: Add) -> (nr: Tag_Number, err: Decode_Error) {
_decode_uint_as_u64 :: proc(r: io.Reader, add: Add) -> (nr: u64, err: Decode_Error) {
#partial switch add {
case .One_Byte: return u64(_decode_u8(r) or_return), nil
case .Two_Bytes: return u64(_decode_u16(r) or_return), nil
@@ -719,30 +762,50 @@ encode_stream_map_entry :: proc(e: Encoder, key: Value, val: Value) -> Encode_Er
return encode(e, val)
}
//
_decode_container_length :: proc(r: io.Reader, add: Add) -> (length: Maybe(int), is_unknown: bool, err: Decode_Error) {
if add == Add.Length_Unknown { return nil, true, nil }
#partial switch add {
case .One_Byte: length = int(_decode_u8(r) or_return)
case .Two_Bytes: length = int(_decode_u16(r) or_return)
case .Four_Bytes:
big_length := _decode_u32(r) or_return
if u64(big_length) > u64(max(int)) {
err = .Length_Too_Big
return
// For `Bytes` and `Text` strings: Decodes the number of items the header says follows.
// If the number is not specified -1 is returned and streaming should be initiated.
// A suitable starting capacity is also returned for a buffer that is allocated up the stack.
_decode_len_str :: proc(d: Decoder, add: Add) -> (n: int, scap: int, err: Decode_Error) {
if add == .Length_Unknown {
if .Disallow_Streaming in d.flags {
return -1, -1, .Disallowed_Streaming
}
length = int(big_length)
case .Eight_Bytes:
big_length := _decode_u64(r) or_return
if big_length > u64(max(int)) {
err = .Length_Too_Big
return
}
length = int(big_length)
case:
length = int(_decode_tiny_u8(add) or_return)
return -1, INITIAL_STREAMED_BYTES_CAPACITY, nil
}
_n := _decode_uint_as_u64(d.reader, add) or_return
if _n > u64(max(int)) { return -1, -1, .Length_Too_Big }
n = int(_n)
scap = n + 1 // Space for zero byte.
if .Trusted_Input not_in d.flags {
scap = min(d.max_pre_alloc, scap)
}
return
}
// For `Array` and `Map` types: Decodes the number of items the header says follows.
// If the number is not specified -1 is returned and streaming should be initiated.
// A suitable starting capacity is also returned for a buffer that is allocated up the stack.
_decode_len_container :: proc(d: Decoder, add: Add) -> (n: int, scap: int, err: Decode_Error) {
if add == .Length_Unknown {
if .Disallow_Streaming in d.flags {
return -1, -1, .Disallowed_Streaming
}
return -1, INITIAL_STREAMED_CONTAINER_CAPACITY, nil
}
_n := _decode_uint_as_u64(d.reader, add) or_return
if _n > u64(max(int)) { return -1, -1, .Length_Too_Big }
n = int(_n)
scap = n
if .Trusted_Input not_in d.flags {
// NOTE: if this is a map it will be twice this.
scap = min(d.max_pre_alloc / size_of(Value), scap)
}
return
}

View File

@@ -55,7 +55,7 @@ Tag_Implementation :: struct {
}
// Procedure responsible for umarshalling the tag out of the reader into the given `any`.
Tag_Unmarshal_Proc :: #type proc(self: ^Tag_Implementation, r: io.Reader, tag_nr: Tag_Number, v: any) -> Unmarshal_Error
Tag_Unmarshal_Proc :: #type proc(self: ^Tag_Implementation, d: Decoder, tag_nr: Tag_Number, v: any) -> Unmarshal_Error
// Procedure responsible for marshalling the tag in the given `any` into the given encoder.
Tag_Marshal_Proc :: #type proc(self: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_Error
@@ -121,30 +121,30 @@ tags_register_defaults :: proc() {
//
// See RFC 8949 section 3.4.2.
@(private)
tag_time_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, _: Tag_Number, v: any) -> (err: Unmarshal_Error) {
hdr := _decode_header(r) or_return
tag_time_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, _: Tag_Number, v: any) -> (err: Unmarshal_Error) {
hdr := _decode_header(d.reader) or_return
#partial switch hdr {
case .U8, .U16, .U32, .U64, .Neg_U8, .Neg_U16, .Neg_U32, .Neg_U64:
switch &dst in v {
case time.Time:
i: i64
_unmarshal_any_ptr(r, &i, hdr) or_return
_unmarshal_any_ptr(d, &i, hdr) or_return
dst = time.unix(i64(i), 0)
return
case:
return _unmarshal_value(r, v, hdr)
return _unmarshal_value(d, v, hdr)
}
case .F16, .F32, .F64:
switch &dst in v {
case time.Time:
f: f64
_unmarshal_any_ptr(r, &f, hdr) or_return
_unmarshal_any_ptr(d, &f, hdr) or_return
whole, fract := math.modf(f)
dst = time.unix(i64(whole), i64(fract * 1e9))
return
case:
return _unmarshal_value(r, v, hdr)
return _unmarshal_value(d, v, hdr)
}
case:
@@ -182,8 +182,8 @@ tag_time_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_
}
@(private)
tag_big_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, tnr: Tag_Number, v: any) -> (err: Unmarshal_Error) {
hdr := _decode_header(r) or_return
tag_big_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, tnr: Tag_Number, v: any) -> (err: Unmarshal_Error) {
hdr := _decode_header(d.reader) or_return
maj, add := _header_split(hdr)
if maj != .Bytes {
// Only bytes are supported in this tag.
@@ -192,7 +192,7 @@ tag_big_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, tnr: Tag_Number,
switch &dst in v {
case big.Int:
bytes := err_conv(_decode_bytes(r, add)) or_return
bytes := err_conv(_decode_bytes(d, add)) or_return
defer delete(bytes)
if err := big.int_from_bytes_big(&dst, bytes); err != nil {
@@ -246,13 +246,13 @@ tag_big_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_E
}
@(private)
tag_cbor_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, _: Tag_Number, v: any) -> Unmarshal_Error {
hdr := _decode_header(r) or_return
tag_cbor_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, _: Tag_Number, v: any) -> Unmarshal_Error {
hdr := _decode_header(d.reader) or_return
major, add := _header_split(hdr)
#partial switch major {
case .Bytes:
ti := reflect.type_info_base(type_info_of(v.id))
return _unmarshal_bytes(r, v, ti, hdr, add)
return _unmarshal_bytes(d, v, ti, hdr, add)
case: return .Bad_Tag_Value
}
@@ -283,8 +283,8 @@ tag_cbor_marshal :: proc(_: ^Tag_Implementation, e: Encoder, v: any) -> Marshal_
}
@(private)
tag_base64_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, _: Tag_Number, v: any) -> (err: Unmarshal_Error) {
hdr := _decode_header(r) or_return
tag_base64_unmarshal :: proc(_: ^Tag_Implementation, d: Decoder, _: Tag_Number, v: any) -> (err: Unmarshal_Error) {
hdr := _decode_header(d.reader) or_return
major, add := _header_split(hdr)
ti := reflect.type_info_base(type_info_of(v.id))
@@ -294,7 +294,7 @@ tag_base64_unmarshal :: proc(_: ^Tag_Implementation, r: io.Reader, _: Tag_Number
bytes: string; {
context.allocator = context.temp_allocator
bytes = string(err_conv(_decode_bytes(r, add)) or_return)
bytes = string(err_conv(_decode_bytes(d, add)) or_return)
}
defer delete(bytes, context.temp_allocator)

View File

@@ -15,25 +15,56 @@ Types that require allocation are allocated using the given allocator.
Some temporary allocations are done on the `context.temp_allocator`, but, if you want to,
this can be set to a "normal" allocator, because the necessary `delete` and `free` calls are still made.
This is helpful when the CBOR size is so big that you don't want to collect all the temporary allocations until the end.
Disable streaming/indeterminate lengths with the `.Disallow_Streaming` flag.
Shrink excess bytes in buffers and containers with the `.Shrink_Excess` flag.
Mark the input as trusted input with the `.Trusted_Input` flag, this turns off the safety feature
of not pre-allocating more than `max_pre_alloc` bytes before reading into the bytes. You should only
do this when you own both sides of the encoding and are sure there can't be malicious bytes used as
an input.
*/
unmarshal :: proc {
unmarshal_from_reader,
unmarshal_from_string,
}
// Unmarshals from a reader, see docs on the proc group `Unmarshal` for more info.
unmarshal_from_reader :: proc(r: io.Reader, ptr: ^$T, allocator := context.allocator) -> Unmarshal_Error {
return _unmarshal_any_ptr(r, ptr, allocator=allocator)
unmarshal_from_reader :: proc(r: io.Reader, ptr: ^$T, flags := Decoder_Flags{}, allocator := context.allocator) -> (err: Unmarshal_Error) {
err = unmarshal_from_decoder(Decoder{ DEFAULT_MAX_PRE_ALLOC, flags, r }, ptr, allocator=allocator)
// Normal EOF does not exist here, we try to read the exact amount that is said to be provided.
if err == .EOF { err = .Unexpected_EOF }
return
}
// Unmarshals from a string, see docs on the proc group `Unmarshal` for more info.
unmarshal_from_string :: proc(s: string, ptr: ^$T, allocator := context.allocator) -> Unmarshal_Error {
unmarshal_from_string :: proc(s: string, ptr: ^$T, flags := Decoder_Flags{}, allocator := context.allocator) -> (err: Unmarshal_Error) {
sr: strings.Reader
r := strings.to_reader(&sr, s)
return _unmarshal_any_ptr(r, ptr, allocator=allocator)
err = unmarshal_from_reader(r, ptr, flags, allocator)
// Normal EOF does not exist here, we try to read the exact amount that is said to be provided.
if err == .EOF { err = .Unexpected_EOF }
return
}
_unmarshal_any_ptr :: proc(r: io.Reader, v: any, hdr: Maybe(Header) = nil, allocator := context.allocator) -> Unmarshal_Error {
unmarshal_from_decoder :: proc(d: Decoder, ptr: ^$T, allocator := context.allocator) -> (err: Unmarshal_Error) {
d := d
if d.max_pre_alloc <= 0 {
d.max_pre_alloc = DEFAULT_MAX_PRE_ALLOC
}
err = _unmarshal_any_ptr(d, ptr, allocator=allocator)
// Normal EOF does not exist here, we try to read the exact amount that is said to be provided.
if err == .EOF { err = .Unexpected_EOF }
return
}
_unmarshal_any_ptr :: proc(d: Decoder, v: any, hdr: Maybe(Header) = nil, allocator := context.allocator) -> Unmarshal_Error {
context.allocator = allocator
v := v
@@ -48,12 +79,13 @@ _unmarshal_any_ptr :: proc(r: io.Reader, v: any, hdr: Maybe(Header) = nil, alloc
}
data := any{(^rawptr)(v.data)^, ti.variant.(reflect.Type_Info_Pointer).elem.id}
return _unmarshal_value(r, data, hdr.? or_else (_decode_header(r) or_return))
return _unmarshal_value(d, data, hdr.? or_else (_decode_header(d.reader) or_return))
}
_unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_Error) {
_unmarshal_value :: proc(d: Decoder, v: any, hdr: Header) -> (err: Unmarshal_Error) {
v := v
ti := reflect.type_info_base(type_info_of(v.id))
r := d.reader
// If it's a union with only one variant, then treat it as that variant
if u, ok := ti.variant.(reflect.Type_Info_Union); ok && len(u.variants) == 1 {
@@ -73,7 +105,7 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
// Allow generic unmarshal by doing it into a `Value`.
switch &dst in v {
case Value:
dst = err_conv(decode(r, hdr)) or_return
dst = err_conv(_decode_from_decoder(d, hdr)) or_return
return
}
@@ -253,7 +285,7 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
case .Tag:
switch &dst in v {
case ^Tag:
tval := err_conv(_decode_tag_ptr(r, add)) or_return
tval := err_conv(_decode_tag_ptr(d, add)) or_return
if t, is_tag := tval.(^Tag); is_tag {
dst = t
return
@@ -262,7 +294,7 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
destroy(tval)
return .Bad_Tag_Value
case Tag:
t := err_conv(_decode_tag(r, add)) or_return
t := err_conv(_decode_tag(d, add)) or_return
if t, is_tag := t.?; is_tag {
dst = t
return
@@ -271,33 +303,33 @@ _unmarshal_value :: proc(r: io.Reader, v: any, hdr: Header) -> (err: Unmarshal_E
return .Bad_Tag_Value
}
nr := err_conv(_decode_tag_nr(r, add)) or_return
nr := err_conv(_decode_uint_as_u64(r, add)) or_return
// Custom tag implementations.
if impl, ok := _tag_implementations_nr[nr]; ok {
return impl->unmarshal(r, nr, v)
return impl->unmarshal(d, nr, v)
} else if nr == TAG_OBJECT_TYPE {
return _unmarshal_union(r, v, ti, hdr)
return _unmarshal_union(d, v, ti, hdr)
} else {
// Discard the tag info and unmarshal as its value.
return _unmarshal_value(r, v, _decode_header(r) or_return)
return _unmarshal_value(d, v, _decode_header(r) or_return)
}
return _unsupported(v, hdr, add)
case .Bytes: return _unmarshal_bytes(r, v, ti, hdr, add)
case .Text: return _unmarshal_string(r, v, ti, hdr, add)
case .Array: return _unmarshal_array(r, v, ti, hdr, add)
case .Map: return _unmarshal_map(r, v, ti, hdr, add)
case .Bytes: return _unmarshal_bytes(d, v, ti, hdr, add)
case .Text: return _unmarshal_string(d, v, ti, hdr, add)
case .Array: return _unmarshal_array(d, v, ti, hdr, add)
case .Map: return _unmarshal_map(d, v, ti, hdr, add)
case: return .Bad_Major
}
}
_unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
_unmarshal_bytes :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
#partial switch t in ti.variant {
case reflect.Type_Info_String:
bytes := err_conv(_decode_bytes(r, add)) or_return
bytes := err_conv(_decode_bytes(d, add)) or_return
if t.is_cstring {
raw := (^cstring)(v.data)
@@ -316,7 +348,7 @@ _unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
if elem_base.id != byte { return _unsupported(v, hdr) }
bytes := err_conv(_decode_bytes(r, add)) or_return
bytes := err_conv(_decode_bytes(d, add)) or_return
raw := (^mem.Raw_Slice)(v.data)
raw^ = transmute(mem.Raw_Slice)bytes
return
@@ -326,7 +358,7 @@ _unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
if elem_base.id != byte { return _unsupported(v, hdr) }
bytes := err_conv(_decode_bytes(r, add)) or_return
bytes := err_conv(_decode_bytes(d, add)) or_return
raw := (^mem.Raw_Dynamic_Array)(v.data)
raw.data = raw_data(bytes)
raw.len = len(bytes)
@@ -339,11 +371,9 @@ _unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
if elem_base.id != byte { return _unsupported(v, hdr) }
bytes: []byte; {
context.allocator = context.temp_allocator
bytes = err_conv(_decode_bytes(r, add)) or_return
}
defer delete(bytes, context.temp_allocator)
context.allocator = context.temp_allocator
bytes := err_conv(_decode_bytes(d, add)) or_return
defer delete(bytes)
if len(bytes) > t.count { return _unsupported(v, hdr) }
@@ -357,10 +387,10 @@ _unmarshal_bytes :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
return _unsupported(v, hdr)
}
_unmarshal_string :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
_unmarshal_string :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
#partial switch t in ti.variant {
case reflect.Type_Info_String:
text := err_conv(_decode_text(r, add)) or_return
text := err_conv(_decode_text(d, add)) or_return
if t.is_cstring {
raw := (^cstring)(v.data)
@@ -376,8 +406,8 @@ _unmarshal_string :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Hea
// Enum by its variant name.
case reflect.Type_Info_Enum:
context.allocator = context.temp_allocator
text := err_conv(_decode_text(r, add)) or_return
defer delete(text, context.temp_allocator)
text := err_conv(_decode_text(d, add)) or_return
defer delete(text)
for name, i in t.names {
if name == text {
@@ -388,8 +418,8 @@ _unmarshal_string :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Hea
case reflect.Type_Info_Rune:
context.allocator = context.temp_allocator
text := err_conv(_decode_text(r, add)) or_return
defer delete(text, context.temp_allocator)
text := err_conv(_decode_text(d, add)) or_return
defer delete(text)
r := (^rune)(v.data)
dr, n := utf8.decode_rune(text)
@@ -404,21 +434,19 @@ _unmarshal_string :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Hea
return _unsupported(v, hdr)
}
_unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
_unmarshal_array :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
assign_array :: proc(
r: io.Reader,
d: Decoder,
da: ^mem.Raw_Dynamic_Array,
elemt: ^reflect.Type_Info,
_length: Maybe(int),
length: int,
growable := true,
) -> (out_of_space: bool, err: Unmarshal_Error) {
length, has_length := _length.?
for idx: uintptr = 0; !has_length || idx < uintptr(length); idx += 1 {
for idx: uintptr = 0; length == -1 || idx < uintptr(length); idx += 1 {
elem_ptr := rawptr(uintptr(da.data) + idx*uintptr(elemt.size))
elem := any{elem_ptr, elemt.id}
hdr := _decode_header(r) or_return
hdr := _decode_header(d.reader) or_return
// Double size if out of capacity.
if da.cap <= da.len {
@@ -432,8 +460,8 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
if !ok { return false, .Out_Of_Memory }
}
err = _unmarshal_value(r, elem, hdr)
if !has_length && err == .Break { break }
err = _unmarshal_value(d, elem, hdr)
if length == -1 && err == .Break { break }
if err != nil { return }
da.len += 1
@@ -445,26 +473,25 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
// Allow generically storing the values array.
switch &dst in v {
case ^Array:
dst = err_conv(_decode_array_ptr(r, add)) or_return
dst = err_conv(_decode_array_ptr(d, add)) or_return
return
case Array:
dst = err_conv(_decode_array(r, add)) or_return
dst = err_conv(_decode_array(d, add)) or_return
return
}
#partial switch t in ti.variant {
case reflect.Type_Info_Slice:
_length, unknown := err_conv(_decode_container_length(r, add)) or_return
length := _length.? or_else INITIAL_STREAMED_CONTAINER_CAPACITY
length, scap := err_conv(_decode_len_container(d, add)) or_return
data := mem.alloc_bytes_non_zeroed(t.elem.size * length, t.elem.align) or_return
data := mem.alloc_bytes_non_zeroed(t.elem.size * scap, t.elem.align) or_return
defer if err != nil { mem.free_bytes(data) }
da := mem.Raw_Dynamic_Array{raw_data(data), 0, length, context.allocator }
assign_array(r, &da, t.elem, _length) or_return
assign_array(d, &da, t.elem, length) or_return
if da.len < da.cap {
if .Shrink_Excess in d.flags {
// Ignoring an error here, but this is not critical to succeed.
_ = runtime.__dynamic_array_shrink(&da, t.elem.size, t.elem.align, da.len)
}
@@ -475,54 +502,58 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
return
case reflect.Type_Info_Dynamic_Array:
_length, unknown := err_conv(_decode_container_length(r, add)) or_return
length := _length.? or_else INITIAL_STREAMED_CONTAINER_CAPACITY
length, scap := err_conv(_decode_len_container(d, add)) or_return
data := mem.alloc_bytes_non_zeroed(t.elem.size * length, t.elem.align) or_return
data := mem.alloc_bytes_non_zeroed(t.elem.size * scap, t.elem.align) or_return
defer if err != nil { mem.free_bytes(data) }
raw := (^mem.Raw_Dynamic_Array)(v.data)
raw.data = raw_data(data)
raw.len = 0
raw.cap = length
raw.allocator = context.allocator
raw := (^mem.Raw_Dynamic_Array)(v.data)
raw.data = raw_data(data)
raw.len = 0
raw.cap = length
raw.allocator = context.allocator
_ = assign_array(r, raw, t.elem, _length) or_return
_ = assign_array(d, raw, t.elem, length) or_return
if .Shrink_Excess in d.flags {
// Ignoring an error here, but this is not critical to succeed.
_ = runtime.__dynamic_array_shrink(raw, t.elem.size, t.elem.align, raw.len)
}
return
case reflect.Type_Info_Array:
_length, unknown := err_conv(_decode_container_length(r, add)) or_return
length := _length.? or_else t.count
_length, scap := err_conv(_decode_len_container(d, add)) or_return
length := min(scap, t.count)
if !unknown && length > t.count {
if length > t.count {
return _unsupported(v, hdr)
}
da := mem.Raw_Dynamic_Array{rawptr(v.data), 0, length, context.allocator }
out_of_space := assign_array(r, &da, t.elem, _length, growable=false) or_return
out_of_space := assign_array(d, &da, t.elem, length, growable=false) or_return
if out_of_space { return _unsupported(v, hdr) }
return
case reflect.Type_Info_Enumerated_Array:
_length, unknown := err_conv(_decode_container_length(r, add)) or_return
length := _length.? or_else t.count
_length, scap := err_conv(_decode_len_container(d, add)) or_return
length := min(scap, t.count)
if !unknown && length > t.count {
if length > t.count {
return _unsupported(v, hdr)
}
da := mem.Raw_Dynamic_Array{rawptr(v.data), 0, length, context.allocator }
out_of_space := assign_array(r, &da, t.elem, _length, growable=false) or_return
out_of_space := assign_array(d, &da, t.elem, length, growable=false) or_return
if out_of_space { return _unsupported(v, hdr) }
return
case reflect.Type_Info_Complex:
_length, unknown := err_conv(_decode_container_length(r, add)) or_return
length := _length.? or_else 2
_length, scap := err_conv(_decode_len_container(d, add)) or_return
length := min(scap, 2)
if !unknown && length > 2 {
if length > 2 {
return _unsupported(v, hdr)
}
@@ -536,15 +567,15 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
case: unreachable()
}
out_of_space := assign_array(r, &da, info, 2, growable=false) or_return
out_of_space := assign_array(d, &da, info, 2, growable=false) or_return
if out_of_space { return _unsupported(v, hdr) }
return
case reflect.Type_Info_Quaternion:
_length, unknown := err_conv(_decode_container_length(r, add)) or_return
length := _length.? or_else 4
_length, scap := err_conv(_decode_len_container(d, add)) or_return
length := min(scap, 4)
if !unknown && length > 4 {
if length > 4 {
return _unsupported(v, hdr)
}
@@ -558,7 +589,7 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
case: unreachable()
}
out_of_space := assign_array(r, &da, info, 4, growable=false) or_return
out_of_space := assign_array(d, &da, info, 4, growable=false) or_return
if out_of_space { return _unsupported(v, hdr) }
return
@@ -566,17 +597,17 @@ _unmarshal_array :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
}
}
_unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
decode_key :: proc(r: io.Reader, v: any) -> (k: string, err: Unmarshal_Error) {
entry_hdr := _decode_header(r) or_return
_unmarshal_map :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header, add: Add) -> (err: Unmarshal_Error) {
r := d.reader
decode_key :: proc(d: Decoder, v: any) -> (k: string, err: Unmarshal_Error) {
entry_hdr := _decode_header(d.reader) or_return
entry_maj, entry_add := _header_split(entry_hdr)
#partial switch entry_maj {
case .Text:
k = err_conv(_decode_text(r, entry_add)) or_return
k = err_conv(_decode_text(d, entry_add)) or_return
return
case .Bytes:
bytes := err_conv(_decode_bytes(r, entry_add)) or_return
bytes := err_conv(_decode_bytes(d, entry_add)) or_return
k = string(bytes)
return
case:
@@ -588,10 +619,10 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
// Allow generically storing the map array.
switch &dst in v {
case ^Map:
dst = err_conv(_decode_map_ptr(r, add)) or_return
dst = err_conv(_decode_map_ptr(d, add)) or_return
return
case Map:
dst = err_conv(_decode_map(r, add)) or_return
dst = err_conv(_decode_map(d, add)) or_return
return
}
@@ -601,14 +632,15 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
return _unsupported(v, hdr)
}
length, unknown := err_conv(_decode_container_length(r, add)) or_return
length, scap := err_conv(_decode_len_container(d, add)) or_return
unknown := length == -1
fields := reflect.struct_fields_zipped(ti.id)
for idx := 0; unknown || idx < length.?; idx += 1 {
for idx := 0; idx < len(fields) && (unknown || idx < length); idx += 1 {
// Decode key, keys can only be strings.
key: string; {
context.allocator = context.temp_allocator
if keyv, kerr := decode_key(r, v); unknown && kerr == .Break {
if keyv, kerr := decode_key(d, v); unknown && kerr == .Break {
break
} else if kerr != nil {
err = kerr
@@ -641,11 +673,11 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
}
}
field := fields[use_field_idx]
name := field.name
ptr := rawptr(uintptr(v.data) + field.offset)
fany := any{ptr, field.type.id}
_unmarshal_value(r, fany, _decode_header(r) or_return) or_return
field := fields[use_field_idx]
name := field.name
ptr := rawptr(uintptr(v.data) + field.offset)
fany := any{ptr, field.type.id}
_unmarshal_value(d, fany, _decode_header(r) or_return) or_return
}
return
@@ -654,6 +686,8 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
return _unsupported(v, hdr)
}
// TODO: shrink excess.
raw_map := (^mem.Raw_Map)(v.data)
if raw_map.allocator.procedure == nil {
raw_map.allocator = context.allocator
@@ -663,10 +697,11 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
_ = runtime.map_free_dynamic(raw_map^, t.map_info)
}
length, unknown := err_conv(_decode_container_length(r, add)) or_return
length, scap := err_conv(_decode_len_container(d, add)) or_return
unknown := length == -1
if !unknown {
// Reserve space before setting so we can return allocation errors and be efficient on big maps.
new_len := uintptr(runtime.map_len(raw_map^)+length.?)
new_len := uintptr(min(scap, runtime.map_len(raw_map^)+length))
runtime.map_reserve_dynamic(raw_map, t.map_info, new_len) or_return
}
@@ -676,10 +711,10 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
map_backing_value := any{raw_data(elem_backing), t.value.id}
for idx := 0; unknown || idx < length.?; idx += 1 {
for idx := 0; unknown || idx < length; idx += 1 {
// Decode key, keys can only be strings.
key: string
if keyv, kerr := decode_key(r, v); unknown && kerr == .Break {
if keyv, kerr := decode_key(d, v); unknown && kerr == .Break {
break
} else if kerr != nil {
err = kerr
@@ -688,14 +723,14 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
key = keyv
}
if unknown {
if unknown || idx > scap {
// Reserve space for new element so we can return allocator errors.
new_len := uintptr(runtime.map_len(raw_map^)+1)
runtime.map_reserve_dynamic(raw_map, t.map_info, new_len) or_return
}
mem.zero_slice(elem_backing)
_unmarshal_value(r, map_backing_value, _decode_header(r) or_return) or_return
_unmarshal_value(d, map_backing_value, _decode_header(r) or_return) or_return
key_ptr := rawptr(&key)
key_cstr: cstring
@@ -709,6 +744,10 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
// We already reserved space for it, so this shouldn't fail.
assert(set_ptr != nil)
}
if .Shrink_Excess in d.flags {
_, _ = runtime.map_shrink_dynamic(raw_map, t.map_info)
}
return
case:
@@ -719,7 +758,8 @@ _unmarshal_map :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header
// Unmarshal into a union, based on the `TAG_OBJECT_TYPE` tag of the spec, it denotes a tag which
// contains an array of exactly two elements, the first is a textual representation of the following
// CBOR value's type.
_unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Header) -> (err: Unmarshal_Error) {
_unmarshal_union :: proc(d: Decoder, v: any, ti: ^reflect.Type_Info, hdr: Header) -> (err: Unmarshal_Error) {
r := d.reader
#partial switch t in ti.variant {
case reflect.Type_Info_Union:
idhdr: Header
@@ -731,8 +771,8 @@ _unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
return .Bad_Tag_Value
}
n_items, unknown := err_conv(_decode_container_length(r, vadd)) or_return
if unknown || n_items != 2 {
n_items, _ := err_conv(_decode_len_container(d, vadd)) or_return
if n_items != 2 {
return .Bad_Tag_Value
}
@@ -743,7 +783,7 @@ _unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
}
context.allocator = context.temp_allocator
target_name = err_conv(_decode_text(r, idadd)) or_return
target_name = err_conv(_decode_text(d, idadd)) or_return
}
defer delete(target_name, context.temp_allocator)
@@ -757,7 +797,7 @@ _unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
case reflect.Type_Info_Named:
if vti.name == target_name {
reflect.set_union_variant_raw_tag(v, tag)
return _unmarshal_value(r, any{v.data, variant.id}, _decode_header(r) or_return)
return _unmarshal_value(d, any{v.data, variant.id}, _decode_header(r) or_return)
}
case:
@@ -769,7 +809,7 @@ _unmarshal_union :: proc(r: io.Reader, v: any, ti: ^reflect.Type_Info, hdr: Head
if variant_name == target_name {
reflect.set_union_variant_raw_tag(v, tag)
return _unmarshal_value(r, any{v.data, variant.id}, _decode_header(r) or_return)
return _unmarshal_value(d, any{v.data, variant.id}, _decode_header(r) or_return)
}
}
}

View File

@@ -4,6 +4,7 @@ import "core:bytes"
import "core:encoding/cbor"
import "core:fmt"
import "core:intrinsics"
import "core:io"
import "core:math/big"
import "core:mem"
import "core:os"
@@ -61,7 +62,9 @@ main :: proc() {
test_marshalling_maybe(&t)
test_marshalling_nil_maybe(&t)
test_cbor_marshalling_union(&t)
test_marshalling_union(&t)
test_lying_length_array(&t)
test_decode_unsigned(&t)
test_encode_unsigned(&t)
@@ -202,7 +205,7 @@ test_marshalling :: proc(t: ^testing.T) {
ev(t, err, nil)
defer delete(data)
decoded, derr := cbor.decode_string(string(data))
decoded, derr := cbor.decode(string(data))
ev(t, derr, nil)
defer cbor.destroy(decoded)
@@ -398,7 +401,7 @@ test_marshalling_nil_maybe :: proc(t: ^testing.T) {
}
@(test)
test_cbor_marshalling_union :: proc(t: ^testing.T) {
test_marshalling_union :: proc(t: ^testing.T) {
My_Distinct :: distinct string
My_Enum :: enum {
@@ -457,6 +460,14 @@ test_cbor_marshalling_union :: proc(t: ^testing.T) {
}
}
@(test)
test_lying_length_array :: proc(t: ^testing.T) {
// Input says this is an array of length max(u64), this should not allocate that amount.
input := []byte{0x9B, 0x00, 0x00, 0x42, 0xFA, 0x42, 0xFA, 0x42, 0xFA, 0x42}
_, err := cbor.decode(string(input))
expect_value(t, err, io.Error.Unexpected_EOF) // .Out_Of_Memory would be bad.
}
@(test)
test_decode_unsigned :: proc(t: ^testing.T) {
expect_decoding(t, "\x00", "0", u8)