diff --git a/core/crypto/chacha20poly1305/chacha20poly1305.odin b/core/crypto/chacha20poly1305/chacha20poly1305.odin new file mode 100644 index 000000000..67d89df56 --- /dev/null +++ b/core/crypto/chacha20poly1305/chacha20poly1305.odin @@ -0,0 +1,146 @@ +package chacha20poly1305 + +import "core:crypto" +import "core:crypto/chacha20" +import "core:crypto/poly1305" +import "core:crypto/util" +import "core:mem" + +KEY_SIZE :: chacha20.KEY_SIZE +NONCE_SIZE :: chacha20.NONCE_SIZE +TAG_SIZE :: poly1305.TAG_SIZE + +_P_MAX :: 64 * 0xffffffff // 64 * (2^32-1) + +_validate_common_slice_sizes :: proc (tag, key, nonce, aad, text: []byte) { + if len(tag) != TAG_SIZE { + panic("crypto/chacha20poly1305: invalid destination tag size") + } + if len(key) != KEY_SIZE { + panic("crypto/chacha20poly1305: invalid key size") + } + if len(nonce) != NONCE_SIZE { + panic("crypto/chacha20poly1305: invalid nonce size") + } + + #assert(size_of(int) == 8 || size_of(int) <= 4) + when size_of(int) == 8 { + // A_MAX = 2^64 - 1 due to the length field limit. + // P_MAX = 64 * (2^32 - 1) due to the IETF ChaCha20 counter limit. + // + // A_MAX is limited by size_of(int), so there is no need to + // enforce it. P_MAX only needs to be checked on 64-bit targets, + // for reasons that should be obvious. + if text_len := len(text); text_len > _P_MAX { + panic("crypto/chacha20poly1305: oversized src data") + } + } +} + +_PAD: [16]byte +_update_mac_pad16 :: #force_inline proc (ctx: ^poly1305.Context, x_len: int) { + if pad_len := 16 - (x_len & (16-1)); pad_len != 16 { + poly1305.update(ctx, _PAD[:pad_len]) + } +} + +encrypt :: proc (ciphertext, tag, key, nonce, aad, plaintext: []byte) { + _validate_common_slice_sizes(tag, key, nonce, aad, plaintext) + if len(ciphertext) != len(plaintext) { + panic("crypto/chacha20poly1305: invalid destination ciphertext size") + } + + stream_ctx: chacha20.Context = --- + chacha20.init(&stream_ctx, key, nonce) + + // otk = poly1305_key_gen(key, nonce) + otk: [poly1305.KEY_SIZE]byte = --- + chacha20.keystream_bytes(&stream_ctx, otk[:]) + mac_ctx: poly1305.Context = --- + poly1305.init(&mac_ctx, otk[:]) + mem.zero_explicit(&otk, size_of(otk)) + + aad_len, ciphertext_len := len(aad), len(ciphertext) + + // There is nothing preventing aad and ciphertext from overlapping + // so auth the AAD before encrypting (slightly different from the + // RFC, since the RFC encrypts into a new buffer). + // + // mac_data = aad | pad16(aad) + poly1305.update(&mac_ctx, aad) + _update_mac_pad16(&mac_ctx, aad_len) + + // ciphertext = chacha20_encrypt(key, 1, nonce, plaintext) + chacha20.seek(&stream_ctx, 1) + chacha20.xor_bytes(&stream_ctx, ciphertext, plaintext) + chacha20.reset(&stream_ctx) // Don't need the stream context anymore. + + // mac_data |= ciphertext | pad16(ciphertext) + poly1305.update(&mac_ctx, ciphertext) + _update_mac_pad16(&mac_ctx, ciphertext_len) + + // mac_data |= num_to_8_le_bytes(aad.length) + // mac_data |= num_to_8_le_bytes(ciphertext.length) + l_buf := otk[0:16] // Reuse the scratch buffer. + util.PUT_U64_LE(l_buf[0:8], u64(aad_len)) + util.PUT_U64_LE(l_buf[8:16], u64(ciphertext_len)) + poly1305.update(&mac_ctx, l_buf) + + // tag = poly1305_mac(mac_data, otk) + poly1305.final(&mac_ctx, tag) // Implicitly sanitizes context. +} + +decrypt :: proc (plaintext, tag, key, nonce, aad, ciphertext: []byte) -> bool { + _validate_common_slice_sizes(tag, key, nonce, aad, ciphertext) + if len(ciphertext) != len(plaintext) { + panic("crypto/chacha20poly1305: invalid destination plaintext size") + } + + // Note: Unlike encrypt, this can fail early, so use defer for + // sanitization rather than assuming control flow reaches certain + // points where needed. + + stream_ctx: chacha20.Context = --- + chacha20.init(&stream_ctx, key, nonce) + + // otk = poly1305_key_gen(key, nonce) + otk: [poly1305.KEY_SIZE]byte = --- + chacha20.keystream_bytes(&stream_ctx, otk[:]) + defer chacha20.reset(&stream_ctx) + + mac_ctx: poly1305.Context = --- + poly1305.init(&mac_ctx, otk[:]) + defer mem.zero_explicit(&otk, size_of(otk)) + + aad_len, ciphertext_len := len(aad), len(ciphertext) + + // mac_data = aad | pad16(aad) + // mac_data |= ciphertext | pad16(ciphertext) + // mac_data |= num_to_8_le_bytes(aad.length) + // mac_data |= num_to_8_le_bytes(ciphertext.length) + poly1305.update(&mac_ctx, aad) + _update_mac_pad16(&mac_ctx, aad_len) + poly1305.update(&mac_ctx, ciphertext) + _update_mac_pad16(&mac_ctx, ciphertext_len) + l_buf := otk[0:16] // Reuse the scratch buffer. + util.PUT_U64_LE(l_buf[0:8], u64(aad_len)) + util.PUT_U64_LE(l_buf[8:16], u64(ciphertext_len)) + poly1305.update(&mac_ctx, l_buf) + + // tag = poly1305_mac(mac_data, otk) + derived_tag := otk[0:poly1305.TAG_SIZE] // Reuse the scratch buffer again. + poly1305.final(&mac_ctx, derived_tag) // Implicitly sanitizes context. + + // Validate the tag in constant time. + if crypto.compare_constant_time(tag, derived_tag) != 1 { + // Zero out the plaintext, as a defense in depth measure. + mem.zero_explicit(raw_data(plaintext), ciphertext_len) + return false + } + + // plaintext = chacha20_decrypt(key, 1, nonce, ciphertext) + chacha20.seek(&stream_ctx, 1) + chacha20.xor_bytes(&stream_ctx, plaintext, ciphertext) + + return true +} diff --git a/tests/core/crypto/test_core_crypto.odin b/tests/core/crypto/test_core_crypto.odin index b73a191ad..731833096 100644 --- a/tests/core/crypto/test_core_crypto.odin +++ b/tests/core/crypto/test_core_crypto.odin @@ -118,6 +118,7 @@ main :: proc() { // "modern" crypto tests test_chacha20(&t) test_poly1305(&t) + test_chacha20poly1305(&t) test_x25519(&t) bench_modern(&t) diff --git a/tests/core/crypto/test_core_crypto_modern.odin b/tests/core/crypto/test_core_crypto_modern.odin index 45ec8b339..b3d9e47fd 100644 --- a/tests/core/crypto/test_core_crypto_modern.odin +++ b/tests/core/crypto/test_core_crypto_modern.odin @@ -6,6 +6,7 @@ import "core:mem" import "core:time" import "core:crypto/chacha20" +import "core:crypto/chacha20poly1305" import "core:crypto/poly1305" import "core:crypto/x25519" @@ -30,13 +31,14 @@ _decode_hex32 :: proc(s: string) -> [32]byte{ return b } +_PLAINTEXT_SUNSCREEN_STR := "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it." + @(test) test_chacha20 :: proc(t: ^testing.T) { log(t, "Testing (X)ChaCha20") // Test cases taken from RFC 8439, and draft-irtf-cfrg-xchacha-03 - plaintext_str := "Ladies and Gentlemen of the class of '99: If I could offer you only one tip for the future, sunscreen would be it." - plaintext := transmute([]byte)(plaintext_str) + plaintext := transmute([]byte)(_PLAINTEXT_SUNSCREEN_STR) key := [chacha20.KEY_SIZE]byte{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, @@ -182,6 +184,80 @@ test_poly1305 :: proc(t: ^testing.T) { expect(t, derived_tag_str == tag_str, fmt.tprintf("Expected %s for init/update/final - incremental, but got %s instead", tag_str, derived_tag_str)) } +@(test) +test_chacha20poly1305 :: proc(t: ^testing.T) { + log(t, "Testing chacha20poly1205") + + plaintext := transmute([]byte)(_PLAINTEXT_SUNSCREEN_STR) + + aad := [12]byte{ + 0x50, 0x51, 0x52, 0x53, 0xc0, 0xc1, 0xc2, 0xc3, + 0xc4, 0xc5, 0xc6, 0xc7, + } + + key := [chacha20poly1305.KEY_SIZE]byte{ + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, + 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, + 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, + } + + nonce := [chacha20poly1305.NONCE_SIZE]byte{ + 0x07, 0x00, 0x00, 0x00, + 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, + } + + ciphertext := [114]byte{ + 0xd3, 0x1a, 0x8d, 0x34, 0x64, 0x8e, 0x60, 0xdb, + 0x7b, 0x86, 0xaf, 0xbc, 0x53, 0xef, 0x7e, 0xc2, + 0xa4, 0xad, 0xed, 0x51, 0x29, 0x6e, 0x08, 0xfe, + 0xa9, 0xe2, 0xb5, 0xa7, 0x36, 0xee, 0x62, 0xd6, + 0x3d, 0xbe, 0xa4, 0x5e, 0x8c, 0xa9, 0x67, 0x12, + 0x82, 0xfa, 0xfb, 0x69, 0xda, 0x92, 0x72, 0x8b, + 0x1a, 0x71, 0xde, 0x0a, 0x9e, 0x06, 0x0b, 0x29, + 0x05, 0xd6, 0xa5, 0xb6, 0x7e, 0xcd, 0x3b, 0x36, + 0x92, 0xdd, 0xbd, 0x7f, 0x2d, 0x77, 0x8b, 0x8c, + 0x98, 0x03, 0xae, 0xe3, 0x28, 0x09, 0x1b, 0x58, + 0xfa, 0xb3, 0x24, 0xe4, 0xfa, 0xd6, 0x75, 0x94, + 0x55, 0x85, 0x80, 0x8b, 0x48, 0x31, 0xd7, 0xbc, + 0x3f, 0xf4, 0xde, 0xf0, 0x8e, 0x4b, 0x7a, 0x9d, + 0xe5, 0x76, 0xd2, 0x65, 0x86, 0xce, 0xc6, 0x4b, + 0x61, 0x16, + } + ciphertext_str := hex_string(ciphertext[:]) + + tag := [chacha20poly1305.TAG_SIZE]byte{ + 0x1a, 0xe1, 0x0b, 0x59, 0x4f, 0x09, 0xe2, 0x6a, + 0x7e, 0x90, 0x2e, 0xcb, 0xd0, 0x60, 0x06, 0x91, + } + tag_str := hex_string(tag[:]) + + derived_tag: [chacha20poly1305.TAG_SIZE]byte + derived_ciphertext: [114]byte + + chacha20poly1305.encrypt(derived_ciphertext[:], derived_tag[:], key[:], nonce[:], aad[:], plaintext) + + derived_ciphertext_str := hex_string(derived_ciphertext[:]) + expect(t, derived_ciphertext_str == ciphertext_str, fmt.tprintf("Expected ciphertext %s for encrypt(aad, plaintext), but got %s instead", ciphertext_str, derived_ciphertext_str)) + + derived_tag_str := hex_string(derived_tag[:]) + expect(t, derived_tag_str == tag_str, fmt.tprintf("Expected tag %s for encrypt(aad, plaintext), but got %s instead", tag_str, derived_tag_str)) + + derived_plaintext: [114]byte + ok := chacha20poly1305.decrypt(derived_plaintext[:], tag[:], key[:], nonce[:], aad[:], ciphertext[:]) + derived_plaintext_str := string(derived_plaintext[:]) + expect(t, ok, "Expected true for decrypt(tag, aad, ciphertext)") + expect(t, derived_plaintext_str == _PLAINTEXT_SUNSCREEN_STR, fmt.tprintf("Expected plaintext %s for decrypt(tag, aad, ciphertext), but got %s instead", _PLAINTEXT_SUNSCREEN_STR, derived_plaintext_str)) + + derived_ciphertext[0] ~= 0xa5 + ok = chacha20poly1305.decrypt(derived_plaintext[:], tag[:], key[:], nonce[:], aad[:], derived_ciphertext[:]) + expect(t, !ok, "Expected false for decrypt(tag, aad, corrupted_ciphertext)") + + aad[0] ~= 0xa5 + ok = chacha20poly1305.decrypt(derived_plaintext[:], tag[:], key[:], nonce[:], aad[:], ciphertext[:]) + expect(t, !ok, "Expected false for decrypt(tag, corrupted_aad, ciphertext)") +} + TestECDH :: struct { scalar: string, point: string, @@ -233,6 +309,7 @@ bench_modern :: proc(t: ^testing.T) { bench_chacha20(t) bench_poly1305(t) + bench_chacha20poly1305(t) bench_x25519(t) } @@ -293,6 +370,29 @@ _benchmark_poly1305 :: proc(options: ^time.Benchmark_Options, allocator := conte return nil } +_benchmark_chacha20poly1305 :: proc(options: ^time.Benchmark_Options, allocator := context.allocator) -> (err: time.Benchmark_Error) { + buf := options.input + key := [chacha20.KEY_SIZE]byte{ + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + 0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef, + } + nonce := [chacha20.NONCE_SIZE]byte{ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + } + + tag: [chacha20poly1305.TAG_SIZE]byte = --- + + for _ in 0..=options.rounds { + chacha20poly1305.encrypt(buf,tag[:], key[:], nonce[:], nil, buf) + } + options.count = options.rounds + options.processed = options.rounds * options.bytes + return nil +} + benchmark_print :: proc(name: string, options: ^time.Benchmark_Options) { fmt.printf("\t[%v] %v rounds, %v bytes processed in %v ns\n\t\t%5.3f rounds/s, %5.3f MiB/s\n", name, @@ -352,6 +452,33 @@ bench_poly1305 :: proc(t: ^testing.T) { benchmark_print(name, options) } +bench_chacha20poly1305 :: proc(t: ^testing.T) { + name := "chacha20poly1305 64 bytes" + options := &time.Benchmark_Options{ + rounds = 1_000, + bytes = 64, + setup = _setup_sized_buf, + bench = _benchmark_chacha20poly1305, + teardown = _teardown_sized_buf, + } + + err := time.benchmark(options, context.allocator) + expect(t, err == nil, name) + benchmark_print(name, options) + + name = "chacha20poly1305 1024 bytes" + options.bytes = 1024 + err = time.benchmark(options, context.allocator) + expect(t, err == nil, name) + benchmark_print(name, options) + + name = "chacha20poly1305 65536 bytes" + options.bytes = 65536 + err = time.benchmark(options, context.allocator) + expect(t, err == nil, name) + benchmark_print(name, options) +} + bench_x25519 :: proc(t: ^testing.T) { point := _decode_hex32("deadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef") scalar := _decode_hex32("cafebabecafebabecafebabecafebabecafebabecafebabecafebabecafebabe")