From e3504c94adf4edeabea3883e0e110bb4a9e7e690 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Sat, 28 Mar 2026 11:47:06 +0900 Subject: [PATCH 1/4] core/crypto: Get rid of `set` (only used by legacy) --- core/crypto/crypto.odin | 12 ------------ core/crypto/legacy/md5/md5.odin | 3 ++- core/crypto/legacy/sha1/sha1.odin | 3 ++- tests/benchmark/crypto/benchmark_ecc.odin | 3 ++- 4 files changed, 6 insertions(+), 15 deletions(-) diff --git a/core/crypto/crypto.odin b/core/crypto/crypto.odin index aa5a67b8f..3218b8670 100644 --- a/core/crypto/crypto.odin +++ b/core/crypto/crypto.odin @@ -85,18 +85,6 @@ zero_explicit :: proc "contextless" (data: rawptr, len: int) -> rawptr { return data } -/* -Set each byte of a memory range to a specific value. - -This procedure copies value specified by the `value` parameter into each of the -`len` bytes of a memory range, located at address `data`. - -This procedure returns the pointer to `data`. -*/ -set :: proc "contextless" (data: rawptr, value: byte, len: int) -> rawptr { - return runtime.memset(data, i32(value), len) -} - // rand_bytes fills the dst buffer with cryptographic entropy taken from // the system entropy source. This routine will block if the system entropy // source is not ready yet. All system entropy source failures are treated diff --git a/core/crypto/legacy/md5/md5.odin b/core/crypto/legacy/md5/md5.odin index 4bbc5d32a..ddc795b7d 100644 --- a/core/crypto/legacy/md5/md5.odin +++ b/core/crypto/legacy/md5/md5.odin @@ -18,6 +18,7 @@ package md5 zhibog, dotbmp: Initial implementation. */ +import "base:intrinsics" import "core:crypto" import "core:encoding/endian" import "core:math/bits" @@ -100,7 +101,7 @@ final :: proc(ctx: ^Context, hash: []byte, finalize_clone: bool = false) { i += 1 } transform(ctx, ctx.data[:]) - crypto.set(&ctx.data, 0, 56) + intrinsics.mem_zero(&ctx.data, 56) } ctx.bitlen += u64(ctx.datalen * 8) diff --git a/core/crypto/legacy/sha1/sha1.odin b/core/crypto/legacy/sha1/sha1.odin index 892f893a6..bf3ad9602 100644 --- a/core/crypto/legacy/sha1/sha1.odin +++ b/core/crypto/legacy/sha1/sha1.odin @@ -19,6 +19,7 @@ package sha1 zhibog, dotbmp: Initial implementation. */ +import "base:intrinsics" import "core:crypto" import "core:encoding/endian" import "core:math/bits" @@ -107,7 +108,7 @@ final :: proc(ctx: ^Context, hash: []byte, finalize_clone: bool = false) { i += 1 } transform(ctx, ctx.data[:]) - crypto.set(&ctx.data, 0, 56) + intrinsics.mem_zero(&ctx.data, 56) } ctx.bitlen += u64(ctx.datalen * 8) diff --git a/tests/benchmark/crypto/benchmark_ecc.odin b/tests/benchmark/crypto/benchmark_ecc.odin index c1809c6ba..95db33ab3 100644 --- a/tests/benchmark/crypto/benchmark_ecc.odin +++ b/tests/benchmark/crypto/benchmark_ecc.odin @@ -2,6 +2,7 @@ package benchmark_core_crypto import "base:runtime" import "core:encoding/hex" +import "core:mem" import "core:log" import "core:testing" import "core:text/table" @@ -161,7 +162,7 @@ bench_ed25519 :: proc() -> (sk, sig, verif: time.Duration) { @(private="file") bench_ecdsa :: proc(curve: ecdsa.Curve, hash: hash.Algorithm) -> (sk, sig, verif: time.Duration) { priv_bytes := make([]byte, ecdsa.PRIVATE_KEY_SIZES[curve], context.temp_allocator) - crypto.set(raw_data(priv_bytes), 0x69, len(priv_bytes)) + mem.set(raw_data(priv_bytes), 0x69, len(priv_bytes)) priv_key: ecdsa.Private_Key start := time.tick_now() for _ in 0 ..< DSA_ITERS { From a3d7300e5539779be0e5e29c35b4395af5f8a5b1 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Fri, 1 May 2026 23:48:54 +0900 Subject: [PATCH 2/4] core/crypto/ecdh,ecdsa: Add `require_results` annotations --- core/crypto/ecdh/ecdh.odin | 8 ++++++++ core/crypto/ecdsa/ecdsa.odin | 5 +++++ core/crypto/noise/protocol.odin | 13 ++++++++++--- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/core/crypto/ecdh/ecdh.odin b/core/crypto/ecdh/ecdh.odin index f5106d152..d31a69aa8 100644 --- a/core/crypto/ecdh/ecdh.odin +++ b/core/crypto/ecdh/ecdh.odin @@ -106,6 +106,7 @@ Public_Key :: struct { // private_key_generate uses the system entropy source to generate a new // Private_Key. This will only fail if and only if (⟺) the system entropy source is // missing or broken. +@(require_results) private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool { private_key_clear(priv_key) @@ -143,6 +144,7 @@ private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool { // private_key_set_bytes decodes a byte-encoded private key, and returns // true if and only if (⟺) the operation was successful. +@(require_results) private_key_set_bytes :: proc(priv_key: ^Private_Key, curve: Curve, b: []byte) -> bool { private_key_clear(priv_key) @@ -281,6 +283,7 @@ private_key_bytes :: proc(priv_key: ^Private_Key, dst: []byte) { // private_key_equal returns true if and only if (⟺) the private keys are equal, // in constant time. +@(require_results) private_key_equal :: proc(p, q: ^Private_Key) -> bool { if p._curve != q._curve { return false @@ -311,6 +314,7 @@ private_key_clear :: proc "contextless" (priv_key: ^Private_Key) { // public_key_set_bytes decodes a byte-encoded public key, and returns // true if and only if (⟺) the operation was successful. +@(require_results) public_key_set_bytes :: proc(pub_key: ^Public_Key, curve: Curve, b: []byte) -> bool { public_key_clear(pub_key) @@ -411,6 +415,7 @@ public_key_bytes :: proc(pub_key: ^Public_Key, dst: []byte) { // public_key_equal returns true if and only if (⟺) the public keys are equal, // in constant time. +@(require_results) public_key_equal :: proc(p, q: ^Public_Key) -> bool { if p._curve != q._curve { return false @@ -479,11 +484,13 @@ ecdh :: proc(priv_key: ^Private_Key, pub_key: ^Public_Key, dst: []byte) -> bool } // curve returns the Curve used by a Private_Key or Public_Key instance. +@(require_results) curve :: proc(k: ^$T) -> Curve where(T == Private_Key || T == Public_Key) { return k._curve } // key_size returns the key size of a Private_Key or Public_Key in bytes. +@(require_results) key_size :: proc(k: ^$T) -> int where(T == Private_Key || T == Public_Key) { when T == Private_Key { return PRIVATE_KEY_SIZES[k._curve] @@ -494,6 +501,7 @@ key_size :: proc(k: ^$T) -> int where(T == Private_Key || T == Public_Key) { // shared_secret_size returns the shared secret size of a key exchange // in bytes. +@(require_results) shared_secret_size :: proc(k: ^$T) -> int where(T == Private_Key || T == Public_Key) { return SHARED_SECRET_SIZES[k._curve] } diff --git a/core/crypto/ecdsa/ecdsa.odin b/core/crypto/ecdsa/ecdsa.odin index 350bab3ec..8bb1748cf 100644 --- a/core/crypto/ecdsa/ecdsa.odin +++ b/core/crypto/ecdsa/ecdsa.odin @@ -81,6 +81,7 @@ Public_Key :: struct { // private_key_generate uses the system entropy source to generate a new // Private_Key. This will only fail if and only if (⟺) the system entropy source is // missing or broken. +@(require_results) private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool { private_key_clear(priv_key) @@ -112,6 +113,7 @@ private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool { // private_key_set_bytes decodes a byte-encoded private key, and returns // true if and only if (⟺) the operation was successful. +@(require_results) private_key_set_bytes :: proc(priv_key: ^Private_Key, curve: Curve, b: []byte) -> bool { private_key_clear(priv_key) @@ -222,6 +224,7 @@ private_key_set :: proc(priv_key, src: ^Private_Key) { // private_key_equal returns true if and only if (⟺) the private keys are equal, // in constant time. +@(require_results) private_key_equal :: proc(p, q: ^Private_Key) -> bool { if p._curve != q._curve { return false @@ -246,6 +249,7 @@ private_key_clear :: proc "contextless" (priv_key: ^Private_Key) { // public_key_set_bytes decodes a byte-encoded public key, and returns // true if and only if (⟺) the operation was successful. +@(require_results) public_key_set_bytes :: proc(pub_key: ^Public_Key, curve: Curve, b: []byte) -> bool { public_key_clear(pub_key) @@ -334,6 +338,7 @@ public_key_bytes :: proc(pub_key: ^Public_Key, dst: []byte) { // public_key_equal returns true if and only if (⟺) the public keys are equal, // in constant time. +@(require_results) public_key_equal :: proc(p, q: ^Public_Key) -> bool { if p._curve != q._curve { return false diff --git a/core/crypto/noise/protocol.odin b/core/crypto/noise/protocol.odin index 883376a42..29482bd36 100644 --- a/core/crypto/noise/protocol.odin +++ b/core/crypto/noise/protocol.odin @@ -58,7 +58,9 @@ generate_keypair :: proc(protocol: ^Protocol, private_key: ^ecdh.Private_Key) { case: panic("crypto/noise: unsupported DH curve in protocol") } - ecdh.private_key_generate(private_key, protocol.dh) + if !ecdh.private_key_generate(private_key, protocol.dh) { + panic("crypto/noise: entropy source unavailable") + } } // Performs a Diffie-Hellman calculation between the private key in key_pair @@ -837,7 +839,9 @@ handshakestate_read_message :: proc(self: ^Handshake_State, message, dst: []byte panic("crypto/noise: re was not empty when processing token 'e' during ReadMessage") } - ecdh.public_key_set_bytes(&self.re, protocol.dh, re) + if !ecdh.public_key_set_bytes(&self.re, protocol.dh, re) { + return nil, .Invalid_Handshake_Message + } symmetricstate_mix_hash(&self.symmetric_state, re) if self.message_pattern.is_psk { symmetricstate_mix_key(&self.symmetric_state, re) @@ -864,7 +868,10 @@ handshakestate_read_message :: proc(self: ^Handshake_State, message, dst: []byte panic("crypto/noise: rs was not empty when processing token 's' during ReadMessage") } - ecdh.public_key_set_bytes(&self.rs, protocol.dh, rs) + if !ecdh.public_key_set_bytes(&self.rs, protocol.dh, rs) { + self.status = .Handshake_Failed + return nil, .Invalid_Handshake_Message + } msg = msg[rs_len:] case .ee: From 8f1067f29094b3427eed66993b3cf65a8ee612fa Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Sat, 2 May 2026 14:47:57 +0900 Subject: [PATCH 3/4] tests/core/crypto/wycheproof: Break up into separate files --- tests/core/crypto/wycheproof/aead.odin | 514 +++++++++ tests/core/crypto/wycheproof/ecc.odin | 427 ++++++++ tests/core/crypto/wycheproof/kdf.odin | 238 +++++ tests/core/crypto/wycheproof/mac.odin | 156 +++ tests/core/crypto/wycheproof/main.odin | 1328 ------------------------ 5 files changed, 1335 insertions(+), 1328 deletions(-) create mode 100644 tests/core/crypto/wycheproof/aead.odin create mode 100644 tests/core/crypto/wycheproof/ecc.odin create mode 100644 tests/core/crypto/wycheproof/kdf.odin create mode 100644 tests/core/crypto/wycheproof/mac.odin diff --git a/tests/core/crypto/wycheproof/aead.odin b/tests/core/crypto/wycheproof/aead.odin new file mode 100644 index 000000000..6c2b61d72 --- /dev/null +++ b/tests/core/crypto/wycheproof/aead.odin @@ -0,0 +1,514 @@ +package test_wycheproof + +import "core:encoding/hex" +import "core:log" +import "core:mem" +import "core:os" +import "core:slice" +import "core:testing" + +import chacha_simd128 "core:crypto/_chacha20/simd128" +import chacha_simd256 "core:crypto/_chacha20/simd256" +import "core:crypto/aegis" +import "core:crypto/aes" +import "core:crypto/chacha20" +import "core:crypto/chacha20poly1305" + +import "../common" + +supported_aegis_impls :: proc() -> [dynamic]aes.Implementation { + impls := make([dynamic]aes.Implementation, 0, 2, context.temp_allocator) + append(&impls, aes.Implementation.Portable) + if aegis.is_hardware_accelerated() { + append(&impls, aes.Implementation.Hardware) + } + + return impls +} + +@(test) +test_aead_aegis :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + files := []string { + "aegis128L_test.json", + "aegis256_test.json", + } + + log.debug("aead/aegis: starting") + + for f in files { + mem.free_all() // Probably don't need this, but be safe. + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Aead_Test_Group) + if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) { + continue + } + + for impl in supported_aegis_impls() { + testing.expectf(t, test_aead_aegis_impl(&test_vectors, impl), "impl {} failed", impl) + } + } +} + +test_aead_aegis_impl :: proc( + test_vectors: ^Test_Vectors(Aead_Test_Group), + impl: aes.Implementation, +) -> bool { + log.debug("aead/aegis/%v: starting", impl) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf( + "aead/aegis/%v/%d: %s: %+v", + impl, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("aead/aegis/%v/%d: %+v", + impl, + test_vector.tc_id, + test_vector.flags, + ) + } + + key := common.hexbytes_decode(test_vector.key) + iv := common.hexbytes_decode(test_vector.iv) + aad := common.hexbytes_decode(test_vector.aad) + msg := common.hexbytes_decode(test_vector.msg) + ct := common.hexbytes_decode(test_vector.ct) + tag := common.hexbytes_decode(test_vector.tag) + + if len(iv) == 0 { + log.infof( + "aead/aegis/%v/%d: skipped, invalid IVs panic", + impl, + test_vector.tc_id, + ) + num_skipped += 1 + continue + } + + ctx: aegis.Context + aegis.init(&ctx, key, impl) + + if result_is_valid(test_vector.result) { + ct_ := make([]byte, len(ct)) + tag_ := make([]byte, len(tag)) + aegis.seal(&ctx, ct_, tag_, iv, aad, msg) + + ok := common.hexbytes_compare(test_vector.ct, ct_) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(ct_)) + log.errorf( + "aead/aegis/%v/%d: ciphertext: expected %s actual %s", + impl, + test_vector.tc_id, + test_vector.ct, + x, + ) + num_failed += 1 + continue + } + + ok = common.hexbytes_compare(test_vector.tag, tag_) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(tag_)) + log.errorf( + "aead/aegis/%v/%d: tag: expected %s actual %s", + impl, + test_vector.tc_id, + test_vector.tag, + x, + ) + num_failed += 1 + continue + } + } + + msg_ := make([]byte, len(msg)) + ok := aegis.open(&ctx, msg_, iv, aad, ct, tag) + if !result_check(test_vector.result, ok) { + log.errorf("aead/aegis/%v/%d: decrypt failed", impl, test_vector.tc_id) + num_failed += 1 + continue + } + + if ok && !common.hexbytes_compare(test_vector.msg, msg_) { + x := transmute(string)(hex.encode(msg_)) + log.errorf( + "aead/aegis/%v/%d: decrypt msg: expected %s actual %s", + impl, + test_vector.tc_id, + test_vector.msg, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "aead/aegis: ran %d, passed %d, failed %d, skipped %d", + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +supported_aes_impls :: proc() -> [dynamic]aes.Implementation { + impls := make([dynamic]aes.Implementation, 0, 2) + append(&impls, aes.Implementation.Portable) + if aes.is_hardware_accelerated() { + append(&impls, aes.Implementation.Hardware) + } + + return impls +} + +@(test) +test_aead_aes_gcm :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + fn, _ := os.join_path([]string{BASE_PATH, "aes_gcm_test.json"}, context.allocator) + + log.debug("aead/aes-gcm: starting") + + test_vectors: Test_Vectors(Aead_Test_Group) + if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", fn) { + return + } + + for impl in supported_aes_impls() { + testing.expectf(t, test_aead_aes_gcm_impl(&test_vectors, impl), "impl {} failed", impl) + } +} + +test_aead_aes_gcm_impl :: proc( + test_vectors: ^Test_Vectors(Aead_Test_Group), + impl: aes.Implementation, +) -> bool { + log.debug("aead/aes-gcm/%v: starting", impl) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf( + "aead/aes-gcm/%v/%d: %s: %+v", + impl, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("aead/aes-gcm/%v/%d: %+v", + impl, + test_vector.tc_id, + test_vector.flags, + ) + } + + key := common.hexbytes_decode(test_vector.key) + iv := common.hexbytes_decode(test_vector.iv) + aad := common.hexbytes_decode(test_vector.aad) + msg := common.hexbytes_decode(test_vector.msg) + ct := common.hexbytes_decode(test_vector.ct) + tag := common.hexbytes_decode(test_vector.tag) + + if len(iv) == 0 { + log.infof( + "aead/aes-gcm/%v/%d: skipped, invalid IVs panic", + impl, + test_vector.tc_id, + ) + num_skipped += 1 + continue + } + + ctx: aes.Context_GCM + aes.init_gcm(&ctx, key, impl) + + if result_is_valid(test_vector.result) { + ct_ := make([]byte, len(ct)) + tag_ := make([]byte, len(tag)) + aes.seal_gcm(&ctx, ct_, tag_, iv, aad, msg) + + ok := common.hexbytes_compare(test_vector.ct, ct_) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(ct_)) + log.errorf( + "aead/aes-gcm/%v/%d: ciphertext: expected %s actual %s", + impl, + test_vector.tc_id, + test_vector.ct, + x, + ) + num_failed += 1 + continue + } + + ok = common.hexbytes_compare(test_vector.tag, tag_) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(tag_)) + log.errorf( + "aead/aes-gcm/%v/%d: tag: expected %s actual %s", + impl, + test_vector.tc_id, + test_vector.tag, + x, + ) + num_failed += 1 + continue + } + } + + msg_ := make([]byte, len(msg)) + ok := aes.open_gcm(&ctx, msg_, iv, aad, ct, tag) + if !result_check(test_vector.result, ok) { + log.errorf("aead/aes-gcm/%v/%d: decrypt failed", impl, test_vector.tc_id) + num_failed += 1 + continue + } + + if ok && !common.hexbytes_compare(test_vector.msg, msg_) { + x := transmute(string)(hex.encode(msg_)) + log.errorf( + "aead/aes-gcm/%v/%d: decrypt msg: expected %s actual %s", + impl, + test_vector.tc_id, + test_vector.msg, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "aead/aes-gcm: ran %d, passed %d, failed %d, skipped %d", + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +supported_chacha_impls :: proc() -> [dynamic]chacha20.Implementation { + impls := make([dynamic]chacha20.Implementation, 0, 3) + append(&impls, chacha20.Implementation.Portable) + if chacha_simd128.is_performant() { + append(&impls, chacha20.Implementation.Simd128) + } + if chacha_simd256.is_performant() { + append(&impls, chacha20.Implementation.Simd256) + } + + return impls +} + +@(test) +test_aead_chacha20_poly1305 :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + files := []string { + "chacha20_poly1305_test.json", + "xchacha20_poly1305_test.json", + } + + log.debug("aead/(x)chacha20poly1305: starting") + + for f, i in files { + mem.free_all() // Probably don't need this, but be safe. + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Aead_Test_Group) + if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) { + continue + } + + for impl in supported_chacha_impls() { + testing.expectf(t, test_aead_chacha20_poly1305_impl(&test_vectors, i == 1, impl), "impl {} failed", impl) + } + } +} + +test_aead_chacha20_poly1305_impl :: proc( + test_vectors: ^Test_Vectors(Aead_Test_Group), + is_xchacha: bool, + impl: chacha20.Implementation, +) -> bool { + FLAG_INVALID_NONCE_SIZE :: "InvalidNonceSize" + + alg_str := is_xchacha ? "xchacha20poly1305" : "chacha20poly1305" + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf( + "aead/%s/%v/%d: %s: %+v", + alg_str, + impl, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("aead/%s/%v/%d: %+v", + alg_str, + impl, + test_vector.tc_id, + test_vector.flags, + ) + } + + key := common.hexbytes_decode(test_vector.key) + iv := common.hexbytes_decode(test_vector.iv) + aad := common.hexbytes_decode(test_vector.aad) + msg := common.hexbytes_decode(test_vector.msg) + ct := common.hexbytes_decode(test_vector.ct) + tag := common.hexbytes_decode(test_vector.tag) + + if slice.contains(test_vector.flags, FLAG_INVALID_NONCE_SIZE) { + log.infof( + "aead/%s/%v/%d: skipped, invalid nonces panic", + alg_str, + impl, + test_vector.tc_id, + ) + num_skipped += 1 + continue + } + + ctx: chacha20poly1305.Context + switch is_xchacha { + case true: + chacha20poly1305.init_xchacha(&ctx, key, impl) + case false: + chacha20poly1305.init(&ctx, key, impl) + } + + if result_is_valid(test_vector.result) { + ct_ := make([]byte, len(ct)) + tag_ := make([]byte, len(tag)) + chacha20poly1305.seal(&ctx, ct_, tag_, iv, aad, msg) + + ok := common.hexbytes_compare(test_vector.ct, ct_) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(ct_)) + log.errorf( + "aead/%s/%v/%d: ciphertext: expected %s actual %s", + alg_str, + impl, + test_vector.tc_id, + test_vector.ct, + x, + ) + num_failed += 1 + continue + } + + ok = common.hexbytes_compare(test_vector.tag, tag_) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(tag_)) + log.errorf( + "aead/%s/%v/%d: tag: expected %s actual %s", + alg_str, + impl, + test_vector.tc_id, + test_vector.tag, + x, + ) + num_failed += 1 + continue + } + } + + msg_ := make([]byte, len(msg)) + ok := chacha20poly1305.open(&ctx, msg_, iv, aad, ct, tag) + if !result_check(test_vector.result, ok) { + log.errorf("aead/%s/%v/%d: decrypt failed", + alg_str, + impl, + test_vector.tc_id, + ) + num_failed += 1 + continue + } + + if ok && !common.hexbytes_compare(test_vector.msg, msg_) { + x := transmute(string)(hex.encode(msg_)) + log.errorf( + "aead/%s/%v/%d: decrypt msg: expected %s actual %s", + alg_str, + impl, + test_vector.tc_id, + test_vector.msg, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "aead/%s/%v: ran %d, passed %d, failed %d, skipped %d", + alg_str, + impl, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} diff --git a/tests/core/crypto/wycheproof/ecc.odin b/tests/core/crypto/wycheproof/ecc.odin new file mode 100644 index 000000000..53b63cc08 --- /dev/null +++ b/tests/core/crypto/wycheproof/ecc.odin @@ -0,0 +1,427 @@ +package test_wycheproof + +import "core:encoding/hex" +import "core:log" +import "core:mem" +import "core:os" +import "core:slice" +import "core:strings" +import "core:testing" + +import "core:crypto/hash" +import "core:crypto/ecdh" +import "core:crypto/ecdsa" +import "core:crypto/ed25519" + +import "../common" + +@(test) +test_eddsa_ed25519 :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + fn, _ := os.join_path([]string{BASE_PATH, "ed25519_test.json"}, context.allocator) + + log.debug("eddsa/ed25519: starting") + + test_vectors: Test_Vectors(Eddsa_Test_Group) + if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", fn) { + return + } + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, i in test_vectors.test_groups { + mem.free_all() // Probably don't need this, but be safe. + pk_bytes := common.hexbytes_decode(test_group.public_key.pk) + + pk: ed25519.Public_Key + pk_ok := ed25519.public_key_set_bytes(&pk, pk_bytes) + if !testing.expectf(t, pk_ok, "eddsa/ed25519/%d: invalid public key: %s", i, test_group.public_key.pk) { + num_failed += len(test_group.tests) + continue + } + + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf( + "eddsa/ed25519/%d: %s: %+v", + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("eddsa/ed25519/%d: %+v", test_vector.tc_id, test_vector.flags) + } + + msg := common.hexbytes_decode(test_vector.msg) + sig := common.hexbytes_decode(test_vector.sig) + + verify_ok := ed25519.verify(&pk, msg, sig) + if !testing.expectf( + t, + result_check(test_vector.result, verify_ok), + "eddsa/ed25519/%d: verify failed: expected %s actual %v", + test_vector.tc_id, + test_vector.result, + verify_ok, + ) { + num_failed += 1 + continue + } + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "eddsa/ed25519: ran %d, passed %d, failed %d, skipped %d", + num_ran, + num_passed, + num_failed, + num_skipped, + ) +} + +@(test) +test_ecdsa :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + log.debug("ecdsa: starting") + + files := []string { + "ecdsa_secp256r1_sha256_test.json", + "ecdsa_secp256r1_sha512_test.json", + "ecdsa_secp384r1_sha384_test.json", + } + + for f in files { + mem.free_all() // Probably don't need this, but be safe. + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Ecdsa_Test_Group) + if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_ecdsa_impl(t, &test_vectors), "ecdsa failed") + } +} + +test_ecdsa_impl :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Ecdsa_Test_Group)) -> bool { + curve_str := test_vectors.test_groups[0].public_key.curve + hash_str := test_vectors.test_groups[0].sha + + curve_alg: ecdsa.Curve + switch curve_str { + case "secp256r1": + curve_alg = .SECP256R1 + case "secp384r1": + curve_alg = .SECP384R1 + case: + log.errorf("ecdsa: unsupported curve: %s", curve_str) + } + + hash_alg: hash.Algorithm + switch hash_str { + case "SHA-256": + hash_alg = .SHA256 + case "SHA-384": + hash_alg = .SHA384 + case "SHA-512": + hash_alg = .SHA512 + case: + log.errorf("ecdsa: unsupported hash: %s", hash_str) + } + + log.debugf("ecdsa/%s/%s: starting", curve_str, hash_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, i in test_vectors.test_groups { + pk_bytes := common.hexbytes_decode(test_group.public_key.uncompressed) + + pk: ecdsa.Public_Key + pk_ok := ecdsa.public_key_set_bytes(&pk, curve_alg, pk_bytes) + if !testing.expectf(t, pk_ok, "ecdsa/%s/%s/%d: invalid public key: %s", curve_str, hash_str, i, test_group.public_key.uncompressed) { + num_failed += len(test_group.tests) + continue + } + + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf( + "ecda/%s/%s/%d: %s: %+v", + curve_str, + hash_str, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("ecdsa/%s/%s/%d: %+v", curve_str, hash_str, test_vector.tc_id, test_vector.flags) + } + + msg := common.hexbytes_decode(test_vector.msg) + sig := common.hexbytes_decode(test_vector.sig) + + verify_ok := ecdsa.verify_asn1(&pk, hash_alg, msg, sig) + if !testing.expectf( + t, + result_check(test_vector.result, verify_ok), + "ecdsa/%s/%s/%d: verify failed: expected %s actual %v", + curve_str, + hash_str, + test_vector.tc_id, + test_vector.result, + verify_ok, + ) { + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "ecdsa/%s/%s: ran %d, passed %d, failed %d, skipped %d", + curve_str, + hash_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +@(test) +test_ecdh :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + PREFIX_TEST_ECDH :: "ecdh_" + SUFFIX_TEST_ECPOINT :: "_ecpoint" + + files := []string { + "ecdh_secp256r1_ecpoint_test.json", + "ecdh_secp384r1_ecpoint_test.json", + "x25519_test.json", + "x448_test.json", + } + + log.debug("ecdh: starting") + + for f in files { + mem.free_all() // Probably don't need this, but be safe. + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Ecdh_Test_Group) + if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) { + continue + } + + alg_str := strings.trim_suffix(f, SUFFIX_TEST_JSON) + alg_str = strings.trim_suffix(alg_str, SUFFIX_TEST_ECPOINT) + alg_str = strings.trim_prefix(alg_str, PREFIX_TEST_ECDH) + testing.expectf(t, test_ecdh_impl(&test_vectors, alg_str), "alg {} failed", alg_str) + } +} + +test_ecdh_impl :: proc( + test_vectors: ^Test_Vectors(Ecdh_Test_Group), + alg_str: string, +) -> bool { + ALG_P256 :: "secp256r1" + ALG_P384 :: "secp384r1" + ALG_X25519 :: "x25519" + ALG_X448 :: "x448" + + // XDH exceptions + FLAG_PUBLIC_KEY_TOO_LONG :: "PublicKeyTooLong" + FLAG_ZERO_SHARED_SECRET :: "ZeroSharedSecret" + + // ECDH exceptions + FLAG_COMPRESSED_POINT :: "CompressedPoint" + FLAG_INVALID_CURVE :: "InvalidCurveAttack" + FLAG_INVALID_ENCODING :: "InvalidEncoding" + + log.debugf("ecdh/%s: starting", alg_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf("ecdh/%s/%d: %s: %+v", alg_str, test_vector.tc_id, comment, test_vector.flags) + } else { + log.debugf("ecdh/%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags) + } + + raw_pub := common.hexbytes_decode(test_vector.public) + raw_priv := common.hexbytes_decode(test_vector.private) + + curve: ecdh.Curve + priv_key: ecdh.Private_Key + pub_key: ecdh.Public_Key + + is_nist, is_xdh: bool + switch alg_str { + case ALG_P256: + curve = .SECP256R1 + // Ugh, ASN.1 :( + l := len(raw_priv) + if l == 33 { + if raw_priv[0] == 0 { + raw_priv = raw_priv[1:] + } + } else if l < 32 { + // left-pad.odin + tmp := make([]byte, 32) + copy(tmp[32-l:], raw_priv) + raw_priv = tmp + } + is_nist = true + case ALG_P384: + curve = .SECP384R1 + // Ugh, ASN.1 :( + l := len(raw_priv) + if l == 49 { + if raw_priv[0] == 0 { + raw_priv = raw_priv[1:] + } + } else if l < 48 { + // left-pad.odin + tmp := make([]byte, 48) + copy(tmp[48-l:], raw_priv) + raw_priv = tmp + } + is_nist = true + case ALG_X25519: + curve = .X25519 + is_xdh = true + case ALG_X448: + curve = .X448 + is_xdh = true + case: + log.errorf("ecdh: unsupported algorithm: %s", alg_str) + return false + } + + if ok := ecdh.private_key_set_bytes(&priv_key, curve, raw_priv); !ok { + log.errorf( + "ecdh/%s/%d: failed to deserialize private_key: %s %d %x", + alg_str, + test_vector.tc_id, + test_vector.private, + len(raw_priv), + raw_priv, + ) + num_failed += 1 + continue + } + + if ok := ecdh.public_key_set_bytes(&pub_key, curve, raw_pub); !ok { + if is_nist { + if slice.contains(test_vector.flags, FLAG_COMPRESSED_POINT) { + num_passed += 1 + continue + } + if slice.contains(test_vector.flags, FLAG_INVALID_CURVE) { + num_passed += 1 + continue + } + if slice.contains(test_vector.flags, FLAG_INVALID_ENCODING) { + num_passed += 1 + continue + } + } + if slice.contains(test_vector.flags, FLAG_PUBLIC_KEY_TOO_LONG) { + num_passed += 1 + continue + } + + log.errorf( + "ecdh/%s/%d: failed to deserialize public_key: %s", + alg_str, + test_vector.tc_id, + test_vector.public, + ) + num_failed += 1 + continue + } + + shared := make([]byte, ecdh.SHARED_SECRET_SIZES[curve]) + + ok := ecdh.ecdh(&priv_key, &pub_key, shared) + if !ok { + if is_xdh && slice.contains(test_vector.flags, FLAG_ZERO_SHARED_SECRET) { + num_passed += 1 + continue + } + // unused: x := transmute(string)(hex.encode(shared)) + log.errorf( + "ecdh/%s/%d: ecdh failed", + alg_str, + test_vector.tc_id, + ) + num_failed += 1 + continue + } + + ok = common.hexbytes_compare(test_vector.shared, shared) + // "acceptable" results are fine from here because we have + // checked for the all-zero shared secret XDH case already. + if !result_check(test_vector.result, ok, false) { + x := transmute(string)(hex.encode(shared)) + log.errorf( + "ecdh/%s/%d: shared: expected %s actual %s", + alg_str, + test_vector.tc_id, + test_vector.shared, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "ecdh/%s: ran %d, passed %d, failed %d, skipped %d", + alg_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} diff --git a/tests/core/crypto/wycheproof/kdf.odin b/tests/core/crypto/wycheproof/kdf.odin new file mode 100644 index 000000000..5b5ea1f2d --- /dev/null +++ b/tests/core/crypto/wycheproof/kdf.odin @@ -0,0 +1,238 @@ +package test_wycheproof + +import "core:encoding/hex" +import "core:log" +import "core:mem" +import "core:os" +import "core:slice" +import "core:strings" +import "core:testing" + +import "core:crypto/hkdf" +import "core:crypto/pbkdf2" + +import "../common" + +@(test) +test_hkdf :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + log.debug("hkdf: starting") + + files := []string { + "hkdf_sha1_test.json", + "hkdf_sha256_test.json", + "hkdf_sha384_test.json", + "hkdf_sha512_test.json", + } + + for f in files { + mem.free_all() // Probably don't need this, but be safe. + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Hkdf_Test_Group) + if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_hkdf_impl(&test_vectors), "hkdf failed") + } +} + +test_hkdf_impl :: proc(test_vectors: ^Test_Vectors(Hkdf_Test_Group)) -> bool { + PREFIX_HKDF :: "HKDF-" + FLAG_SIZE_TOO_LARGE :: "SizeTooLarge" + + alg_str := strings.trim_prefix(test_vectors.algorithm, PREFIX_HKDF) + alg, ok := hash_name_to_algorithm(alg_str) + if !ok { + return false + } + alg_str = strings.to_lower(alg_str) + + log.debugf("hkdf/%s: starting", alg_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf( + "hkdf/%s/%d: %s: %+v", + alg_str, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("hkdf/%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags) + } + + ikm := common.hexbytes_decode(test_vector.ikm) + salt := common.hexbytes_decode(test_vector.salt) + info := common.hexbytes_decode(test_vector.info) + + if slice.contains(test_vector.flags, FLAG_SIZE_TOO_LARGE) { + log.infof( + "hkdf/%s/%d: skipped, oversized outputs panic", + alg_str, + test_vector.tc_id, + ) + num_skipped += 1 + continue + } + + okm_ := make([]byte, test_vector.size) + hkdf.extract_and_expand(alg, salt, ikm, info, okm_) + + ok = common.hexbytes_compare(test_vector.okm, okm_) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(okm_)) + log.errorf( + "hkdf/%s/%d: shared: expected %s actual %s", + alg_str, + test_vector.tc_id, + test_vector.okm, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "hkdf/%s: ran %d, passed %d, failed %d, skipped %d", + alg_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +@(test) +test_pbkdf2 :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + log.debug("pbkdf2: starting") + + files := []string { + "pbkdf2_hmacsha1_test.json", + "pbkdf2_hmacsha224_test.json", + "pbkdf2_hmacsha256_test.json", + "pbkdf2_hmacsha384_test.json", + "pbkdf2_hmacsha512_test.json", + } + + for f in files { + mem.free_all() // Probably don't need this, but be safe. + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Pbkdf_Test_Group) + if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_pbkdf2_impl(&test_vectors), "pbkdf2 failed") + } +} + +test_pbkdf2_impl :: proc( + test_vectors: ^Test_Vectors(Pbkdf_Test_Group), +) -> bool { + PREFIX_PBKDF_HMAC :: "PBKDF2-HMAC" + FLAG_LARGE_ITERATION_COUNT :: "LargeIterationCount" + + alg_str := strings.trim_prefix(test_vectors.algorithm, PREFIX_PBKDF_HMAC) + alg, ok := hash_name_to_algorithm(alg_str) + if !ok { + return false + } + alg_str = strings.to_lower(alg_str) + + log.debugf("pbkdf2/hmac-%s: starting", alg_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf( + "pbkdf2/hmac-%s/%d: %s: %+v", + alg_str, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("pbkdf2/hmac-%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags) + } + + if slice.contains(test_vector.flags, FLAG_LARGE_ITERATION_COUNT) { + log.infof( + "pbkdf2/hmac-%s/%d: skipped, takes fucking forever", + alg_str, + test_vector.tc_id, + ) + num_skipped += 1 + continue + } + + password := common.hexbytes_decode(test_vector.password) + salt := common.hexbytes_decode(test_vector.salt) + + dk_ := make([]byte, test_vector.dk_len) + pbkdf2.derive(alg, password, salt, test_vector.iteration_count, dk_) + + ok = common.hexbytes_compare(test_vector.dk, dk_) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(dk_)) + log.errorf( + "pbkdf2/hmac-%s/%d: shared: expected %s actual %s", + alg_str, + test_vector.tc_id, + test_vector.dk, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "pbkdf2/%s: ran %d, passed %d, failed %d, skipped %d", + alg_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} diff --git a/tests/core/crypto/wycheproof/mac.odin b/tests/core/crypto/wycheproof/mac.odin new file mode 100644 index 000000000..35dcc1fde --- /dev/null +++ b/tests/core/crypto/wycheproof/mac.odin @@ -0,0 +1,156 @@ +package test_wycheproof + +import "core:encoding/hex" + +import "core:log" +import "core:mem" +import "core:os" +import "core:testing" + +import "core:crypto/hmac" +import "core:crypto/kmac" +import "core:crypto/siphash" + +import "../common" + +@(test) +test_mac :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + log.debug("mac: starting") + + files := []string { + "hmac_sha1_test.json", + "hmac_sha224_test.json", + "hmac_sha256_test.json", + "hmac_sha3_224_test.json", + "hmac_sha3_256_test.json", + "hmac_sha3_384_test.json", + "hmac_sha3_512_test.json", + "hmac_sha384_test.json", + // "hmac_sha512_224_test.json", + "hmac_sha512_256_test.json", + "hmac_sha512_test.json", + "hmac_sm3_test.json", + "kmac128_no_customization_test.json", + "kmac256_no_customization_test.json", + "siphash_1_3_test.json", + "siphash_2_4_test.json", + "siphash_4_8_test.json", + } + + for f in files { + mem.free_all() // Probably don't need this, but be safe. + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Mac_Test_Group) + if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_mac_impl(&test_vectors), "hkdf failed") + } +} + +test_mac_impl :: proc(test_vectors: ^Test_Vectors(Mac_Test_Group)) -> bool { + PREFIX_HMAC :: "HMAC" + PREFIX_KMAC :: "KMAC" + + mac_alg, hmac_alg, alg_str, ok := mac_algorithm(test_vectors.algorithm) + if !ok { + log.errorf("mac: unsupported algorith: %s", test_vectors.algorithm) + return false + } + + log.debugf("%s: starting", alg_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf( + "%s/%d: %s: %+v", + alg_str, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags) + } + + key := common.hexbytes_decode(test_vector.key) + msg := common.hexbytes_decode(test_vector.msg) + + tag_ := make([]byte, len(test_vector.tag) / 2) + + #partial switch mac_alg { + case .HMAC: + ctx: hmac.Context + hmac.init(&ctx, hmac_alg, key) + hmac.update(&ctx, msg) + if l := hmac.tag_size(&ctx); l == len(tag_) { + hmac.final(&ctx, tag_) + } else { + // Our hmac package does not support truncation. + tmp := make([]byte, l) + hmac.final(&ctx, tmp) + copy(tag_, tmp) + } + case .KMAC128, .KMAC256: + ctx: kmac.Context + #partial switch mac_alg { + case .KMAC128: + kmac.init_128(&ctx, key, nil) + case .KMAC256: + kmac.init_256(&ctx, key, nil) + } + kmac.update(&ctx, msg) + kmac.final(&ctx, tag_) + case .SIPHASH_1_3: + siphash.sum_1_3(msg, key, tag_) + case .SIPHASH_2_4: + siphash.sum_2_4(msg, key, tag_) + case .SIPHASH_4_8: + siphash.sum_4_8(msg, key, tag_) + } + + ok = common.hexbytes_compare(test_vector.tag, tag_) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(tag_)) + log.errorf( + "%s/%d: tag: expected %s actual %s", + alg_str, + test_vector.tc_id, + test_vector.tag, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "%s: ran %d, passed %d, failed %d, skipped %d", + alg_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} diff --git a/tests/core/crypto/wycheproof/main.odin b/tests/core/crypto/wycheproof/main.odin index dc8ab9237..dfdc78267 100644 --- a/tests/core/crypto/wycheproof/main.odin +++ b/tests/core/crypto/wycheproof/main.odin @@ -1,32 +1,8 @@ package test_wycheproof -import "core:encoding/hex" import "core:log" -import "core:mem" -import "core:os" -import "core:slice" -import "core:strings" import "core:testing" -import chacha_simd128 "core:crypto/_chacha20/simd128" -import chacha_simd256 "core:crypto/_chacha20/simd256" -import "core:crypto/aegis" -import "core:crypto/aes" -import "core:crypto/chacha20" -import "core:crypto/chacha20poly1305" -import "core:crypto/ecdh" -import "core:crypto/ecdsa" -import "core:crypto/ed25519" -import "core:crypto/hash" -import "core:crypto/hkdf" -import "core:crypto/hmac" -import "core:crypto/kmac" -import "core:crypto/pbkdf2" -import "core:crypto/siphash" -import "core:crypto/deoxysii" - -import "../common" - // Covered: // - crypto/aegis // - aegis128L_test.json @@ -96,1307 +72,3 @@ SUFFIX_TEST_JSON :: "_test.json" print_test_vector_path :: proc(t: ^testing.T) { log.infof("wycheproof path: %s", BASE_PATH) } - -supported_aegis_impls :: proc() -> [dynamic]aes.Implementation { - impls := make([dynamic]aes.Implementation, 0, 2, context.temp_allocator) - append(&impls, aes.Implementation.Portable) - if aegis.is_hardware_accelerated() { - append(&impls, aes.Implementation.Hardware) - } - - return impls -} - -@(test) -test_aead_aegis :: proc(t: ^testing.T) { - arena: mem.Arena - arena_backing := make([]byte, ARENA_SIZE) - defer delete(arena_backing) - mem.arena_init(&arena, arena_backing) - context.allocator = mem.arena_allocator(&arena) - - files := []string { - "aegis128L_test.json", - "aegis256_test.json", - } - - log.debug("aead/aegis: starting") - - for f in files { - mem.free_all() // Probably don't need this, but be safe. - - fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) - - test_vectors: Test_Vectors(Aead_Test_Group) - load_ok := load(&test_vectors, fn) - testing.expectf(t, load_ok, "Unable to load {}", f) - if !load_ok { - continue - } - - for impl in supported_aegis_impls() { - testing.expectf(t, test_aead_aegis_impl(&test_vectors, impl), "impl {} failed", impl) - } - } -} - -test_aead_aegis_impl :: proc( - test_vectors: ^Test_Vectors(Aead_Test_Group), - impl: aes.Implementation, -) -> bool { - log.debug("aead/aegis/%v: starting", impl) - - num_ran, num_passed, num_failed, num_skipped: int - for &test_group in test_vectors.test_groups { - for &test_vector in test_group.tests { - num_ran += 1 - - if comment := test_vector.comment; comment != "" { - log.debugf( - "aead/aegis/%v/%d: %s: %+v", - impl, - test_vector.tc_id, - comment, - test_vector.flags, - ) - } else { - log.debugf("aead/aegis/%v/%d: %+v", - impl, - test_vector.tc_id, - test_vector.flags, - ) - } - - key := common.hexbytes_decode(test_vector.key) - iv := common.hexbytes_decode(test_vector.iv) - aad := common.hexbytes_decode(test_vector.aad) - msg := common.hexbytes_decode(test_vector.msg) - ct := common.hexbytes_decode(test_vector.ct) - tag := common.hexbytes_decode(test_vector.tag) - - if len(iv) == 0 { - log.infof( - "aead/aegis/%v/%d: skipped, invalid IVs panic", - impl, - test_vector.tc_id, - ) - num_skipped += 1 - continue - } - - ctx: aegis.Context - aegis.init(&ctx, key, impl) - - if result_is_valid(test_vector.result) { - ct_ := make([]byte, len(ct)) - tag_ := make([]byte, len(tag)) - aegis.seal(&ctx, ct_, tag_, iv, aad, msg) - - ok := common.hexbytes_compare(test_vector.ct, ct_) - if !result_check(test_vector.result, ok) { - x := transmute(string)(hex.encode(ct_)) - log.errorf( - "aead/aegis/%v/%d: ciphertext: expected %s actual %s", - impl, - test_vector.tc_id, - test_vector.ct, - x, - ) - num_failed += 1 - continue - } - - ok = common.hexbytes_compare(test_vector.tag, tag_) - if !result_check(test_vector.result, ok) { - x := transmute(string)(hex.encode(tag_)) - log.errorf( - "aead/aegis/%v/%d: tag: expected %s actual %s", - impl, - test_vector.tc_id, - test_vector.tag, - x, - ) - num_failed += 1 - continue - } - } - - msg_ := make([]byte, len(msg)) - ok := aegis.open(&ctx, msg_, iv, aad, ct, tag) - if !result_check(test_vector.result, ok) { - log.errorf("aead/aegis/%v/%d: decrypt failed", impl, test_vector.tc_id) - num_failed += 1 - continue - } - - if ok && !common.hexbytes_compare(test_vector.msg, msg_) { - x := transmute(string)(hex.encode(msg_)) - log.errorf( - "aead/aegis/%v/%d: decrypt msg: expected %s actual %s", - impl, - test_vector.tc_id, - test_vector.msg, - x, - ) - num_failed += 1 - continue - } - - num_passed += 1 - } - } - - assert(num_ran == test_vectors.number_of_tests) - assert(num_passed + num_failed + num_skipped == num_ran) - - log.infof( - "aead/aegis: ran %d, passed %d, failed %d, skipped %d", - num_ran, - num_passed, - num_failed, - num_skipped, - ) - - return num_failed == 0 -} - -supported_aes_impls :: proc() -> [dynamic]aes.Implementation { - impls := make([dynamic]aes.Implementation, 0, 2) - append(&impls, aes.Implementation.Portable) - if aes.is_hardware_accelerated() { - append(&impls, aes.Implementation.Hardware) - } - - return impls -} - -@(test) -test_aead_aes_gcm :: proc(t: ^testing.T) { - arena: mem.Arena - arena_backing := make([]byte, ARENA_SIZE) - defer delete(arena_backing) - mem.arena_init(&arena, arena_backing) - context.allocator = mem.arena_allocator(&arena) - - fn, _ := os.join_path([]string{BASE_PATH, "aes_gcm_test.json"}, context.allocator) - - log.debug("aead/aes-gcm: starting") - - test_vectors: Test_Vectors(Aead_Test_Group) - assert(load(&test_vectors, fn)) - - for impl in supported_aes_impls() { - testing.expectf(t, test_aead_aes_gcm_impl(&test_vectors, impl), "impl {} failed", impl) - } -} - -test_aead_aes_gcm_impl :: proc( - test_vectors: ^Test_Vectors(Aead_Test_Group), - impl: aes.Implementation, -) -> bool { - log.debug("aead/aes-gcm/%v: starting", impl) - - num_ran, num_passed, num_failed, num_skipped: int - for &test_group in test_vectors.test_groups { - for &test_vector in test_group.tests { - num_ran += 1 - - if comment := test_vector.comment; comment != "" { - log.debugf( - "aead/aes-gcm/%v/%d: %s: %+v", - impl, - test_vector.tc_id, - comment, - test_vector.flags, - ) - } else { - log.debugf("aead/aes-gcm/%v/%d: %+v", - impl, - test_vector.tc_id, - test_vector.flags, - ) - } - - key := common.hexbytes_decode(test_vector.key) - iv := common.hexbytes_decode(test_vector.iv) - aad := common.hexbytes_decode(test_vector.aad) - msg := common.hexbytes_decode(test_vector.msg) - ct := common.hexbytes_decode(test_vector.ct) - tag := common.hexbytes_decode(test_vector.tag) - - if len(iv) == 0 { - log.infof( - "aead/aes-gcm/%v/%d: skipped, invalid IVs panic", - impl, - test_vector.tc_id, - ) - num_skipped += 1 - continue - } - - ctx: aes.Context_GCM - aes.init_gcm(&ctx, key, impl) - - if result_is_valid(test_vector.result) { - ct_ := make([]byte, len(ct)) - tag_ := make([]byte, len(tag)) - aes.seal_gcm(&ctx, ct_, tag_, iv, aad, msg) - - ok := common.hexbytes_compare(test_vector.ct, ct_) - if !result_check(test_vector.result, ok) { - x := transmute(string)(hex.encode(ct_)) - log.errorf( - "aead/aes-gcm/%v/%d: ciphertext: expected %s actual %s", - impl, - test_vector.tc_id, - test_vector.ct, - x, - ) - num_failed += 1 - continue - } - - ok = common.hexbytes_compare(test_vector.tag, tag_) - if !result_check(test_vector.result, ok) { - x := transmute(string)(hex.encode(tag_)) - log.errorf( - "aead/aes-gcm/%v/%d: tag: expected %s actual %s", - impl, - test_vector.tc_id, - test_vector.tag, - x, - ) - num_failed += 1 - continue - } - } - - msg_ := make([]byte, len(msg)) - ok := aes.open_gcm(&ctx, msg_, iv, aad, ct, tag) - if !result_check(test_vector.result, ok) { - log.errorf("aead/aes-gcm/%v/%d: decrypt failed", impl, test_vector.tc_id) - num_failed += 1 - continue - } - - if ok && !common.hexbytes_compare(test_vector.msg, msg_) { - x := transmute(string)(hex.encode(msg_)) - log.errorf( - "aead/aes-gcm/%v/%d: decrypt msg: expected %s actual %s", - impl, - test_vector.tc_id, - test_vector.msg, - x, - ) - num_failed += 1 - continue - } - - num_passed += 1 - } - } - - assert(num_ran == test_vectors.number_of_tests) - assert(num_passed + num_failed + num_skipped == num_ran) - - log.infof( - "aead/aes-gcm: ran %d, passed %d, failed %d, skipped %d", - num_ran, - num_passed, - num_failed, - num_skipped, - ) - - return num_failed == 0 -} - -supported_chacha_impls :: proc() -> [dynamic]chacha20.Implementation { - impls := make([dynamic]chacha20.Implementation, 0, 3) - append(&impls, chacha20.Implementation.Portable) - if chacha_simd128.is_performant() { - append(&impls, chacha20.Implementation.Simd128) - } - if chacha_simd256.is_performant() { - append(&impls, chacha20.Implementation.Simd256) - } - - return impls -} - -@(test) -test_aead_chacha20_poly1305 :: proc(t: ^testing.T) { - arena: mem.Arena - arena_backing := make([]byte, ARENA_SIZE) - defer delete(arena_backing) - mem.arena_init(&arena, arena_backing) - context.allocator = mem.arena_allocator(&arena) - - files := []string { - "chacha20_poly1305_test.json", - "xchacha20_poly1305_test.json", - } - - log.debug("aead/(x)chacha20poly1305: starting") - - for f, i in files { - mem.free_all() // Probably don't need this, but be safe. - - fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) - - test_vectors: Test_Vectors(Aead_Test_Group) - load_ok := load(&test_vectors, fn) - testing.expectf(t, load_ok, "Unable to load {}", f) - if !load_ok { - continue - } - - for impl in supported_chacha_impls() { - testing.expectf(t, test_aead_chacha20_poly1305_impl(&test_vectors, i == 1, impl), "impl {} failed", impl) - } - } -} - -test_aead_chacha20_poly1305_impl :: proc( - test_vectors: ^Test_Vectors(Aead_Test_Group), - is_xchacha: bool, - impl: chacha20.Implementation, -) -> bool { - FLAG_INVALID_NONCE_SIZE :: "InvalidNonceSize" - - alg_str := is_xchacha ? "xchacha20poly1305" : "chacha20poly1305" - - num_ran, num_passed, num_failed, num_skipped: int - for &test_group in test_vectors.test_groups { - for &test_vector in test_group.tests { - num_ran += 1 - - if comment := test_vector.comment; comment != "" { - log.debugf( - "aead/%s/%v/%d: %s: %+v", - alg_str, - impl, - test_vector.tc_id, - comment, - test_vector.flags, - ) - } else { - log.debugf("aead/%s/%v/%d: %+v", - alg_str, - impl, - test_vector.tc_id, - test_vector.flags, - ) - } - - key := common.hexbytes_decode(test_vector.key) - iv := common.hexbytes_decode(test_vector.iv) - aad := common.hexbytes_decode(test_vector.aad) - msg := common.hexbytes_decode(test_vector.msg) - ct := common.hexbytes_decode(test_vector.ct) - tag := common.hexbytes_decode(test_vector.tag) - - if slice.contains(test_vector.flags, FLAG_INVALID_NONCE_SIZE) { - log.infof( - "aead/%s/%v/%d: skipped, invalid nonces panic", - alg_str, - impl, - test_vector.tc_id, - ) - num_skipped += 1 - continue - } - - ctx: chacha20poly1305.Context - switch is_xchacha { - case true: - chacha20poly1305.init_xchacha(&ctx, key, impl) - case false: - chacha20poly1305.init(&ctx, key, impl) - } - - if result_is_valid(test_vector.result) { - ct_ := make([]byte, len(ct)) - tag_ := make([]byte, len(tag)) - chacha20poly1305.seal(&ctx, ct_, tag_, iv, aad, msg) - - ok := common.hexbytes_compare(test_vector.ct, ct_) - if !result_check(test_vector.result, ok) { - x := transmute(string)(hex.encode(ct_)) - log.errorf( - "aead/%s/%v/%d: ciphertext: expected %s actual %s", - alg_str, - impl, - test_vector.tc_id, - test_vector.ct, - x, - ) - num_failed += 1 - continue - } - - ok = common.hexbytes_compare(test_vector.tag, tag_) - if !result_check(test_vector.result, ok) { - x := transmute(string)(hex.encode(tag_)) - log.errorf( - "aead/%s/%v/%d: tag: expected %s actual %s", - alg_str, - impl, - test_vector.tc_id, - test_vector.tag, - x, - ) - num_failed += 1 - continue - } - } - - msg_ := make([]byte, len(msg)) - ok := chacha20poly1305.open(&ctx, msg_, iv, aad, ct, tag) - if !result_check(test_vector.result, ok) { - log.errorf("aead/%s/%v/%d: decrypt failed", - alg_str, - impl, - test_vector.tc_id, - ) - num_failed += 1 - continue - } - - if ok && !common.hexbytes_compare(test_vector.msg, msg_) { - x := transmute(string)(hex.encode(msg_)) - log.errorf( - "aead/%s/%v/%d: decrypt msg: expected %s actual %s", - alg_str, - impl, - test_vector.tc_id, - test_vector.msg, - x, - ) - num_failed += 1 - continue - } - - num_passed += 1 - } - } - - assert(num_ran == test_vectors.number_of_tests) - assert(num_passed + num_failed + num_skipped == num_ran) - - log.infof( - "aead/%s/%v: ran %d, passed %d, failed %d, skipped %d", - alg_str, - impl, - num_ran, - num_passed, - num_failed, - num_skipped, - ) - - return num_failed == 0 -} - -@(test) -test_aead_deoxysii :: proc(t: ^testing.T) { - ctx: deoxysii.Context - - key: [deoxysii.KEY_SIZE]byte - iv: [deoxysii.IV_SIZE]byte - tag: [deoxysii.TAG_SIZE]byte - buf: [4096]byte - - deoxysii.init(&ctx, key[:]) - deoxysii.seal(&ctx, buf[:], tag[:], iv[:], nil, buf[:]) - assert(deoxysii.open(&ctx, buf[:], iv[:], nil, buf[:], tag[:])) -} - -@(test) -test_eddsa_ed25519 :: proc(t: ^testing.T) { - arena: mem.Arena - arena_backing := make([]byte, ARENA_SIZE) - defer delete(arena_backing) - mem.arena_init(&arena, arena_backing) - context.allocator = mem.arena_allocator(&arena) - - fn_, _ := os.join_path([]string{BASE_PATH, "ed25519_test.json"}, context.allocator) - - log.debug("eddsa/ed25519: starting") - - test_vectors: Test_Vectors(Eddsa_Test_Group) - assert(load(&test_vectors, fn_)) - - num_ran, num_passed, num_failed, num_skipped: int - for &test_group, i in test_vectors.test_groups { - mem.free_all() // Probably don't need this, but be safe. - pk_bytes := common.hexbytes_decode(test_group.public_key.pk) - - pk: ed25519.Public_Key - pk_ok := ed25519.public_key_set_bytes(&pk, pk_bytes) - testing.expectf(t, pk_ok, "eddsa/ed25519/%d: invalid public key: %s", i, test_group.public_key.pk) - if !pk_ok { - num_failed += len(test_group.tests) - continue - } - - for &test_vector in test_group.tests { - num_ran += 1 - - if comment := test_vector.comment; comment != "" { - log.debugf( - "eddsa/ed25519/%d: %s: %+v", - test_vector.tc_id, - comment, - test_vector.flags, - ) - } else { - log.debugf("eddsa/ed25519/%d: %+v", test_vector.tc_id, test_vector.flags) - } - - msg := common.hexbytes_decode(test_vector.msg) - sig := common.hexbytes_decode(test_vector.sig) - - verify_ok := ed25519.verify(&pk, msg, sig) - result_ok := result_check(test_vector.result, verify_ok) - testing.expectf( - t, - result_ok, - "eddsa/ed25519/%d: verify failed: expected %s actual %v", - test_vector.tc_id, - test_vector.result, - verify_ok, - ) - if !result_ok { - num_failed += 1 - continue - } - num_passed += 1 - } - } - - assert(num_ran == test_vectors.number_of_tests) - assert(num_passed + num_failed + num_skipped == num_ran) - - log.infof( - "eddsa/ed25519: ran %d, passed %d, failed %d, skipped %d", - num_ran, - num_passed, - num_failed, - num_skipped, - ) -} - -@(test) -test_ecdsa :: proc(t: ^testing.T) { - arena: mem.Arena - arena_backing := make([]byte, ARENA_SIZE) - defer delete(arena_backing) - mem.arena_init(&arena, arena_backing) - context.allocator = mem.arena_allocator(&arena) - - log.debug("ecdsa: starting") - - files := []string { - "ecdsa_secp256r1_sha256_test.json", - "ecdsa_secp256r1_sha512_test.json", - "ecdsa_secp384r1_sha384_test.json", - } - - for f in files { - mem.free_all() // Probably don't need this, but be safe. - - fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) - - test_vectors: Test_Vectors(Ecdsa_Test_Group) - load_ok := load(&test_vectors, fn) - testing.expectf(t, load_ok, "Unable to load {}", f) - if !load_ok { - continue - } - - testing.expectf(t, test_ecdsa_impl(t, &test_vectors), "ecdsa failed") - } -} - -test_ecdsa_impl :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Ecdsa_Test_Group)) -> bool { - curve_str := test_vectors.test_groups[0].public_key.curve - hash_str := test_vectors.test_groups[0].sha - - curve_alg: ecdsa.Curve - switch curve_str { - case "secp256r1": - curve_alg = .SECP256R1 - case "secp384r1": - curve_alg = .SECP384R1 - case: - log.errorf("ecdsa: unsupported curve: %s", curve_str) - } - - hash_alg: hash.Algorithm - switch hash_str { - case "SHA-256": - hash_alg = .SHA256 - case "SHA-384": - hash_alg = .SHA384 - case "SHA-512": - hash_alg = .SHA512 - case: - log.errorf("ecdsa: unsupported hash: %s", hash_str) - } - - log.debugf("ecdsa/%s/%s: starting", curve_str, hash_str) - - num_ran, num_passed, num_failed, num_skipped: int - for &test_group, i in test_vectors.test_groups { - pk_bytes := common.hexbytes_decode(test_group.public_key.uncompressed) - - pk: ecdsa.Public_Key - pk_ok := ecdsa.public_key_set_bytes(&pk, curve_alg, pk_bytes) - testing.expectf(t, pk_ok, "ecdsa/%s/%s/%d: invalid public key: %s", curve_str, hash_str, i, test_group.public_key.uncompressed) - if !pk_ok { - num_failed += len(test_group.tests) - continue - } - - for &test_vector in test_group.tests { - num_ran += 1 - - if comment := test_vector.comment; comment != "" { - log.debugf( - "ecda/%s/%s/%d: %s: %+v", - curve_str, - hash_str, - test_vector.tc_id, - comment, - test_vector.flags, - ) - } else { - log.debugf("ecdsa/%s/%s/%d: %+v", curve_str, hash_str, test_vector.tc_id, test_vector.flags) - } - - msg := common.hexbytes_decode(test_vector.msg) - sig := common.hexbytes_decode(test_vector.sig) - - verify_ok := ecdsa.verify_asn1(&pk, hash_alg, msg, sig) - result_ok := result_check(test_vector.result, verify_ok) - testing.expectf( - t, - result_ok, - "ecdsa/%s/%s/%d: verify failed: expected %s actual %v", - curve_str, - hash_str, - test_vector.tc_id, - test_vector.result, - verify_ok, - ) - if !result_ok { - num_failed += 1 - continue - } - - num_passed += 1 - } - } - - assert(num_ran == test_vectors.number_of_tests) - assert(num_passed + num_failed + num_skipped == num_ran) - - log.infof( - "ecdsa/%s/%s: ran %d, passed %d, failed %d, skipped %d", - curve_str, - hash_str, - num_ran, - num_passed, - num_failed, - num_skipped, - ) - - return num_failed == 0 -} - -@(test) -test_hkdf :: proc(t: ^testing.T) { - arena: mem.Arena - arena_backing := make([]byte, ARENA_SIZE) - defer delete(arena_backing) - mem.arena_init(&arena, arena_backing) - context.allocator = mem.arena_allocator(&arena) - - log.debug("hkdf: starting") - - files := []string { - "hkdf_sha1_test.json", - "hkdf_sha256_test.json", - "hkdf_sha384_test.json", - "hkdf_sha512_test.json", - } - - for f in files { - mem.free_all() // Probably don't need this, but be safe. - - fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) - - test_vectors: Test_Vectors(Hkdf_Test_Group) - load_ok := load(&test_vectors, fn) - testing.expectf(t, load_ok, "Unable to load {}", f) - if !load_ok { - continue - } - - testing.expectf(t, test_hkdf_impl(&test_vectors), "hkdf failed") - } -} - -test_hkdf_impl :: proc(test_vectors: ^Test_Vectors(Hkdf_Test_Group)) -> bool { - PREFIX_HKDF :: "HKDF-" - FLAG_SIZE_TOO_LARGE :: "SizeTooLarge" - - alg_str := strings.trim_prefix(test_vectors.algorithm, PREFIX_HKDF) - alg, ok := hash_name_to_algorithm(alg_str) - if !ok { - return false - } - alg_str = strings.to_lower(alg_str) - - log.debugf("hkdf/%s: starting", alg_str) - - num_ran, num_passed, num_failed, num_skipped: int - for &test_group in test_vectors.test_groups { - for &test_vector in test_group.tests { - num_ran += 1 - - if comment := test_vector.comment; comment != "" { - log.debugf( - "hkdf/%s/%d: %s: %+v", - alg_str, - test_vector.tc_id, - comment, - test_vector.flags, - ) - } else { - log.debugf("hkdf/%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags) - } - - ikm := common.hexbytes_decode(test_vector.ikm) - salt := common.hexbytes_decode(test_vector.salt) - info := common.hexbytes_decode(test_vector.info) - - if slice.contains(test_vector.flags, FLAG_SIZE_TOO_LARGE) { - log.infof( - "hkdf/%s/%d: skipped, oversized outputs panic", - alg_str, - test_vector.tc_id, - ) - num_skipped += 1 - continue - } - - okm_ := make([]byte, test_vector.size) - hkdf.extract_and_expand(alg, salt, ikm, info, okm_) - - ok = common.hexbytes_compare(test_vector.okm, okm_) - if !result_check(test_vector.result, ok) { - x := transmute(string)(hex.encode(okm_)) - log.errorf( - "hkdf/%s/%d: shared: expected %s actual %s", - alg_str, - test_vector.tc_id, - test_vector.okm, - x, - ) - num_failed += 1 - continue - } - - num_passed += 1 - } - } - - assert(num_ran == test_vectors.number_of_tests) - assert(num_passed + num_failed + num_skipped == num_ran) - - log.infof( - "hkdf/%s: ran %d, passed %d, failed %d, skipped %d", - alg_str, - num_ran, - num_passed, - num_failed, - num_skipped, - ) - - return num_failed == 0 -} - -@(test) -test_mac :: proc(t: ^testing.T) { - arena: mem.Arena - arena_backing := make([]byte, ARENA_SIZE) - defer delete(arena_backing) - mem.arena_init(&arena, arena_backing) - context.allocator = mem.arena_allocator(&arena) - - log.debug("mac: starting") - - files := []string { - "hmac_sha1_test.json", - "hmac_sha224_test.json", - "hmac_sha256_test.json", - "hmac_sha3_224_test.json", - "hmac_sha3_256_test.json", - "hmac_sha3_384_test.json", - "hmac_sha3_512_test.json", - "hmac_sha384_test.json", - // "hmac_sha512_224_test.json", - "hmac_sha512_256_test.json", - "hmac_sha512_test.json", - "hmac_sm3_test.json", - "kmac128_no_customization_test.json", - "kmac256_no_customization_test.json", - "siphash_1_3_test.json", - "siphash_2_4_test.json", - "siphash_4_8_test.json", - } - - for f in files { - mem.free_all() // Probably don't need this, but be safe. - - fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) - - test_vectors: Test_Vectors(Mac_Test_Group) - load_ok := load(&test_vectors, fn) - testing.expectf(t, load_ok, "Unable to load {}", f) - if !load_ok { - continue - } - - testing.expectf(t, test_mac_impl(&test_vectors), "hkdf failed") - } -} - -test_mac_impl :: proc(test_vectors: ^Test_Vectors(Mac_Test_Group)) -> bool { - PREFIX_HMAC :: "HMAC" - PREFIX_KMAC :: "KMAC" - - mac_alg, hmac_alg, alg_str, ok := mac_algorithm(test_vectors.algorithm) - if !ok { - log.errorf("mac: unsupported algorith: %s", test_vectors.algorithm) - return false - } - - log.debugf("%s: starting", alg_str) - - num_ran, num_passed, num_failed, num_skipped: int - for &test_group in test_vectors.test_groups { - for &test_vector in test_group.tests { - num_ran += 1 - - if comment := test_vector.comment; comment != "" { - log.debugf( - "%s/%d: %s: %+v", - alg_str, - test_vector.tc_id, - comment, - test_vector.flags, - ) - } else { - log.debugf("%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags) - } - - key := common.hexbytes_decode(test_vector.key) - msg := common.hexbytes_decode(test_vector.msg) - - tag_ := make([]byte, len(test_vector.tag) / 2) - - #partial switch mac_alg { - case .HMAC: - ctx: hmac.Context - hmac.init(&ctx, hmac_alg, key) - hmac.update(&ctx, msg) - if l := hmac.tag_size(&ctx); l == len(tag_) { - hmac.final(&ctx, tag_) - } else { - // Our hmac package does not support truncation. - tmp := make([]byte, l) - hmac.final(&ctx, tmp) - copy(tag_, tmp) - } - case .KMAC128, .KMAC256: - ctx: kmac.Context - #partial switch mac_alg { - case .KMAC128: - kmac.init_128(&ctx, key, nil) - case .KMAC256: - kmac.init_256(&ctx, key, nil) - } - kmac.update(&ctx, msg) - kmac.final(&ctx, tag_) - case .SIPHASH_1_3: - siphash.sum_1_3(msg, key, tag_) - case .SIPHASH_2_4: - siphash.sum_2_4(msg, key, tag_) - case .SIPHASH_4_8: - siphash.sum_4_8(msg, key, tag_) - } - - ok = common.hexbytes_compare(test_vector.tag, tag_) - if !result_check(test_vector.result, ok) { - x := transmute(string)(hex.encode(tag_)) - log.errorf( - "%s/%d: tag: expected %s actual %s", - alg_str, - test_vector.tc_id, - test_vector.tag, - x, - ) - num_failed += 1 - continue - } - - num_passed += 1 - } - } - - assert(num_ran == test_vectors.number_of_tests) - assert(num_passed + num_failed + num_skipped == num_ran) - - log.infof( - "%s: ran %d, passed %d, failed %d, skipped %d", - alg_str, - num_ran, - num_passed, - num_failed, - num_skipped, - ) - - return num_failed == 0 -} - -@(test) -test_pbkdf2 :: proc(t: ^testing.T) { - arena: mem.Arena - arena_backing := make([]byte, ARENA_SIZE) - defer delete(arena_backing) - mem.arena_init(&arena, arena_backing) - context.allocator = mem.arena_allocator(&arena) - - log.debug("pbkdf2: starting") - - files := []string { - "pbkdf2_hmacsha1_test.json", - "pbkdf2_hmacsha224_test.json", - "pbkdf2_hmacsha256_test.json", - "pbkdf2_hmacsha384_test.json", - "pbkdf2_hmacsha512_test.json", - } - - for f in files { - mem.free_all() // Probably don't need this, but be safe. - - fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) - - test_vectors: Test_Vectors(Pbkdf_Test_Group) - load_ok := load(&test_vectors, fn) - testing.expectf(t, load_ok, "Unable to load {}", f) - if !load_ok { - continue - } - - testing.expectf(t, test_pbkdf2_impl(&test_vectors), "pbkdf2 failed") - } -} - -test_pbkdf2_impl :: proc( - test_vectors: ^Test_Vectors(Pbkdf_Test_Group), -) -> bool { - PREFIX_PBKDF_HMAC :: "PBKDF2-HMAC" - FLAG_LARGE_ITERATION_COUNT :: "LargeIterationCount" - - alg_str := strings.trim_prefix(test_vectors.algorithm, PREFIX_PBKDF_HMAC) - alg, ok := hash_name_to_algorithm(alg_str) - if !ok { - return false - } - alg_str = strings.to_lower(alg_str) - - log.debugf("pbkdf2/hmac-%s: starting", alg_str) - - num_ran, num_passed, num_failed, num_skipped: int - for &test_group in test_vectors.test_groups { - for &test_vector in test_group.tests { - num_ran += 1 - - if comment := test_vector.comment; comment != "" { - log.debugf( - "pbkdf2/hmac-%s/%d: %s: %+v", - alg_str, - test_vector.tc_id, - comment, - test_vector.flags, - ) - } else { - log.debugf("pbkdf2/hmac-%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags) - } - - if slice.contains(test_vector.flags, FLAG_LARGE_ITERATION_COUNT) { - log.infof( - "pbkdf2/hmac-%s/%d: skipped, takes fucking forever", - alg_str, - test_vector.tc_id, - ) - num_skipped += 1 - continue - } - - password := common.hexbytes_decode(test_vector.password) - salt := common.hexbytes_decode(test_vector.salt) - - dk_ := make([]byte, test_vector.dk_len) - pbkdf2.derive(alg, password, salt, test_vector.iteration_count, dk_) - - ok = common.hexbytes_compare(test_vector.dk, dk_) - if !result_check(test_vector.result, ok) { - x := transmute(string)(hex.encode(dk_)) - log.errorf( - "pbkdf2/hmac-%s/%d: shared: expected %s actual %s", - alg_str, - test_vector.tc_id, - test_vector.dk, - x, - ) - num_failed += 1 - continue - } - - num_passed += 1 - } - } - - assert(num_ran == test_vectors.number_of_tests) - assert(num_passed + num_failed + num_skipped == num_ran) - - log.infof( - "pbkdf2/%s: ran %d, passed %d, failed %d, skipped %d", - alg_str, - num_ran, - num_passed, - num_failed, - num_skipped, - ) - - return num_failed == 0 -} - -@(test) -test_ecdh :: proc(t: ^testing.T) { - arena: mem.Arena - arena_backing := make([]byte, ARENA_SIZE) - defer delete(arena_backing) - mem.arena_init(&arena, arena_backing) - context.allocator = mem.arena_allocator(&arena) - - PREFIX_TEST_ECDH :: "ecdh_" - SUFFIX_TEST_ECPOINT :: "_ecpoint" - - files := []string { - "ecdh_secp256r1_ecpoint_test.json", - "ecdh_secp384r1_ecpoint_test.json", - "x25519_test.json", - "x448_test.json", - } - - log.debug("ecdh: starting") - - for f in files { - mem.free_all() // Probably don't need this, but be safe. - - fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) - - test_vectors: Test_Vectors(Ecdh_Test_Group) - load_ok := load(&test_vectors, fn) - testing.expectf(t, load_ok, "Unable to load {}", f) - if !load_ok { - continue - } - - alg_str := strings.trim_suffix(f, SUFFIX_TEST_JSON) - alg_str = strings.trim_suffix(alg_str, SUFFIX_TEST_ECPOINT) - alg_str = strings.trim_prefix(alg_str, PREFIX_TEST_ECDH) - testing.expectf(t, test_ecdh_impl(&test_vectors, alg_str), "alg {} failed", alg_str) - } -} - -test_ecdh_impl :: proc( - test_vectors: ^Test_Vectors(Ecdh_Test_Group), - alg_str: string, -) -> bool { - ALG_P256 :: "secp256r1" - ALG_P384 :: "secp384r1" - ALG_X25519 :: "x25519" - ALG_X448 :: "x448" - - // XDH exceptions - FLAG_PUBLIC_KEY_TOO_LONG :: "PublicKeyTooLong" - FLAG_ZERO_SHARED_SECRET :: "ZeroSharedSecret" - - // ECDH exceptions - FLAG_COMPRESSED_POINT :: "CompressedPoint" - FLAG_INVALID_CURVE :: "InvalidCurveAttack" - FLAG_INVALID_ENCODING :: "InvalidEncoding" - - log.debugf("ecdh/%s: starting", alg_str) - - num_ran, num_passed, num_failed, num_skipped: int - for &test_group in test_vectors.test_groups { - for &test_vector in test_group.tests { - num_ran += 1 - - if comment := test_vector.comment; comment != "" { - log.debugf("ecdh/%s/%d: %s: %+v", alg_str, test_vector.tc_id, comment, test_vector.flags) - } else { - log.debugf("ecdh/%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags) - } - - raw_pub := common.hexbytes_decode(test_vector.public) - raw_priv := common.hexbytes_decode(test_vector.private) - - curve: ecdh.Curve - priv_key: ecdh.Private_Key - pub_key: ecdh.Public_Key - - is_nist, is_xdh: bool - switch alg_str { - case ALG_P256: - curve = .SECP256R1 - // Ugh, ASN.1 :( - l := len(raw_priv) - if l == 33 { - if raw_priv[0] == 0 { - raw_priv = raw_priv[1:] - } - } else if l < 32 { - // left-pad.odin - tmp := make([]byte, 32) - copy(tmp[32-l:], raw_priv) - raw_priv = tmp - } - is_nist = true - case ALG_P384: - curve = .SECP384R1 - // Ugh, ASN.1 :( - l := len(raw_priv) - if l == 49 { - if raw_priv[0] == 0 { - raw_priv = raw_priv[1:] - } - } else if l < 48 { - // left-pad.odin - tmp := make([]byte, 48) - copy(tmp[48-l:], raw_priv) - raw_priv = tmp - } - is_nist = true - case ALG_X25519: - curve = .X25519 - is_xdh = true - case ALG_X448: - curve = .X448 - is_xdh = true - case: - log.errorf("ecdh: unsupported algorithm: %s", alg_str) - return false - } - - if ok := ecdh.private_key_set_bytes(&priv_key, curve, raw_priv); !ok { - log.errorf( - "ecdh/%s/%d: failed to deserialize private_key: %s %d %x", - alg_str, - test_vector.tc_id, - test_vector.private, - len(raw_priv), - raw_priv, - ) - num_failed += 1 - continue - } - - if ok := ecdh.public_key_set_bytes(&pub_key, curve, raw_pub); !ok { - if is_nist { - if slice.contains(test_vector.flags, FLAG_COMPRESSED_POINT) { - num_passed += 1 - continue - } - if slice.contains(test_vector.flags, FLAG_INVALID_CURVE) { - num_passed += 1 - continue - } - if slice.contains(test_vector.flags, FLAG_INVALID_ENCODING) { - num_passed += 1 - continue - } - } - if slice.contains(test_vector.flags, FLAG_PUBLIC_KEY_TOO_LONG) { - num_passed += 1 - continue - } - - log.errorf( - "ecdh/%s/%d: failed to deserialize public_key: %s", - alg_str, - test_vector.tc_id, - test_vector.public, - ) - num_failed += 1 - continue - } - - shared := make([]byte, ecdh.SHARED_SECRET_SIZES[curve]) - - ok := ecdh.ecdh(&priv_key, &pub_key, shared) - if !ok { - if is_xdh && slice.contains(test_vector.flags, FLAG_ZERO_SHARED_SECRET) { - num_passed += 1 - continue - } - // unused: x := transmute(string)(hex.encode(shared)) - log.errorf( - "ecdh/%s/%d: ecdh failed", - alg_str, - test_vector.tc_id, - ) - num_failed += 1 - continue - } - - ok = common.hexbytes_compare(test_vector.shared, shared) - // "acceptable" results are fine from here because we have - // checked for the all-zero shared secret XDH case already. - if !result_check(test_vector.result, ok, false) { - x := transmute(string)(hex.encode(shared)) - log.errorf( - "ecdh/%s/%d: shared: expected %s actual %s", - alg_str, - test_vector.tc_id, - test_vector.shared, - x, - ) - num_failed += 1 - continue - } - - num_passed += 1 - } - } - - assert(num_ran == test_vectors.number_of_tests) - assert(num_passed + num_failed + num_skipped == num_ran) - - log.infof( - "ecdh/%s: ran %d, passed %d, failed %d, skipped %d", - alg_str, - num_ran, - num_passed, - num_failed, - num_skipped, - ) - - return num_failed == 0 -} From d2c29c025ef3ba17b704cd4eb293b5192e77af03 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Wed, 1 Apr 2026 04:21:36 +0900 Subject: [PATCH 4/4] core/crypto/mlkem: Initial import --- core/crypto/_mlkem/cbd.odin | 52 +++ core/crypto/_mlkem/constants.odin | 53 +++ core/crypto/_mlkem/k_pke.odin | 349 +++++++++++++++++ core/crypto/_mlkem/kem_internal.odin | 195 ++++++++++ core/crypto/_mlkem/ntt.odin | 75 ++++ core/crypto/_mlkem/poly.odin | 241 ++++++++++++ core/crypto/_mlkem/polyvec.odin | 224 +++++++++++ core/crypto/_mlkem/reduce.odin | 19 + core/crypto/_mlkem/symmetric_shake.odin | 61 +++ core/crypto/_subtle/subtle.odin | 39 ++ core/crypto/mlkem/api.odin | 268 +++++++++++++ core/crypto/mlkem/doc.odin | 7 + examples/all/all_js.odin | 1 + examples/all/all_main.odin | 1 + tests/benchmark/crypto/benchmark_pqc.odin | 84 ++++ tests/core/crypto/test_core_crypto_pqc.odin | 72 ++++ tests/core/crypto/wycheproof/main.odin | 10 + tests/core/crypto/wycheproof/pqc.odin | 401 ++++++++++++++++++++ tests/core/crypto/wycheproof/schemas.odin | 25 ++ 19 files changed, 2177 insertions(+) create mode 100644 core/crypto/_mlkem/cbd.odin create mode 100644 core/crypto/_mlkem/constants.odin create mode 100644 core/crypto/_mlkem/k_pke.odin create mode 100644 core/crypto/_mlkem/kem_internal.odin create mode 100644 core/crypto/_mlkem/ntt.odin create mode 100644 core/crypto/_mlkem/poly.odin create mode 100644 core/crypto/_mlkem/polyvec.odin create mode 100644 core/crypto/_mlkem/reduce.odin create mode 100644 core/crypto/_mlkem/symmetric_shake.odin create mode 100644 core/crypto/mlkem/api.odin create mode 100644 core/crypto/mlkem/doc.odin create mode 100644 tests/benchmark/crypto/benchmark_pqc.odin create mode 100644 tests/core/crypto/test_core_crypto_pqc.odin create mode 100644 tests/core/crypto/wycheproof/pqc.odin diff --git a/core/crypto/_mlkem/cbd.odin b/core/crypto/_mlkem/cbd.odin new file mode 100644 index 000000000..301551ff5 --- /dev/null +++ b/core/crypto/_mlkem/cbd.odin @@ -0,0 +1,52 @@ +#+private +package _mlkem + +import "core:encoding/endian" + +unchecked_get_u24le :: #force_inline proc "contextless" (b: []byte) -> u32 #no_bounds_check { + r := u32(b[0]) + r |= u32(b[1]) << 8 + r |= u32(b[2]) << 16 + return r +} + +cbd3 :: proc "contextless" (r: ^Poly, buf: ^[3*N/4]byte) #no_bounds_check { + for i in 0..>1) & 0x00249249 + d += (t>>2) & 0x00249249 + + for j in uint(0)..<4 { + a := i16((d >> (6*j+0)) & 0x7) + b := i16((d >> (6*j+3)) & 0x7) + r.coeffs[4*i+int(j)] = a - b + } + } +} + +cbd2 :: proc "contextless" (r: ^Poly, buf: ^[2*N/4]byte) #no_bounds_check { + for i in 0..>1) & 0x55555555 + + for j in uint(0)..<8 { + a := i16((d >> (4*j+0)) & 0x3) + b := i16((d >> (4*j+2)) & 0x3) + r.coeffs[8*i+int(j)] = a - b + } + } +} + +poly_cbd_eta1_512 :: proc "contextless" (r: ^Poly, buf: ^[ETA1_512*N/4]byte) { + cbd3(r, buf) +} + +poly_cbd_eta1 :: proc "contextless" (r: ^Poly, buf: ^[ETA1*N/4]byte) { + cbd2(r, buf) +} + +poly_cbd_eta2 :: proc "contextless" (r: ^Poly, buf: ^[ETA2*N/4]byte) { + cbd2(r, buf) +} diff --git a/core/crypto/_mlkem/constants.odin b/core/crypto/_mlkem/constants.odin new file mode 100644 index 000000000..7195d8b61 --- /dev/null +++ b/core/crypto/_mlkem/constants.odin @@ -0,0 +1,53 @@ +package _mlkem + +K_512 :: 2 +K_768 :: 3 +K_1024 :: 4 +K_MAX :: K_1024 + +N :: 256 +Q :: 3329 + +ETA1_512 :: 3 +ETA1 :: 2 +ETA2 :: 2 + +POLYBYTES :: 384 +SYMBYTES :: 32 + +POLYCOMPRESSEDBYTES_512 :: 128 +POLYCOMPRESSEDBYTES_768 :: 128 +POLYCOMPRESSEDBYTES_1024 :: 160 + +POLYVECBYTES_512 :: K_512 * POLYBYTES +POLYVECBYTES_768 :: K_768 * POLYBYTES +POLYVECBYTES_1024 :: K_1024 * POLYBYTES + +POLYVECCOMPRESSEDBYTES_512 :: K_512 * 320 +POLYVECCOMPRESSEDBYTES_768 :: K_768 * 320 +POLYVECCOMPRESSEDBYTES_1024 :: K_1024 * 352 + +INDCPA_MSGBYTES :: SYMBYTES +INDCPA_PUBLICKEYBYTES_512 :: POLYVECBYTES_512 + SYMBYTES +INDCPA_SECRETKEYBYTES_512 :: POLYVECBYTES_512 +INDCPA_PUBLICKEYBYTES_768 :: POLYVECBYTES_768 + SYMBYTES +INDCPA_SECRETKEYBYTES_768 :: POLYVECBYTES_768 +INDCPA_PUBLICKEYBYTES_1024 :: POLYVECBYTES_1024 + SYMBYTES +INDCPA_SECRETKEYBYTES_1024 :: POLYVECBYTES_1024 +INDCPA_PUBLICKEYBYTES_MAX :: INDCPA_PUBLICKEYBYTES_1024 +INDCPA_BYTES_512 :: POLYVECCOMPRESSEDBYTES_512 + POLYCOMPRESSEDBYTES_512 +INDCPA_BYTES_768 :: POLYVECCOMPRESSEDBYTES_768 + POLYCOMPRESSEDBYTES_768 +INDCPA_BYTES_1024 :: POLYVECCOMPRESSEDBYTES_1024 + POLYCOMPRESSEDBYTES_1024 + +ENCAPSKEYBYTES_512 :: INDCPA_PUBLICKEYBYTES_512 +ENCAPSKEYBYTES_768 :: INDCPA_PUBLICKEYBYTES_768 +ENCAPSKEYBYTES_1024 :: INDCPA_PUBLICKEYBYTES_1024 + +DECAPSKEYBYTES_512 :: INDCPA_SECRETKEYBYTES_512 + INDCPA_PUBLICKEYBYTES_512 + 2 * SYMBYTES +DECAPSKEYBYTES_768 :: INDCPA_SECRETKEYBYTES_768 + INDCPA_PUBLICKEYBYTES_768 + 2 * SYMBYTES +DECAPSKEYBYTES_1024 :: INDCPA_SECRETKEYBYTES_1024 + INDCPA_PUBLICKEYBYTES_1024 + 2 * SYMBYTES + +CIPHERTEXTBYTES_512 :: INDCPA_BYTES_512 +CIPHERTEXTBYTES_768 :: INDCPA_BYTES_768 +CIPHERTEXTBYTES_1024 :: INDCPA_BYTES_1024 +CIPHERTEXTBYTES_MAX :: INDCPA_BYTES_1024 diff --git a/core/crypto/_mlkem/k_pke.odin b/core/crypto/_mlkem/k_pke.odin new file mode 100644 index 000000000..d8b87ad7a --- /dev/null +++ b/core/crypto/_mlkem/k_pke.odin @@ -0,0 +1,349 @@ +#+private +package _mlkem + +import "core:crypto" +import "core:crypto/shake" + +@(require_results) +pack_pk :: proc "contextless" (r: []byte, pk: ^Polyvec, seed: []byte, k: int) -> bool { + pk_len := polyvec_byte_size(k) + switch { + case pk_len == 0: + return false + case len(seed) != SYMBYTES || len(r) != pk_len + SYMBYTES: + return false + } + + polyvec_tobytes(r[:pk_len], pk, k) + copy(r[pk_len:], seed) + + return true +} + +@(require_results) +unpack_pk :: proc "contextless" (pk: ^Polyvec, seed, packedpk: []byte) -> bool { + pk_len := len(packedpk) - SYMBYTES + k: int + switch { + case pk_len == POLYVECBYTES_512: + k = K_512 + case pk_len == POLYVECBYTES_768: + k = K_768 + case pk_len == POLYVECBYTES_1024: + k = K_1024 + case len(packedpk) - pk_len != SYMBYTES: + return false + case len(seed) != SYMBYTES: + return false + } + if k == 0 { + return false + } + + ok := polyvec_frombytes(pk, packedpk[:pk_len], k) + copy(seed, packedpk[pk_len:]) + + return ok +} + +@(require_results) +pack_sk :: proc "contextless" (r: []byte, sk: ^Polyvec, k: int) -> bool { + r_len := len(r) + if r_len == 0 || r_len != polyvec_byte_size(k) { + return false + } + + polyvec_tobytes(r, sk, k) + + return true +} + +@(require_results) +unpack_sk :: proc "contextless" (sk: ^Polyvec, packedsk: []byte) -> bool { + k: int + switch len(packedsk) { + case POLYVECBYTES_512: + k = K_512 + case POLYVECBYTES_768: + k = K_768 + case POLYVECBYTES_1024: + k = K_1024 + case: + return false + } + if k == 0 { + return false + } + + return polyvec_frombytes(sk, packedsk, k) +} + +@(require_results) +pack_ciphertext :: proc "contextless" (r: []byte, b: ^Polyvec, v: ^Poly, k: int) -> bool { + b_len := polyvec_compressed_byte_size(k) + if len(r) != b_len + poly_compressed_bytes(k) { + return false + } + + polyvec_compress(r[:b_len], b, k) + poly_compress(r[b_len:], v) + + return true +} + +@(require_results) +unpack_ciphertext :: proc "contextless" (b: ^Polyvec, v: ^Poly, c: []byte) -> int { + b_len: int + k: int + switch len(c) { + case INDCPA_BYTES_512: + b_len = POLYVECCOMPRESSEDBYTES_512 + k = K_512 + case INDCPA_BYTES_768: + b_len = POLYVECCOMPRESSEDBYTES_768 + k = K_768 + case INDCPA_BYTES_1024: + b_len = POLYVECCOMPRESSEDBYTES_1024 + k = K_1024 + case: + return 0 + } + + polyvec_decompress(b, c[:b_len], k) + poly_decompress(v, c[b_len:]) + + return k +} + +@(require_results) +rej_uniform :: proc "contextless" (r: []i16, buf: []byte) -> int { + r_len, b_len := len(r), len(buf) + + ctr, pos: int + for ctr < r_len && pos + 3 <= b_len { + val0 := (u16(buf[pos+0] >> 0) | (u16(buf[pos+1]) << 8)) & 0xFFF + val1 := (u16(buf[pos+1] >> 4) | (u16(buf[pos+2]) << 4)) & 0xFFF + pos += 3 + + if val0 < Q { + r[ctr] = i16(val0) + ctr += 1 + } + if(ctr < r_len && val1 < Q) { + r[ctr] = i16(val1) + ctr += 1 + } + } + + return ctr +} + +gen_matrix :: proc(a: []Polyvec, seed: []byte, transposed: bool, k: int) { + GEN_MATRIX_NBLOCKS :: ((12*N/8*(1 << 12)/Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES) + + buf: [GEN_MATRIX_NBLOCKS*XOF_BLOCKBYTES]byte = --- + ctx: shake.Context = --- + ctr: int + + defer shake.reset(&ctx) + defer crypto.zero_explicit(&buf, size_of(buf)) + + for i in 0.. bool { + ensure(len(m) == INDCPA_MSGBYTES, "crypto/mlkem: invalid K-PKE m") + ensure(len(r) == SYMBYTES, "crypto/mlkem: invalid K-PKE r") + + k := ek.k + + at_: [K_MAX]Polyvec = --- + sp, ep, b: Polyvec = ---, ---, --- + kay, epp, v: Poly = ---, ---, --- + defer crypto.zero_explicit(&at_, size_of(Polyvec) * k) + defer polyvec_clear(&sp, &ep, &b) + defer poly_clear(&kay, &epp, &v) + + poly_frommsg(&kay, m) + + at := at_[:k] + + gen_matrix(at, ek.p[:], true, k) + + n := byte(0) + for i in 0.. bool { + if len(plaintext) != INDCPA_MSGBYTES { + return false + } + + k := dk.k + + b: Polyvec = --- + v, mp: Poly = ---, --- + defer poly_clear(&v, &mp) + + if unpack_ciphertext(&b, &v, c) != k { + return false + } + + polyvec_ntt(&b, k) + polyvec_basemul_acc_montgomery(&mp, &dk.pv, &b, k) + poly_invntt_tomont(&mp) + + poly_sub(&mp, &v, &mp) + poly_reduce(&mp) + + poly_tomsg(plaintext, &mp) + + return true +} diff --git a/core/crypto/_mlkem/kem_internal.odin b/core/crypto/_mlkem/kem_internal.odin new file mode 100644 index 000000000..b50aa4b58 --- /dev/null +++ b/core/crypto/_mlkem/kem_internal.odin @@ -0,0 +1,195 @@ +package _mlkem + +import "core:crypto" +import subtle "core:crypto/_subtle" + +// This implementation is derived from the PQ-CRYSTALS reference +// implementation [[ https://github.com/pq-crystals/kyber ]], +// primarily for licensing reasons. Arguably mlkem-native is +// a more "up to date" codebase, but the changes to the +// ref code is minor and they slapped an attribution-required +// license on something that was originally CC-0/Apache 2.0. + +// "Private Key" +Decapsulation_Key :: struct { + pke_dk: K_PKE_Decryption_Key, + ek: Encapsulation_Key, + seed: [SYMBYTES*2]byte, // (d, z) +} + +// "Public Key" +Encapsulation_Key :: struct { + pke_ek: K_PKE_Encryption_Key, + raw_bytes: [INDCPA_PUBLICKEYBYTES_MAX]byte, + h: [SYMBYTES]byte, +} + +decapsulation_key_expanded_bytes :: proc( + dk: ^Decapsulation_Key, + dst: []byte, +) { + sk := &dk.pke_dk + pv_len := polyvec_byte_size(sk.k) + ek_len := pv_len + SYMBYTES + + ek_bytes := dk.ek.raw_bytes[:ek_len] + + dst := dst + _ = pack_sk(dst[:pv_len], &sk.pv, sk.k) + + dst = dst[pv_len:] + copy(dst, ek_bytes) + dst = dst[ek_len:] + hash_h(dst[:SYMBYTES], ek_bytes) + dst = dst[SYMBYTES:] + copy(dst, dk.seed[SYMBYTES:]) +} + +@(require_results) +encapsulation_key_set_bytes :: proc( + ek: ^Encapsulation_Key, + k: int, + b: []byte, +) -> bool { + k_len: int + switch k { + case K_512: + k_len = ENCAPSKEYBYTES_512 + case K_768: + k_len = ENCAPSKEYBYTES_768 + case K_1024: + k_len = ENCAPSKEYBYTES_1024 + case: + return false + } + if len(b) != k_len { + return false + } + + pke_ek := &ek.pke_ek + ok := unpack_pk(&pke_ek.pv, pke_ek.p[:], b) + pke_ek.k = k + copy(ek.raw_bytes[:k_len], b) + hash_h(ek.h[:], b) + + // FIPS 203 unlike Kyber requires canonical encoding of + // encapsulation keys (Section 7,2), which is checked in + // unpack_pk. + + if !ok { + crypto.zero_explicit(ek, size_of(Encapsulation_Key)) + } + + return ok +} + +encapsulation_key_set_decaps :: proc(ek: ^Encapsulation_Key, dk: ^Decapsulation_Key) { + dk_ek := &dk.ek.pke_ek + ensure(dk_ek.k == K_512 || dk_ek.k == K_768 || dk_ek.k == K_1024, "crypto/mlkem: invalid decaps k") + + k_pke_encryption_key_set(&ek.pke_ek, dk_ek) + copy(ek.raw_bytes[:], dk.ek.raw_bytes[:]) + copy(ek.h[:], dk.ek.h[:]) +} + +// NIST's version of this also returns an encapsulation key, but our +// internal representation includes it as part of the decapsulation key +// in a more traditional "keypair" approach. +kem_keygen_internal :: proc( + dk: ^Decapsulation_Key, + seed: []byte, // (d, z) + k: int, +) { + ensure(len(seed) == 2 * SYMBYTES, "crypto/mlkem: invalid seed") + + dk_ek := &dk.ek + d, z := seed[:SYMBYTES], seed[SYMBYTES:] + + k_pke_keygen(&dk_ek.pke_ek, &dk.pke_dk, d, k) + + ek_len := polyvec_byte_size(k) + SYMBYTES + ek_bytes := dk_ek.raw_bytes[:ek_len] + ensure( + pack_pk(ek_bytes, &dk_ek.pke_ek.pv, dk_ek.pke_ek.p[:], k), + "crypto/mlkem: failed to pack K-PKE ek", + ) + hash_h(dk_ek.h[:], ek_bytes) + copy(dk.seed[:SYMBYTES], d) + copy(dk.seed[SYMBYTES:], z) +} + +// The `_internal` "de-randomized" versions of ML-KEM.Encaps and +// ML-KEM.Decaps are only ever to be called by the actual non-interal +// implementation or test cases. + +kem_encaps_internal :: proc( + shared_secret: []byte, + ciphertext: []byte, + ek: ^Encapsulation_Key, + randomness: []byte, +) { + ensure(len(shared_secret) == SYMBYTES, "crypto/mlkem: invalid K") + ensure(len(randomness) == SYMBYTES, "crypto/mlkem: invalid m") + ensure( + len(ciphertext) == ct_len_for_k(ek.pke_ek.k), + "crypto/mlkem: invalid ciphertext length", + ) + + buf: [2*SYMBYTES]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + hash_g(buf[:], randomness, ek.h[:]) + + // Can't fail, ciphertext length is valid. + _ = k_pke_encrypt(ciphertext, &ek.pke_ek, randomness, buf[SYMBYTES:]) + + copy(shared_secret, buf[:SYMBYTES]) +} + +kem_decaps_internal :: proc( + shared_secret: []byte, + dk: ^Decapsulation_Key, + ciphertext: []byte, +) { + ct_len := ct_len_for_k(dk.pke_dk.k) + ensure( + len(ciphertext) == ct_len, + "crypto/mlkem: invalid ciphertext length", + ) + + m_: [SYMBYTES]byte + defer crypto.zero_explicit(&m_, size_of(m_)) + + // Can't fail, ciphertext length is valid. + _ = k_pke_decrypt(m_[:], &dk.pke_dk, ciphertext) + + buf: [2*SYMBYTES]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + ek := &dk.ek + hash_g(buf[:], m_[:], ek.h[:]) + + rkprf(shared_secret, dk.seed[SYMBYTES:], ciphertext) + + ct_buf: [CIPHERTEXTBYTES_MAX]byte = --- + defer crypto.zero_explicit(&ct_buf, size_of(ct_buf)) + ct_ := ct_buf[:ct_len] + _ = k_pke_encrypt(ct_, &ek.pke_ek, m_[:], buf[SYMBYTES:]) + + ok := crypto.compare_constant_time(ciphertext, ct_) + subtle.cmov_bytes(shared_secret, buf[:SYMBYTES], ok) +} + +@(private="file") +ct_len_for_k :: proc(k: int) -> int { + switch k { + case K_512: + return CIPHERTEXTBYTES_512 + case K_768: + return CIPHERTEXTBYTES_768 + case K_1024: + return CIPHERTEXTBYTES_1024 + case: + panic("crypto/mlkem: invalid k for ciphertext length") + } +} diff --git a/core/crypto/_mlkem/ntt.odin b/core/crypto/_mlkem/ntt.odin new file mode 100644 index 000000000..c99c73487 --- /dev/null +++ b/core/crypto/_mlkem/ntt.odin @@ -0,0 +1,75 @@ +#+private +package _mlkem + +@(rodata) +ZETAS := [128]i16 { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, + -171, 622, 1577, 182, 962, -1202, -1474, 1468, + 573, -1325, 264, 383, -829, 1458, -1602, -130, + -681, 1017, 732, 608, -1542, 411, -205, -1571, + 1223, 652, -552, 1015, -1293, 1491, -282, -1544, + 516, -8, -320, -666, -1618, -1162, 126, 1469, + -853, -90, -271, 830, 107, -1421, -247, -951, + -398, 961, -1508, -725, 448, -1065, 677, -1275, + -1103, 430, 555, 843, -1251, 871, 1550, 105, + 422, 587, 177, -235, -291, -460, 1574, 1653, + -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, + 817, 1097, 603, 610, 1322, -1285, -1465, 384, + -1215, -136, 1218, -1335, -874, 220, -1187, -1659, + -1185, -1530, -1278, 794, -1510, -854, -870, 478, + -108, -308, 996, 991, 958, -1460, 1522, 1628, +} + +@(require_results) +fqmul :: #force_inline proc "contextless" (a, b: i16) -> i16 { + return montgomery_reduce(i32(a) * i32(b)) +} + +ntt :: proc "contextless" (r: ^[N]i16) #no_bounds_check { + j, k := 0, 1 + for l := 128; l >= 2; l >>= 1 { + for start := 0; start < N; start = j + l { + zeta := ZETAS[k] + k += 1 + for j = start; j < start + l; j += 1 { + t := fqmul(zeta, r[j+l]) + r[j+l] = r[j] - t + r[j] = r[j] + t + } + } + } +} + +invntt :: proc "contextless" (r: ^[N]i16) #no_bounds_check { + F : i16 : 1441 // mont^2/128 + + j, k := 0, 127 + for l := 2; l <= 128; l <<= 1 { + for start := 0; start < 256; start = j+l { + zeta := ZETAS[k] + k -= 1 + for j = start; j < start + l; j += 1 { + t := r[j] + r[j] = barrett_reduce(t + r[j+l]) + r[j+l] = r[j+l] - t + r[j+l] = fqmul(zeta, r[j+l]) + } + } + } + + for v, i in r { + r[i] = fqmul(v, F) + } +} + +@(require_results) +base_case_multiply :: proc "contextless" (a_0, a_1, b_0, b_1, zeta: i16) -> (i16, i16) { + r_0 := fqmul(a_1, b_1) + r_0 = fqmul(r_0, zeta) + r_0 += fqmul(a_0, b_0) + r_1 := fqmul(a_0, b_1) + r_1 += fqmul(a_1, b_0) + + return r_0, r_1 +} diff --git a/core/crypto/_mlkem/poly.odin b/core/crypto/_mlkem/poly.odin new file mode 100644 index 000000000..4a79e9770 --- /dev/null +++ b/core/crypto/_mlkem/poly.odin @@ -0,0 +1,241 @@ +#+private +package _mlkem + +import "core:crypto" +import subtle "core:crypto/_subtle" + +// Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial +// coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] +Poly :: struct { + coeffs: [N]i16, +} + +poly_compress :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check { + t: [8]byte = --- + defer crypto.zero_explicit(&t, size_of(t)) + + r := r + switch len(r) { + case POLYCOMPRESSEDBYTES_768: // Also covers _512 + for i in 0..> 15) & Q + // t[j] = ((((uint16_t)u << 4) + Q/2)/Q) & 15 + d0 := u32(u) << 4 + d0 += 1665 + d0 *= 80635 + d0 >>= 28 + t[j] = byte(d0) & 0xf + } + + r[0] = t[0] | (t[1] << 4) + r[1] = t[2] | (t[3] << 4) + r[2] = t[4] | (t[5] << 4) + r[3] = t[6] | (t[7] << 4) + r = r[4:] + } + case POLYCOMPRESSEDBYTES_1024: + for i in 0..> 15) & Q + // t[j] = ((((uint16_t)u << 5) + Q/2)/Q) & 31 + d0 := u32(u) << 5 + d0 += 1664 + d0 *= 40318 + d0 >>= 27 + t[j] = byte(d0) & 0x1f + } + + r[0] = (t[0] >> 0) | (t[1] << 5) + r[1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7) + r[2] = (t[3] >> 1) | (t[4] << 4) + r[3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6) + r[4] = (t[6] >> 2) | (t[7] << 3) + r = r[5:] + } + case: + panic_contextless("crypto/mlkem: invalid POLYCOMPRESSEDBYTES") + } +} + +poly_decompress :: proc "contextless" (r: ^Poly, a: []byte) { + a := a + switch len(a) { + case POLYCOMPRESSEDBYTES_768: // Also covers _512 + for i in 0..> 4) + r.coeffs[2*i+1] = i16(((u16(a[0] >> 4) * Q) + 8) >> 4) + a = a[1:] + } + case POLYCOMPRESSEDBYTES_1024: + t: [8]byte = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 0) + t[1] = (a[0] >> 5) | (a[1] << 3) + t[2] = (a[1] >> 2) + t[3] = (a[1] >> 7) | (a[2] << 1) + t[4] = (a[2] >> 4) | (a[3] << 4) + t[5] = (a[3] >> 1) + t[6] = (a[3] >> 6) | (a[4] << 2) + t[7] = (a[4] >> 3) + a = a[5:] + + for j in 0..<8 { + r.coeffs[8*i+j] = i16((u32(t[j] & 31) * Q + 16) >> 5) + } + } + case: + panic_contextless("crypto/mlkem: invalid POLYCOMPRESSEDBYTES") + } +} + +poly_tobytes :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check { + ensure_contextless(len(r) >= POLYBYTES) + + for i in 0..> 15) & Q) + t1 := u16(a.coeffs[2*i+1]) + t1 += u16((i16(t1) >> 15) & Q) + r[3*i+0] = byte(t0 >> 0) + r[3*i+1] = byte(t0 >> 8) | byte(t1 << 4) + r[3*i+2] = byte(t1 >> 4) + } +} + +@(require_results) +poly_frombytes :: proc "contextless" (r: ^Poly, a: []byte) -> bool #no_bounds_check { + ensure_contextless(len(a) >= POLYBYTES) + + ok := true + for i in 0..> 0) | (u16(a[3*i+1]) << 8)) & 0xFFF) + r.coeffs[2*i+1] = i16(((u16(a[3*i+1]) >> 4) | (u16(a[3*i+2]) << 4)) & 0xFFF) + ok &= r.coeffs[2*i] < Q && r.coeffs[2*i+1] < Q + } + + return ok +} + +poly_frommsg :: proc "contextless" (r: ^Poly, msg: []byte) #no_bounds_check { + #assert(INDCPA_MSGBYTES == N/8) + ensure_contextless(len(msg) == INDCPA_MSGBYTES) + + for i in 0..> uint(j))&1) + } + } +} + +poly_tomsg :: proc "contextless" (msg: []byte, a: ^Poly) #no_bounds_check { + ensure_contextless(len(msg) == INDCPA_MSGBYTES) + + for i in 0..> 15) & Q + // t = (((t << 1) + Q/2)/Q) & 1 + t <<= 1 + t += 1665 + t *= 80635 + t >>= 28 + t &= 1 + msg[i] |= byte(t << j) + } + } +} + +poly_getnoise_eta1_512 :: proc(r: ^Poly, seed: []byte, iv: byte) { + buf: [ETA1_512*N/4]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + prf(buf[:], seed, iv) + poly_cbd_eta1_512(r, &buf) +} + +poly_getnoise_eta1 :: proc(r: ^Poly, seed: []byte, iv: byte) { + buf: [ETA1*N/4]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + prf(buf[:], seed, iv) + poly_cbd_eta1(r, &buf) +} + +poly_getnoise_eta2 :: proc(r: ^Poly, seed: []byte, iv: byte) { + buf: [ETA2*N/4]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + prf(buf[:], seed, iv) + poly_cbd_eta2(r, &buf) +} + +poly_ntt :: proc "contextless" (r: ^Poly) { + ntt(&r.coeffs) + poly_reduce(r) +} + +poly_invntt_tomont :: proc "contextless" (r: ^Poly) { + invntt(&r.coeffs) +} + +poly_basemul_montgomery :: proc "contextless" (r, a, b: ^Poly) #no_bounds_check { + for i in 0.. int { + switch k { + case K_512: + return POLYCOMPRESSEDBYTES_512 + case K_768: + return POLYCOMPRESSEDBYTES_768 + case K_1024: + return POLYCOMPRESSEDBYTES_1024 + case: + panic_contextless("crypto/mlkem: invalid k") + } +} diff --git a/core/crypto/_mlkem/polyvec.odin b/core/crypto/_mlkem/polyvec.odin new file mode 100644 index 000000000..387acade8 --- /dev/null +++ b/core/crypto/_mlkem/polyvec.odin @@ -0,0 +1,224 @@ +#+private +package _mlkem + +import "core:crypto" + +Polyvec :: struct { + vec: [K_MAX]Poly, +} + +polyvec_compress :: proc "contextless" (r: []byte, a: ^Polyvec, kay: int) #no_bounds_check { + d0: u64 + + r := r + switch len(r) { + case POLYVECCOMPRESSEDBYTES_512, POLYVECCOMPRESSEDBYTES_768: + ensure_contextless(kay == K_512 || kay == K_768) + + t: [4]u16 = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 15) & Q) + // t[k] = ((((uint32_t)t[k] << 10) + Q/2)/Q) & 0x3ff + d0 = u64(t[k]) + d0 <<= 10 + d0 += 1665 + d0 *= 1290167 + d0 >>= 32 + t[k] = u16(d0 & 0x3ff) + } + + r[0] = byte(t[0] >> 0) + r[1] = byte((t[0] >> 8) | (t[1] << 2)) + r[2] = byte((t[1] >> 6) | (t[2] << 4)) + r[3] = byte((t[2] >> 4) | (t[3] << 6)) + r[4] = byte(t[3] >> 2) + r = r[5:] + } + } + case POLYVECCOMPRESSEDBYTES_1024: + ensure_contextless(kay == K_1024) + + t: [8]u16 = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 15) & Q) + // t[k] = ((((uint32_t)t[k] << 11) + Q/2)/Q) & 0x7ff + d0 = u64(t[k]) + d0 <<= 11 + d0 += 1664 + d0 *= 645084 + d0 >>= 31 + t[k] = u16(d0 & 0x7ff) + } + + r[0] = byte(t[0] >> 0) + r[1] = byte((t[0] >> 8) | (t[1] << 3)) + r[2] = byte((t[1] >> 5) | (t[2] << 6)) + r[3] = byte(t[2] >> 2) + r[4] = byte((t[2] >> 10) | (t[3] << 1)) + r[5] = byte((t[3] >> 7) | (t[4] << 4)) + r[6] = byte((t[4] >> 4) | (t[5] << 7)) + r[7] = byte(t[5] >> 1) + r[8] = byte((t[5] >> 9) | (t[6] << 2)) + r[9] = byte((t[6] >> 6) | (t[7] << 5)) + r[10] = byte(t[7] >> 3) + r = r[11:] + } + } + case: + panic_contextless("crypto/mlkem: invalid POLYVECCOMPRESSEDBYTES") + } +} + +polyvec_decompress :: proc "contextless" (r: ^Polyvec, a: []byte, kay: int) #no_bounds_check { + a := a + switch len(a) { + case POLYVECCOMPRESSEDBYTES_512, POLYVECCOMPRESSEDBYTES_768: + ensure_contextless(kay == K_512 || kay == K_768) + + t: [4]u16 = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 0) | (u16(a[1]) << 8) + t[1] = u16(a[1] >> 2) | (u16(a[2]) << 6) + t[2] = u16(a[2] >> 4) | (u16(a[3]) << 4) + t[3] = u16(a[3] >> 6) | (u16(a[4]) << 2) + a = a[5:] + + for k in 0..<4 { + r.vec[i].coeffs[4*j+k] = i16((u32(t[k] & 0x3FF) * Q + 512) >> 10) + } + } + } + case POLYVECCOMPRESSEDBYTES_1024: + t: [8]u16 = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 0) | (u16(a[1]) << 8) + t[1] = u16(a[1] >> 3) | (u16(a[2]) << 5) + t[2] = u16(a[2] >> 6) | (u16(a[3]) << 2) | (u16(a[4]) << 10) + t[3] = u16(a[4] >> 1) | (u16(a[5]) << 7) + t[4] = u16(a[5] >> 4) | (u16(a[6]) << 4) + t[5] = u16(a[6] >> 7) | (u16(a[7]) << 1) | (u16(a[8]) << 9) + t[6] = u16(a[8] >> 2) | (u16(a[9]) << 6) + t[7] = u16(a[9] >> 5) | (u16(a[10]) << 3) + a = a[11:] + + for k in 0..<8 { + r.vec[i].coeffs[8*j+k] = i16((u32(t[k] & 0x7FF) * Q + 1024) >> 11) + } + } + } + case: + panic_contextless("crypto/mlkem: invalid POLYVECCOMPRESSEDBYTES") + } +} + +polyvec_tobytes :: proc "contextless" (r: []byte, a: ^Polyvec, k: int) #no_bounds_check { + ensure_contextless(len(r) == k * POLYBYTES, "crypto/mlkem: invalid buffer") + + r := r + for i in 0.. bool #no_bounds_check { + switch k { + case K_512, K_768, K_1024: + case: + panic_contextless("crypto/mlkem: invalid POLYVECBYTES") + } + ensure_contextless(len(a) == k * POLYBYTES, "crypto/mlkem: invalid buffer") + + a := a + ok := true + for i in 0.. int { + switch k { + case K_512, K_768, K_1024: + return k * POLYBYTES + case: + return 0 + } +} + +@(require_results) +polyvec_compressed_byte_size :: #force_inline proc "contextless" (k: int) -> int { + switch k { + case K_512: + return POLYVECCOMPRESSEDBYTES_512 + case K_768: + return POLYVECCOMPRESSEDBYTES_768 + case K_1024: + return POLYVECCOMPRESSEDBYTES_1024 + case: + return 0 + } +} + +polyvec_ntt :: proc "contextless" (r: ^Polyvec, k: int) { + for i in 0.. i16 { + QINV :: -3327 // q^-1 mod 2^16 + + t := i16(a) * QINV + return i16((a - i32(t) * Q) >> 16) +} + +@(require_results) +barrett_reduce :: #force_inline proc "contextless" (a: i16) -> i16 { + V : i16 : ((1<<26) + Q / 2) / Q + + t := i16((i32(V)*i32(a) + (1<<25)) >> 26) + t *= Q + return a - t +} diff --git a/core/crypto/_mlkem/symmetric_shake.odin b/core/crypto/_mlkem/symmetric_shake.odin new file mode 100644 index 000000000..bf851f6f6 --- /dev/null +++ b/core/crypto/_mlkem/symmetric_shake.odin @@ -0,0 +1,61 @@ +#+private +package _mlkem + +import "core:crypto" +import "core:crypto/_sha3" +import "core:crypto/sha3" +import "core:crypto/shake" + +XOF_BLOCKBYTES :: _sha3.RATE_128 +#assert(XOF_BLOCKBYTES % 3 == 0) + +prf :: proc(out, key: []byte, iv: byte) { + ctx: shake.Context = --- + defer shake.reset(&ctx) + + shake.init_256(&ctx) + shake.write(&ctx, key) + shake.write(&ctx, []byte{iv}) + shake.read(&ctx, out) +} + +rkprf :: proc(out, key, input: []byte) { + ctx: shake.Context = --- + defer shake.reset(&ctx) + + shake.init_256(&ctx) + shake.write(&ctx, key) + shake.write(&ctx, input) + shake.read(&ctx, out) +} + +xof_absorb :: proc(ctx: ^shake.Context, seed: []byte, x, y: byte) { + shake.init_128(ctx) + + extseed: [SYMBYTES+2]byte = --- + defer crypto.zero_explicit(&extseed, size_of(extseed)) + + copy(extseed[:], seed) + extseed[SYMBYTES+0] = x + extseed[SYMBYTES+1] = y + + shake.write(ctx, extseed[:]) +} + +hash_h :: proc(dst, src: []byte) { + ctx: sha3.Context = --- + + sha3.init_256(&ctx) + sha3.update(&ctx, src) + sha3.final(&ctx, dst) +} + +hash_g :: proc(dst: []byte, srcs: ..[]byte) { + ctx: sha3.Context = --- + + sha3.init_512(&ctx) + for src in srcs { + sha3.update(&ctx, src) + } + sha3.final(&ctx, dst) +} diff --git a/core/crypto/_subtle/subtle.odin b/core/crypto/_subtle/subtle.odin index 454066e4a..01c84cf2a 100644 --- a/core/crypto/_subtle/subtle.odin +++ b/core/crypto/_subtle/subtle.odin @@ -3,6 +3,7 @@ Various useful bit operations in constant time. */ package _subtle +import "core:crypto/_fiat" import "core:math/bits" // byte_eq returns 1 if and only if (⟺) a == b, 0 otherwise. @@ -40,3 +41,41 @@ u64_is_non_zero :: proc "contextless" (a: u64) -> u64 { is_zero := u64_is_zero(a) return (~is_zero) & 1 } + +@(optimization_mode="none") +cmov_bytes :: proc "contextless" (dst, src: []byte, ctrl: int) { + s_len := len(src) + ensure_contextless(s_len == len(dst), "crypto: cmov length mismatch") + + c := -(byte)(ctrl) + for i in 0.. i16 { + c := -(u16)(ctrl) + return a ~ i16(c & u16(a ~ b)) +} + +@(optimization_mode="none") +csel_u16 :: proc "contextless" (a, b: u16, ctrl: int) -> u16 { + c := -(u16)(ctrl) + return a ~ (c & (a ~ b)) +} + +csel_u32 :: proc "contextless" (a, b: u32, ctrl: int) -> u32 { + return _fiat.cmovznz_u32(_fiat.u1(ctrl), a, b) +} + +csel_u64 :: proc "contextless" (a, b: u64, ctrl: int) -> u64 { + return _fiat.cmovznz_u64(_fiat.u1(ctrl), a, b) +} + +csel :: proc { + csel_i16, + csel_u16, + csel_u32, + csel_u64, +} diff --git a/core/crypto/mlkem/api.odin b/core/crypto/mlkem/api.odin new file mode 100644 index 000000000..9c8fd6a2b --- /dev/null +++ b/core/crypto/mlkem/api.odin @@ -0,0 +1,268 @@ +package mlkem + +import "core:crypto" +import "core:crypto/_mlkem" + +// Parameters are the supported ML-KEM parameter sets. +Parameters :: enum { + Invalid, + ML_KEM_512, + ML_KEM_768, + ML_KEM_1024, +} + +// DECAPSULATION_KEY_SEED_SIZE is the size of a Decapsulation key in bytes. +DECAPSULATION_KEY_SEED_SIZE :: 64 // (d, z) in NIST terms. + +// DECAPSULATION_KEY_EXPANDED_SIZES are the per-parameter sizes of the +// decapsulation key in bytes. +DECAPSULATION_KEY_EXPANDED_SIZES := [Parameters]int { + .Invalid = 0, + .ML_KEM_512 = _mlkem.DECAPSKEYBYTES_512, // 1632-bytes + .ML_KEM_768 = _mlkem.DECAPSKEYBYTES_768, // 2400-bytes + .ML_KEM_1024 = _mlkem.DECAPSKEYBYTES_1024, // 3168-bytes +} + +// ENCAPSULATION_KEY_SIZES are the per-parameter sizes of the encapsulation +// key in bytes. +ENCAPSULATION_KEY_SIZES := [Parameters]int { + .Invalid = 0, + .ML_KEM_512 = _mlkem.ENCAPSKEYBYTES_512, // 800-bytes + .ML_KEM_768 = _mlkem.ENCAPSKEYBYTES_768, // 1184-bytes + .ML_KEM_1024 = _mlkem.ENCAPSKEYBYTES_1024, // 1568-bytes +} + +// CIPHERTEXT_SIZES are the per-parameter set sizes of the ciphertext +// in bytes. +CIPHERTEXT_SIZES := [Parameters]int { + .Invalid = 0, + .ML_KEM_512 = _mlkem.CIPHERTEXTBYTES_512, // 768-bytes + .ML_KEM_768 = _mlkem.CIPHERTEXTBYTES_768, // 1088-bytes + .ML_KEM_1024 = _mlkem.CIPHERTEXTBYTES_1024, // 1568-bytes +} + +// SHARED_SECRET_SIZE is the size of the final shared secret in bytes. +SHARED_SECRET_SIZE :: 32 + +// Decapsulation_Key is a ML-KEM decapsulation (aka "private") key. +// This implementation opts to include the encapsulation (aka "public") +// key as well for cases where the decapsulation key is reused (eg: HPKE +// with X-Wing). +Decapsulation_Key :: _mlkem.Decapsulation_Key + +// Encapsulation_Key is a ML-KEM encapsulation (aka "public") key. +Encapsulation_Key :: _mlkem.Encapsulation_Key + +// decapsulation_key_generate uses the system entropy source to generate +// a decapsulation key. This will only fail if and only if (⟺) the system +// entropy source is missing or broken. +@(require_results) +decapsulation_key_generate :: proc(dk: ^Decapsulation_Key, params: Parameters) -> bool { + decapsulation_key_clear(dk) + + if !crypto.HAS_RAND_BYTES { + return false + } + + k := params_to_k(params) + if k == 0 { + panic("crypto/mlkem: invalid parameter set") + } + + seed: [DECAPSULATION_KEY_SEED_SIZE]byte = --- + defer crypto.zero_explicit(&seed, size_of(seed)) + + crypto.rand_bytes(seed[:]) + _mlkem.kem_keygen_internal(dk, seed[:], k) + + return true +} + +// decapsulation_key_set_bytes decodes a byte-encoded decapsulation key +// in (d, z) "seed" format, and returns true if and only if (⟺) the +// operation was successful. +@(require_results) +decapsulation_key_set_bytes :: proc(dk: ^Decapsulation_Key, params: Parameters, seed: []byte) -> bool { + k := params_to_k(params) + if k == 0 { + return false + } + if len(seed) != DECAPSULATION_KEY_SEED_SIZE { + return false + } + + _mlkem.kem_keygen_internal(dk, seed, k) + + return true +} + +// decapsulation_key_bytes sets dst to byte-encoding of dk in the (d, z) +// "seed" format. +decapsulation_key_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) { + ensure(dk.pke_dk.k != 0, "crypto/mlkem: uninitialized Decapsulation_Key") + ensure(len(dst) == DECAPSULATION_KEY_SEED_SIZE, "crypto/mlkem: invalid destination size") + + copy(dst, dk.seed[:]) +} + +// decapsulation_key_expanded_bytes sets dst to the byte-encoding of dk. +// in the expanded FIPS 203 format. This primarily exists for export +// purposes. +decapsulation_key_expanded_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) { + dk_len: int + switch dk.pke_dk.k { + case _mlkem.K_512: + dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_512] + case _mlkem.K_768: + dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_768] + case _mlkem.K_1024: + dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_1024] + case: + panic("crypto/mlkem: uninitialized Decapsulation_Key") + } + ensure(len(dst) == dk_len, "crypto/mlkem: invalid destination size") + + _mlkem.decapsulation_key_expanded_bytes(dk, dst) +} + +// decapsulation_key_encaps_bytes sets dst to the byte-encoding of the +// encasulation key corresponding to dk. +decapsulation_key_encaps_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) { + encapsulation_key_bytes(&dk.ek, dst) +} + +// decapsulation_key_clear clears dk to the uninitialized state. +decapsulation_key_clear :: proc(dk: ^Decapsulation_Key) { + crypto.zero_explicit(dk, size_of(Decapsulation_Key)) +} + +// encapsulation_key_set_bytes decodes a byte-encoded encapsulation key, +// and returns true if and only if (⟺) the operation was successful. +@(require_results) +encapsulation_key_set_bytes :: proc(ek: ^Encapsulation_Key, params: Parameters, b: []byte) -> bool { + k := params_to_k(params) + if k == 0 { + return false + } + if len(b) != ENCAPSULATION_KEY_SIZES[params] { + return false + } + + return _mlkem.encapsulation_key_set_bytes(ek, k, b) +} + +// encapsulation_key_set_decaps sets ek to the encapsulation key corresponding +// to dk. +encapsulation_key_set_decaps :: proc(ek: ^Encapsulation_Key, dk: ^Decapsulation_Key) { + ensure(dk.pke_dk.k != 0, "crypto/mlkem: uninitialized Decapsulation_Key") + _mlkem.encapsulation_key_set_decaps(ek, dk) +} + +// encapsulation_key_encaps_bytes sets dst to the byte-encoding of ek. +encapsulation_key_bytes :: proc(ek: ^Encapsulation_Key, dst: []byte) { + ensure(ek.pke_ek.k != 0, "crypto/mlkem: uninitialized Encapsulation_Key") + + k_len: int + switch ek.pke_ek.k { + case _mlkem.K_512: + k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_512] + case _mlkem.K_768: + k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_768] + case _mlkem.K_1024: + k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_1024] + case: + panic("crypto/mlkem: invalid destination size") + } + + copy(dst, ek.raw_bytes[:k_len]) +} + +// encapsulation_key_clear clears ek to the uninitialized state. +encapsulation_key_clear :: proc(ek: ^Encapsulation_Key) { + crypto.zero_explicit(ek, size_of(Encapsulation_Key)) +} + +// encaps_raw_ek_bytes uses the byte encoded encapsulation key to generate +// a shared secret and an associated ciphertext. This routine will fail +// if the system entropy source is unavailable, or of the encapsulation key +// is invalid. +@(require_results) +encaps_ek_raw_bytes :: proc(params: Parameters, raw_ek, shared_secret, ciphertext: []byte) -> bool { + ek: Encapsulation_Key = --- + if !encapsulation_key_set_bytes(&ek, params, raw_ek) { + return false + } + defer encapsulation_key_clear(&ek) + + return encaps_ek(&ek, shared_secret, ciphertext) +} + +// encaps_ek uses the encapsulation key to generate a shared secret and an +// associated ciphertext. This routine will fail if the system entropy source +// is unavailable. +@(require_results) +encaps_ek :: proc(ek: ^Encapsulation_Key, shared_secret, ciphertext: []byte) -> bool { + ensure(len(shared_secret) == SHARED_SECRET_SIZE, "crypto/mlkem: invalid shared_seret size") + + if !crypto.HAS_RAND_BYTES { + return false + } + + m: [_mlkem.SYMBYTES]byte = --- + defer crypto.zero_explicit(&m, size_of(m)) + + crypto.rand_bytes(m[:]) + _mlkem.kem_encaps_internal(shared_secret, ciphertext, ek, m[:]) + + return true +} + +encaps :: proc { + encaps_ek, + encaps_ek_raw_bytes, +} + +// decaps uses the decapsulation key to generate a shared secret from a +// ciphertext. Due to ML-KEM's implicit rejection mechanism, this function +// will only return false if and only if (⟺) the lengths of the inputs +// are invalid or the decapsulation key is uninitialized. +// +// This routine returning true does not guarantee that the shared secret +// matches that generated by the peer. +@(require_results) +decaps :: proc(dk: ^Decapsulation_Key, ciphertext, shared_secret: []byte) -> bool { + ensure(len(shared_secret) == SHARED_SECRET_SIZE, "crypto/mlkem: invalid shared_seret size") + + ct_len: int + switch dk.pke_dk.k { + case _mlkem.K_512: + ct_len = CIPHERTEXT_SIZES[.ML_KEM_512] + case _mlkem.K_768: + ct_len = CIPHERTEXT_SIZES[.ML_KEM_768] + case _mlkem.K_1024: + ct_len = CIPHERTEXT_SIZES[.ML_KEM_1024] + case: + return false + } + if len(ciphertext) != ct_len { + return false + } + + _mlkem.kem_decaps_internal(shared_secret, dk, ciphertext) + + return true +} + +@(private="file") +params_to_k :: #force_inline proc "contextless" (params: Parameters) -> int { + #partial switch params { + case .ML_KEM_512: + return _mlkem.K_512 + case .ML_KEM_768: + return _mlkem.K_768 + case .ML_KEM_1024: + return _mlkem.K_1024 + } + + return 0 +} diff --git a/core/crypto/mlkem/doc.odin b/core/crypto/mlkem/doc.odin new file mode 100644 index 000000000..ec32def25 --- /dev/null +++ b/core/crypto/mlkem/doc.odin @@ -0,0 +1,7 @@ +/* +ML-KEM Module-Lattice-Based Key-Encapsulation Mechanism. + +See: +- [[ https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.pdf ]] +*/ +package mlkem diff --git a/examples/all/all_js.odin b/examples/all/all_js.odin index 74cdb5fd3..f557cecdb 100644 --- a/examples/all/all_js.odin +++ b/examples/all/all_js.odin @@ -43,6 +43,7 @@ package all @(require) import "core:crypto/legacy/keccak" @(require) import "core:crypto/legacy/md5" @(require) import "core:crypto/legacy/sha1" +@(require) import "core:crypto/mlkem" @(require) import cnoise "core:crypto/noise" @(require) import "core:crypto/pbkdf2" @(require) import "core:crypto/poly1305" diff --git a/examples/all/all_main.odin b/examples/all/all_main.odin index 973ee423e..5565d225d 100644 --- a/examples/all/all_main.odin +++ b/examples/all/all_main.odin @@ -48,6 +48,7 @@ package all @(require) import "core:crypto/legacy/keccak" @(require) import "core:crypto/legacy/md5" @(require) import "core:crypto/legacy/sha1" +@(require) import "core:crypto/mlkem" @(require) import cnoise "core:crypto/noise" @(require) import "core:crypto/pbkdf2" @(require) import "core:crypto/poly1305" diff --git a/tests/benchmark/crypto/benchmark_pqc.odin b/tests/benchmark/crypto/benchmark_pqc.odin new file mode 100644 index 000000000..314594ca5 --- /dev/null +++ b/tests/benchmark/crypto/benchmark_pqc.odin @@ -0,0 +1,84 @@ +package benchmark_core_crypto + +import "core:log" +import "core:testing" +import "core:text/table" +import "core:time" + +import "core:crypto" +import "core:crypto/mlkem" + +@(private = "file") +MLKEM_ITERS :: 50000 + +@(test) +benchmark_crypto_mlkem :: proc(t: ^testing.T) { + if !crypto.HAS_RAND_BYTES { + log.warnf("ML-KEM benchmarks skipped, no system entropy source") + } + + tbl: table.Table + table.init(&tbl) + defer table.destroy(&tbl) + + table.caption(&tbl, "ML-KEM") + table.aligned_header_of_values(&tbl, .Right, "Parameters", "Keygen", "Encaps", "Decaps") + + append_tbl := proc(tbl: ^table.Table, algo_name: string, keygen, encaps, decaps: time.Duration) { + table.aligned_row_of_values( + tbl, + .Right, + algo_name, + table.format(tbl, "%8M", keygen), + table.format(tbl, "%8M", encaps), + table.format(tbl, "%8M", decaps), + ) + } + + for params in mlkem.Parameters { + if params == .Invalid { + continue + } + param_name := MLKEM_PARAMS_NAMES[params] + + decaps_key: mlkem.Decapsulation_Key + start := time.tick_now() + for _ in 0 ..< MLKEM_ITERS { + _ = mlkem.decapsulation_key_generate(&decaps_key, params) + } + keygen := time.tick_since(start) / MLKEM_ITERS + + encaps_key := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params]) + defer delete(encaps_key) + ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params]) + defer delete(ciphertext) + + mlkem.decapsulation_key_encaps_bytes(&decaps_key, encaps_key) + + bob_shared: [mlkem.SHARED_SECRET_SIZE]byte + start = time.tick_now() + for _ in 0 ..< MLKEM_ITERS { + _ = mlkem.encaps(params, encaps_key, bob_shared[:], ciphertext) + } + encaps := time.tick_since(start) / MLKEM_ITERS + + alice_shared: [mlkem.SHARED_SECRET_SIZE]byte + start = time.tick_now() + for _ in 0 ..< MLKEM_ITERS { + _ = mlkem.decaps(&decaps_key, ciphertext, alice_shared[:]) + } + decaps := time.tick_since(start) / MLKEM_ITERS + + append_tbl(&tbl, param_name, keygen, encaps, decaps) + } + + log_table(&tbl) +} + +@(private="file") +MLKEM_PARAMS_NAMES := [mlkem.Parameters]string { + .Invalid = "invalid", + .ML_KEM_512 = "ML-KEM-512", + .ML_KEM_768 = "ML-KEM-768", + .ML_KEM_1024 = "ML-KEM-1024", +} diff --git a/tests/core/crypto/test_core_crypto_pqc.odin b/tests/core/crypto/test_core_crypto_pqc.odin new file mode 100644 index 000000000..8359053e4 --- /dev/null +++ b/tests/core/crypto/test_core_crypto_pqc.odin @@ -0,0 +1,72 @@ +package test_core_crypto + +import "core:bytes" +import "core:log" +import "core:testing" + +import "core:crypto" +import "core:crypto/mlkem" + +@(test) +test_mlkem :: proc(t: ^testing.T) { + if !crypto.HAS_RAND_BYTES { + log.info("rand_bytes not supported - skipping") + return + } + + // Test vectors are huge, and are covered by the wycheproof corpus, + // so just test a full key exchange with all supported parameter + // sets. + for params in mlkem.Parameters { + if params == .Invalid { + continue + } + + // Alice + decaps_key: mlkem.Decapsulation_Key + if !testing.expectf( + t, + mlkem.decapsulation_key_generate(&decaps_key, params), + "%v: decapsulation_key_generate", + params, + ) { + continue + } + defer mlkem.decapsulation_key_clear(&decaps_key) + + ek_bytes := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params]) + defer delete(ek_bytes) + mlkem.decapsulation_key_encaps_bytes(&decaps_key, ek_bytes) + + // Bob + bob_shared_secret: [mlkem.SHARED_SECRET_SIZE]byte + ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params]) + defer delete(ciphertext) + if !testing.expectf( + t, + mlkem.encaps(params, ek_bytes, bob_shared_secret[:], ciphertext), + "%v: encaps", + params, + ) { + continue + } + + // Alice + alice_shared_secret: [mlkem.SHARED_SECRET_SIZE]byte + if !testing.expectf( + t, + mlkem.decaps(&decaps_key, ciphertext, alice_shared_secret[:]), + "%v: decaps", + params, + ) { + continue + } + + testing.expectf( + t, + bytes.equal(alice_shared_secret[:], bob_shared_secret[:]), + "%v: shared secret mismatch", + params, + ) + } +} diff --git a/tests/core/crypto/wycheproof/main.odin b/tests/core/crypto/wycheproof/main.odin index dfdc78267..654ac2a38 100644 --- a/tests/core/crypto/wycheproof/main.odin +++ b/tests/core/crypto/wycheproof/main.odin @@ -35,6 +35,16 @@ import "core:testing" // - crypto/kmac // - kmac128_no_customization_test.json // - kmac256_no_customization_test.json +// - crypto/mlkem +// - mlkem_512_keygen_seed_test.json +// - mlkem_512_encaps_test.json +// - mlkem_512_test.json +// - mlkem_768_keygen_seed_test.json +// - mlkem_768_encaps_test.json +// - mlkem_768_test.json +// - mlkem_1024_keygen_seed_test.json +// - mlkem_1024_encaps_test.json +// - mlkem_1024_test.json // - crypto/pbkdf2 // - pbkdf2_hmacsha1_test.json // - pbkdf2_hmacsha224_test.json diff --git a/tests/core/crypto/wycheproof/pqc.odin b/tests/core/crypto/wycheproof/pqc.odin new file mode 100644 index 000000000..210d7be1d --- /dev/null +++ b/tests/core/crypto/wycheproof/pqc.odin @@ -0,0 +1,401 @@ +package test_wycheproof + +import "core:encoding/hex" +import "core:log" +import "core:mem" +import "core:os" +import "core:testing" + +import "core:crypto/_mlkem" +import "core:crypto/mlkem" + +import "../common" + +@(test) +test_mlkem :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + log.debug("mlkem: starting") + + files_keygen := []string { + "mlkem_512_keygen_seed_test.json", + "mlkem_768_keygen_seed_test.json", + "mlkem_1024_keygen_seed_test.json", + } + for f in files_keygen { + mem.free_all() + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Kem_Test_Group) + load_ok := load(&test_vectors, fn) + if !testing.expectf(t, load_ok, "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_mlkem_keygen(t, &test_vectors), "ML-KEM KeyGen failed") + } + + files_encaps := []string { + "mlkem_512_encaps_test.json", + "mlkem_768_encaps_test.json", + "mlkem_1024_encaps_test.json", + } + for f in files_encaps { + mem.free_all() + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Kem_Test_Group) + load_ok := load(&test_vectors, fn) + if !testing.expectf(t, load_ok, "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_mlkem_encaps(t, &test_vectors), "ML-KEM Encaps failed") + } + + files_decaps := []string { + "mlkem_512_test.json", + "mlkem_768_test.json", + "mlkem_1024_test.json", + } + for f in files_decaps { + mem.free_all() + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Kem_Test_Group) + load_ok := load(&test_vectors, fn) + if !testing.expectf(t, load_ok, "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_mlkem_decaps(t, &test_vectors), "ML-KEM Decaps failed") + } +} + +test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool { + params_str := test_vectors.test_groups[0].parameter_set + params := parameter_set_to_params(params_str) + if params == .Invalid { + return false + } + + log.debugf("%s: KeyGen starting", params_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, tg_id in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + seed := common.hexbytes_decode(test_vector.seed) + + dk: mlkem.Decapsulation_Key + if !testing.expectf( + t, + mlkem.decapsulation_key_set_bytes(&dk, params, seed), + "%s/KeyGen/%d/%d: failed to set decapsulation key from seed", + params_str, + tg_id, + test_vector.tc_id, + test_vector.seed, + ) { + num_failed *= 1 + continue + } + + ek_bytes := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params]) + mlkem.decapsulation_key_encaps_bytes(&dk, ek_bytes) + + ok := common.hexbytes_compare(test_vector.ek, ek_bytes) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(ek_bytes)) + log.errorf( + "%s/KeyGen/%d/%d: ek: expected %s actual %s", + params_str, + tg_id, + test_vector.tc_id, + test_vector.ek, + x, + ) + num_failed += 1 + continue + } + + dk_bytes := make([]byte, mlkem.DECAPSULATION_KEY_EXPANDED_SIZES[params]) + mlkem.decapsulation_key_expanded_bytes(&dk, dk_bytes) + + ok = common.hexbytes_compare(test_vector.dk, dk_bytes) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(dk_bytes)) + log.errorf( + "%s/KeyGen/%d/%d: dk: expected %s actual %s", + tg_id, + params_str, + test_vector.tc_id, + test_vector.dk, + x, + ) + num_failed += 1 + continue + } + + seed_bytes: [mlkem.DECAPSULATION_KEY_SEED_SIZE]byte + mlkem.decapsulation_key_bytes(&dk, seed_bytes[:]) + + ok = common.hexbytes_compare(test_vector.seed, seed_bytes[:]) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(seed_bytes[:])) + log.errorf( + "%s/KeyGen/%d/%d: seed: expected %s actual %s", + tg_id, + params_str, + test_vector.tc_id, + test_vector.seed, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "%s/KeyGen: ran %d, passed %d, failed %d, skipped %d", + params_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool { + params_str := test_vectors.test_groups[0].parameter_set + params := parameter_set_to_params(params_str) + if params == .Invalid { + return false + } + + log.debugf("%s: Encaps starting", params_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, tg_id in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + ek: mlkem.Encapsulation_Key + ok := mlkem.encapsulation_key_set_bytes( + &ek, + params, + common.hexbytes_decode(test_vector.ek), + ) + + // The current corpus can only fail if the encapsulation key + // is malformed in some way. + if !result_check(test_vector.result, ok) { + log.errorf( + "%s/Encaps/%d/%d: unexpected set encapsulation key from bytes: %s (%v != %v)", + params_str, + tg_id, + test_vector.tc_id, + test_vector.ek, + test_vector.result, + ok, + ) + num_failed += 1 + continue + } + if !ok { + num_passed += 1 + continue + } + + shared_secret: [mlkem.SHARED_SECRET_SIZE]byte + ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params]) + + _mlkem.kem_encaps_internal( + shared_secret[:], + ciphertext, + &ek, + common.hexbytes_decode(test_vector.m), + ) + + ok = common.hexbytes_compare(test_vector.c, ciphertext) + if !ok { + x := transmute(string)(hex.encode(ciphertext)) + log.errorf( + "%s/Encaps/%d/%d: ciphertext: expected: %s actual: %s", + params_str, + tg_id, + test_vector.tc_id, + test_vector.c, + x, + ) + num_failed += 1 + continue + } + + ok = common.hexbytes_compare(test_vector.k, shared_secret[:]) + if !ok { + x := transmute(string)(hex.encode(shared_secret[:])) + log.errorf( + "%s/Encaps/%d/%d: shared_secret: expected: %s actual: %s", + params_str, + tg_id, + test_vector.tc_id, + test_vector.k, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "%s/Encaps: ran %d, passed %d, failed %d, skipped %d", + params_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool { + params_str := test_vectors.test_groups[0].parameter_set + params := parameter_set_to_params(params_str) + if params == .Invalid { + return false + } + + log.debugf("%s: Decaps starting", params_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, tg_id in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + // We do not have an API for decaps with raw seed. + seed := common.hexbytes_decode(test_vector.seed) + switch len(seed) { + case mlkem.DECAPSULATION_KEY_SEED_SIZE: + case: + if testing.expectf( + t, + result_is_invalid(test_vector.result), + "%s/Decaps/%d/%d: test vector expects success with invalid seed", + params_str, + tg_id, + test_vector.tc_id, + ) { + num_passed += 1 + } else { + num_failed += 1 + } + continue + } + + dk: mlkem.Decapsulation_Key + if !testing.expectf( + t, + mlkem.decapsulation_key_set_bytes(&dk, params, seed), + "%s/Decaps/%d/%d: failed to set decapsulation key from seed", + params_str, + tg_id, + test_vector.tc_id, + test_vector.seed, + ) { + num_failed *= 1 + continue + } + + shared_secret: [mlkem.SHARED_SECRET_SIZE]byte + + ok := mlkem.decaps( + &dk, + common.hexbytes_decode(test_vector.c), + shared_secret[:], + ) + if !result_check(test_vector.result, ok) { + log.errorf( + "%s/Decaps/%d/%d: unexpected decapsulation failure", + params_str, + tg_id, + test_vector.tc_id, + ) + num_failed += 1 + continue + } + if !ok { + num_passed += 1 + continue + } + + ok = common.hexbytes_compare(test_vector.k, shared_secret[:]) + if !ok { + x := transmute(string)(hex.encode(shared_secret[:])) + log.errorf( + "%s/Decaps/%d/%d: shared_secret: expected: %s actual: %s", + params_str, + tg_id, + test_vector.tc_id, + test_vector.k, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "%s/Decaps: ran %d, passed %d, failed %d, skipped %d", + params_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +@(require_results, private="file") +parameter_set_to_params :: proc(s: string) -> mlkem.Parameters { + switch s { + case "ML-KEM-512": + return .ML_KEM_512 + case "ML-KEM-768": + return .ML_KEM_768 + case "ML-KEM-1024": + return .ML_KEM_1024 + case: + return .Invalid + } +} diff --git a/tests/core/crypto/wycheproof/schemas.odin b/tests/core/crypto/wycheproof/schemas.odin index 645f0f085..d801207a0 100644 --- a/tests/core/crypto/wycheproof/schemas.odin +++ b/tests/core/crypto/wycheproof/schemas.odin @@ -64,6 +64,11 @@ Test_Vectors_Note :: struct { links: []string `json:"links"`, } +Test_Group_Source :: struct { + name: string `json:"name"`, + version: string `json:"version"`, +} + Aead_Test_Group :: struct { iv_size: int `json:"ivSize"`, key_size: int `json:"keySize"`, @@ -198,3 +203,23 @@ Pbkdf_Test_Vector :: struct { result: Result `json:"result"`, flags: []string `json:"flags"`, } + +Kem_Test_Group :: struct { + type: string `json:"type"`, + source: Test_Group_Source `json:"source"`, + parameter_set: string `json:"parameterSet"`, + tests: []Kem_Test_Vector `json:"tests"`, +} + +Kem_Test_Vector :: struct { + tc_id: int `json:"tcId"`, + flags: []string `json:"flags"`, + comment: string `json:"comment"`, + seed: common.Hex_Bytes `json:"seed"`, + m: common.Hex_Bytes `json:"m"`, + ek: common.Hex_Bytes `json:"ek"`, + dk: common.Hex_Bytes `json:"dk"`, + c: common.Hex_Bytes `json:"c"`, + k: common.Hex_Bytes `json:"K"`, + result: Result `json:"result"`, +}