diff --git a/core/crypto/_mlkem/cbd.odin b/core/crypto/_mlkem/cbd.odin new file mode 100644 index 000000000..301551ff5 --- /dev/null +++ b/core/crypto/_mlkem/cbd.odin @@ -0,0 +1,52 @@ +#+private +package _mlkem + +import "core:encoding/endian" + +unchecked_get_u24le :: #force_inline proc "contextless" (b: []byte) -> u32 #no_bounds_check { + r := u32(b[0]) + r |= u32(b[1]) << 8 + r |= u32(b[2]) << 16 + return r +} + +cbd3 :: proc "contextless" (r: ^Poly, buf: ^[3*N/4]byte) #no_bounds_check { + for i in 0..>1) & 0x00249249 + d += (t>>2) & 0x00249249 + + for j in uint(0)..<4 { + a := i16((d >> (6*j+0)) & 0x7) + b := i16((d >> (6*j+3)) & 0x7) + r.coeffs[4*i+int(j)] = a - b + } + } +} + +cbd2 :: proc "contextless" (r: ^Poly, buf: ^[2*N/4]byte) #no_bounds_check { + for i in 0..>1) & 0x55555555 + + for j in uint(0)..<8 { + a := i16((d >> (4*j+0)) & 0x3) + b := i16((d >> (4*j+2)) & 0x3) + r.coeffs[8*i+int(j)] = a - b + } + } +} + +poly_cbd_eta1_512 :: proc "contextless" (r: ^Poly, buf: ^[ETA1_512*N/4]byte) { + cbd3(r, buf) +} + +poly_cbd_eta1 :: proc "contextless" (r: ^Poly, buf: ^[ETA1*N/4]byte) { + cbd2(r, buf) +} + +poly_cbd_eta2 :: proc "contextless" (r: ^Poly, buf: ^[ETA2*N/4]byte) { + cbd2(r, buf) +} diff --git a/core/crypto/_mlkem/constants.odin b/core/crypto/_mlkem/constants.odin new file mode 100644 index 000000000..7195d8b61 --- /dev/null +++ b/core/crypto/_mlkem/constants.odin @@ -0,0 +1,53 @@ +package _mlkem + +K_512 :: 2 +K_768 :: 3 +K_1024 :: 4 +K_MAX :: K_1024 + +N :: 256 +Q :: 3329 + +ETA1_512 :: 3 +ETA1 :: 2 +ETA2 :: 2 + +POLYBYTES :: 384 +SYMBYTES :: 32 + +POLYCOMPRESSEDBYTES_512 :: 128 +POLYCOMPRESSEDBYTES_768 :: 128 +POLYCOMPRESSEDBYTES_1024 :: 160 + +POLYVECBYTES_512 :: K_512 * POLYBYTES +POLYVECBYTES_768 :: K_768 * POLYBYTES +POLYVECBYTES_1024 :: K_1024 * POLYBYTES + +POLYVECCOMPRESSEDBYTES_512 :: K_512 * 320 +POLYVECCOMPRESSEDBYTES_768 :: K_768 * 320 +POLYVECCOMPRESSEDBYTES_1024 :: K_1024 * 352 + +INDCPA_MSGBYTES :: SYMBYTES +INDCPA_PUBLICKEYBYTES_512 :: POLYVECBYTES_512 + SYMBYTES +INDCPA_SECRETKEYBYTES_512 :: POLYVECBYTES_512 +INDCPA_PUBLICKEYBYTES_768 :: POLYVECBYTES_768 + SYMBYTES +INDCPA_SECRETKEYBYTES_768 :: POLYVECBYTES_768 +INDCPA_PUBLICKEYBYTES_1024 :: POLYVECBYTES_1024 + SYMBYTES +INDCPA_SECRETKEYBYTES_1024 :: POLYVECBYTES_1024 +INDCPA_PUBLICKEYBYTES_MAX :: INDCPA_PUBLICKEYBYTES_1024 +INDCPA_BYTES_512 :: POLYVECCOMPRESSEDBYTES_512 + POLYCOMPRESSEDBYTES_512 +INDCPA_BYTES_768 :: POLYVECCOMPRESSEDBYTES_768 + POLYCOMPRESSEDBYTES_768 +INDCPA_BYTES_1024 :: POLYVECCOMPRESSEDBYTES_1024 + POLYCOMPRESSEDBYTES_1024 + +ENCAPSKEYBYTES_512 :: INDCPA_PUBLICKEYBYTES_512 +ENCAPSKEYBYTES_768 :: INDCPA_PUBLICKEYBYTES_768 +ENCAPSKEYBYTES_1024 :: INDCPA_PUBLICKEYBYTES_1024 + +DECAPSKEYBYTES_512 :: INDCPA_SECRETKEYBYTES_512 + INDCPA_PUBLICKEYBYTES_512 + 2 * SYMBYTES +DECAPSKEYBYTES_768 :: INDCPA_SECRETKEYBYTES_768 + INDCPA_PUBLICKEYBYTES_768 + 2 * SYMBYTES +DECAPSKEYBYTES_1024 :: INDCPA_SECRETKEYBYTES_1024 + INDCPA_PUBLICKEYBYTES_1024 + 2 * SYMBYTES + +CIPHERTEXTBYTES_512 :: INDCPA_BYTES_512 +CIPHERTEXTBYTES_768 :: INDCPA_BYTES_768 +CIPHERTEXTBYTES_1024 :: INDCPA_BYTES_1024 +CIPHERTEXTBYTES_MAX :: INDCPA_BYTES_1024 diff --git a/core/crypto/_mlkem/k_pke.odin b/core/crypto/_mlkem/k_pke.odin new file mode 100644 index 000000000..d8b87ad7a --- /dev/null +++ b/core/crypto/_mlkem/k_pke.odin @@ -0,0 +1,349 @@ +#+private +package _mlkem + +import "core:crypto" +import "core:crypto/shake" + +@(require_results) +pack_pk :: proc "contextless" (r: []byte, pk: ^Polyvec, seed: []byte, k: int) -> bool { + pk_len := polyvec_byte_size(k) + switch { + case pk_len == 0: + return false + case len(seed) != SYMBYTES || len(r) != pk_len + SYMBYTES: + return false + } + + polyvec_tobytes(r[:pk_len], pk, k) + copy(r[pk_len:], seed) + + return true +} + +@(require_results) +unpack_pk :: proc "contextless" (pk: ^Polyvec, seed, packedpk: []byte) -> bool { + pk_len := len(packedpk) - SYMBYTES + k: int + switch { + case pk_len == POLYVECBYTES_512: + k = K_512 + case pk_len == POLYVECBYTES_768: + k = K_768 + case pk_len == POLYVECBYTES_1024: + k = K_1024 + case len(packedpk) - pk_len != SYMBYTES: + return false + case len(seed) != SYMBYTES: + return false + } + if k == 0 { + return false + } + + ok := polyvec_frombytes(pk, packedpk[:pk_len], k) + copy(seed, packedpk[pk_len:]) + + return ok +} + +@(require_results) +pack_sk :: proc "contextless" (r: []byte, sk: ^Polyvec, k: int) -> bool { + r_len := len(r) + if r_len == 0 || r_len != polyvec_byte_size(k) { + return false + } + + polyvec_tobytes(r, sk, k) + + return true +} + +@(require_results) +unpack_sk :: proc "contextless" (sk: ^Polyvec, packedsk: []byte) -> bool { + k: int + switch len(packedsk) { + case POLYVECBYTES_512: + k = K_512 + case POLYVECBYTES_768: + k = K_768 + case POLYVECBYTES_1024: + k = K_1024 + case: + return false + } + if k == 0 { + return false + } + + return polyvec_frombytes(sk, packedsk, k) +} + +@(require_results) +pack_ciphertext :: proc "contextless" (r: []byte, b: ^Polyvec, v: ^Poly, k: int) -> bool { + b_len := polyvec_compressed_byte_size(k) + if len(r) != b_len + poly_compressed_bytes(k) { + return false + } + + polyvec_compress(r[:b_len], b, k) + poly_compress(r[b_len:], v) + + return true +} + +@(require_results) +unpack_ciphertext :: proc "contextless" (b: ^Polyvec, v: ^Poly, c: []byte) -> int { + b_len: int + k: int + switch len(c) { + case INDCPA_BYTES_512: + b_len = POLYVECCOMPRESSEDBYTES_512 + k = K_512 + case INDCPA_BYTES_768: + b_len = POLYVECCOMPRESSEDBYTES_768 + k = K_768 + case INDCPA_BYTES_1024: + b_len = POLYVECCOMPRESSEDBYTES_1024 + k = K_1024 + case: + return 0 + } + + polyvec_decompress(b, c[:b_len], k) + poly_decompress(v, c[b_len:]) + + return k +} + +@(require_results) +rej_uniform :: proc "contextless" (r: []i16, buf: []byte) -> int { + r_len, b_len := len(r), len(buf) + + ctr, pos: int + for ctr < r_len && pos + 3 <= b_len { + val0 := (u16(buf[pos+0] >> 0) | (u16(buf[pos+1]) << 8)) & 0xFFF + val1 := (u16(buf[pos+1] >> 4) | (u16(buf[pos+2]) << 4)) & 0xFFF + pos += 3 + + if val0 < Q { + r[ctr] = i16(val0) + ctr += 1 + } + if(ctr < r_len && val1 < Q) { + r[ctr] = i16(val1) + ctr += 1 + } + } + + return ctr +} + +gen_matrix :: proc(a: []Polyvec, seed: []byte, transposed: bool, k: int) { + GEN_MATRIX_NBLOCKS :: ((12*N/8*(1 << 12)/Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES) + + buf: [GEN_MATRIX_NBLOCKS*XOF_BLOCKBYTES]byte = --- + ctx: shake.Context = --- + ctr: int + + defer shake.reset(&ctx) + defer crypto.zero_explicit(&buf, size_of(buf)) + + for i in 0.. bool { + ensure(len(m) == INDCPA_MSGBYTES, "crypto/mlkem: invalid K-PKE m") + ensure(len(r) == SYMBYTES, "crypto/mlkem: invalid K-PKE r") + + k := ek.k + + at_: [K_MAX]Polyvec = --- + sp, ep, b: Polyvec = ---, ---, --- + kay, epp, v: Poly = ---, ---, --- + defer crypto.zero_explicit(&at_, size_of(Polyvec) * k) + defer polyvec_clear(&sp, &ep, &b) + defer poly_clear(&kay, &epp, &v) + + poly_frommsg(&kay, m) + + at := at_[:k] + + gen_matrix(at, ek.p[:], true, k) + + n := byte(0) + for i in 0.. bool { + if len(plaintext) != INDCPA_MSGBYTES { + return false + } + + k := dk.k + + b: Polyvec = --- + v, mp: Poly = ---, --- + defer poly_clear(&v, &mp) + + if unpack_ciphertext(&b, &v, c) != k { + return false + } + + polyvec_ntt(&b, k) + polyvec_basemul_acc_montgomery(&mp, &dk.pv, &b, k) + poly_invntt_tomont(&mp) + + poly_sub(&mp, &v, &mp) + poly_reduce(&mp) + + poly_tomsg(plaintext, &mp) + + return true +} diff --git a/core/crypto/_mlkem/kem_internal.odin b/core/crypto/_mlkem/kem_internal.odin new file mode 100644 index 000000000..b50aa4b58 --- /dev/null +++ b/core/crypto/_mlkem/kem_internal.odin @@ -0,0 +1,195 @@ +package _mlkem + +import "core:crypto" +import subtle "core:crypto/_subtle" + +// This implementation is derived from the PQ-CRYSTALS reference +// implementation [[ https://github.com/pq-crystals/kyber ]], +// primarily for licensing reasons. Arguably mlkem-native is +// a more "up to date" codebase, but the changes to the +// ref code is minor and they slapped an attribution-required +// license on something that was originally CC-0/Apache 2.0. + +// "Private Key" +Decapsulation_Key :: struct { + pke_dk: K_PKE_Decryption_Key, + ek: Encapsulation_Key, + seed: [SYMBYTES*2]byte, // (d, z) +} + +// "Public Key" +Encapsulation_Key :: struct { + pke_ek: K_PKE_Encryption_Key, + raw_bytes: [INDCPA_PUBLICKEYBYTES_MAX]byte, + h: [SYMBYTES]byte, +} + +decapsulation_key_expanded_bytes :: proc( + dk: ^Decapsulation_Key, + dst: []byte, +) { + sk := &dk.pke_dk + pv_len := polyvec_byte_size(sk.k) + ek_len := pv_len + SYMBYTES + + ek_bytes := dk.ek.raw_bytes[:ek_len] + + dst := dst + _ = pack_sk(dst[:pv_len], &sk.pv, sk.k) + + dst = dst[pv_len:] + copy(dst, ek_bytes) + dst = dst[ek_len:] + hash_h(dst[:SYMBYTES], ek_bytes) + dst = dst[SYMBYTES:] + copy(dst, dk.seed[SYMBYTES:]) +} + +@(require_results) +encapsulation_key_set_bytes :: proc( + ek: ^Encapsulation_Key, + k: int, + b: []byte, +) -> bool { + k_len: int + switch k { + case K_512: + k_len = ENCAPSKEYBYTES_512 + case K_768: + k_len = ENCAPSKEYBYTES_768 + case K_1024: + k_len = ENCAPSKEYBYTES_1024 + case: + return false + } + if len(b) != k_len { + return false + } + + pke_ek := &ek.pke_ek + ok := unpack_pk(&pke_ek.pv, pke_ek.p[:], b) + pke_ek.k = k + copy(ek.raw_bytes[:k_len], b) + hash_h(ek.h[:], b) + + // FIPS 203 unlike Kyber requires canonical encoding of + // encapsulation keys (Section 7,2), which is checked in + // unpack_pk. + + if !ok { + crypto.zero_explicit(ek, size_of(Encapsulation_Key)) + } + + return ok +} + +encapsulation_key_set_decaps :: proc(ek: ^Encapsulation_Key, dk: ^Decapsulation_Key) { + dk_ek := &dk.ek.pke_ek + ensure(dk_ek.k == K_512 || dk_ek.k == K_768 || dk_ek.k == K_1024, "crypto/mlkem: invalid decaps k") + + k_pke_encryption_key_set(&ek.pke_ek, dk_ek) + copy(ek.raw_bytes[:], dk.ek.raw_bytes[:]) + copy(ek.h[:], dk.ek.h[:]) +} + +// NIST's version of this also returns an encapsulation key, but our +// internal representation includes it as part of the decapsulation key +// in a more traditional "keypair" approach. +kem_keygen_internal :: proc( + dk: ^Decapsulation_Key, + seed: []byte, // (d, z) + k: int, +) { + ensure(len(seed) == 2 * SYMBYTES, "crypto/mlkem: invalid seed") + + dk_ek := &dk.ek + d, z := seed[:SYMBYTES], seed[SYMBYTES:] + + k_pke_keygen(&dk_ek.pke_ek, &dk.pke_dk, d, k) + + ek_len := polyvec_byte_size(k) + SYMBYTES + ek_bytes := dk_ek.raw_bytes[:ek_len] + ensure( + pack_pk(ek_bytes, &dk_ek.pke_ek.pv, dk_ek.pke_ek.p[:], k), + "crypto/mlkem: failed to pack K-PKE ek", + ) + hash_h(dk_ek.h[:], ek_bytes) + copy(dk.seed[:SYMBYTES], d) + copy(dk.seed[SYMBYTES:], z) +} + +// The `_internal` "de-randomized" versions of ML-KEM.Encaps and +// ML-KEM.Decaps are only ever to be called by the actual non-interal +// implementation or test cases. + +kem_encaps_internal :: proc( + shared_secret: []byte, + ciphertext: []byte, + ek: ^Encapsulation_Key, + randomness: []byte, +) { + ensure(len(shared_secret) == SYMBYTES, "crypto/mlkem: invalid K") + ensure(len(randomness) == SYMBYTES, "crypto/mlkem: invalid m") + ensure( + len(ciphertext) == ct_len_for_k(ek.pke_ek.k), + "crypto/mlkem: invalid ciphertext length", + ) + + buf: [2*SYMBYTES]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + hash_g(buf[:], randomness, ek.h[:]) + + // Can't fail, ciphertext length is valid. + _ = k_pke_encrypt(ciphertext, &ek.pke_ek, randomness, buf[SYMBYTES:]) + + copy(shared_secret, buf[:SYMBYTES]) +} + +kem_decaps_internal :: proc( + shared_secret: []byte, + dk: ^Decapsulation_Key, + ciphertext: []byte, +) { + ct_len := ct_len_for_k(dk.pke_dk.k) + ensure( + len(ciphertext) == ct_len, + "crypto/mlkem: invalid ciphertext length", + ) + + m_: [SYMBYTES]byte + defer crypto.zero_explicit(&m_, size_of(m_)) + + // Can't fail, ciphertext length is valid. + _ = k_pke_decrypt(m_[:], &dk.pke_dk, ciphertext) + + buf: [2*SYMBYTES]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + ek := &dk.ek + hash_g(buf[:], m_[:], ek.h[:]) + + rkprf(shared_secret, dk.seed[SYMBYTES:], ciphertext) + + ct_buf: [CIPHERTEXTBYTES_MAX]byte = --- + defer crypto.zero_explicit(&ct_buf, size_of(ct_buf)) + ct_ := ct_buf[:ct_len] + _ = k_pke_encrypt(ct_, &ek.pke_ek, m_[:], buf[SYMBYTES:]) + + ok := crypto.compare_constant_time(ciphertext, ct_) + subtle.cmov_bytes(shared_secret, buf[:SYMBYTES], ok) +} + +@(private="file") +ct_len_for_k :: proc(k: int) -> int { + switch k { + case K_512: + return CIPHERTEXTBYTES_512 + case K_768: + return CIPHERTEXTBYTES_768 + case K_1024: + return CIPHERTEXTBYTES_1024 + case: + panic("crypto/mlkem: invalid k for ciphertext length") + } +} diff --git a/core/crypto/_mlkem/ntt.odin b/core/crypto/_mlkem/ntt.odin new file mode 100644 index 000000000..c99c73487 --- /dev/null +++ b/core/crypto/_mlkem/ntt.odin @@ -0,0 +1,75 @@ +#+private +package _mlkem + +@(rodata) +ZETAS := [128]i16 { + -1044, -758, -359, -1517, 1493, 1422, 287, 202, + -171, 622, 1577, 182, 962, -1202, -1474, 1468, + 573, -1325, 264, 383, -829, 1458, -1602, -130, + -681, 1017, 732, 608, -1542, 411, -205, -1571, + 1223, 652, -552, 1015, -1293, 1491, -282, -1544, + 516, -8, -320, -666, -1618, -1162, 126, 1469, + -853, -90, -271, 830, 107, -1421, -247, -951, + -398, 961, -1508, -725, 448, -1065, 677, -1275, + -1103, 430, 555, 843, -1251, 871, 1550, 105, + 422, 587, 177, -235, -291, -460, 1574, 1653, + -246, 778, 1159, -147, -777, 1483, -602, 1119, + -1590, 644, -872, 349, 418, 329, -156, -75, + 817, 1097, 603, 610, 1322, -1285, -1465, 384, + -1215, -136, 1218, -1335, -874, 220, -1187, -1659, + -1185, -1530, -1278, 794, -1510, -854, -870, 478, + -108, -308, 996, 991, 958, -1460, 1522, 1628, +} + +@(require_results) +fqmul :: #force_inline proc "contextless" (a, b: i16) -> i16 { + return montgomery_reduce(i32(a) * i32(b)) +} + +ntt :: proc "contextless" (r: ^[N]i16) #no_bounds_check { + j, k := 0, 1 + for l := 128; l >= 2; l >>= 1 { + for start := 0; start < N; start = j + l { + zeta := ZETAS[k] + k += 1 + for j = start; j < start + l; j += 1 { + t := fqmul(zeta, r[j+l]) + r[j+l] = r[j] - t + r[j] = r[j] + t + } + } + } +} + +invntt :: proc "contextless" (r: ^[N]i16) #no_bounds_check { + F : i16 : 1441 // mont^2/128 + + j, k := 0, 127 + for l := 2; l <= 128; l <<= 1 { + for start := 0; start < 256; start = j+l { + zeta := ZETAS[k] + k -= 1 + for j = start; j < start + l; j += 1 { + t := r[j] + r[j] = barrett_reduce(t + r[j+l]) + r[j+l] = r[j+l] - t + r[j+l] = fqmul(zeta, r[j+l]) + } + } + } + + for v, i in r { + r[i] = fqmul(v, F) + } +} + +@(require_results) +base_case_multiply :: proc "contextless" (a_0, a_1, b_0, b_1, zeta: i16) -> (i16, i16) { + r_0 := fqmul(a_1, b_1) + r_0 = fqmul(r_0, zeta) + r_0 += fqmul(a_0, b_0) + r_1 := fqmul(a_0, b_1) + r_1 += fqmul(a_1, b_0) + + return r_0, r_1 +} diff --git a/core/crypto/_mlkem/poly.odin b/core/crypto/_mlkem/poly.odin new file mode 100644 index 000000000..4a79e9770 --- /dev/null +++ b/core/crypto/_mlkem/poly.odin @@ -0,0 +1,241 @@ +#+private +package _mlkem + +import "core:crypto" +import subtle "core:crypto/_subtle" + +// Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial +// coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1] +Poly :: struct { + coeffs: [N]i16, +} + +poly_compress :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check { + t: [8]byte = --- + defer crypto.zero_explicit(&t, size_of(t)) + + r := r + switch len(r) { + case POLYCOMPRESSEDBYTES_768: // Also covers _512 + for i in 0..> 15) & Q + // t[j] = ((((uint16_t)u << 4) + Q/2)/Q) & 15 + d0 := u32(u) << 4 + d0 += 1665 + d0 *= 80635 + d0 >>= 28 + t[j] = byte(d0) & 0xf + } + + r[0] = t[0] | (t[1] << 4) + r[1] = t[2] | (t[3] << 4) + r[2] = t[4] | (t[5] << 4) + r[3] = t[6] | (t[7] << 4) + r = r[4:] + } + case POLYCOMPRESSEDBYTES_1024: + for i in 0..> 15) & Q + // t[j] = ((((uint16_t)u << 5) + Q/2)/Q) & 31 + d0 := u32(u) << 5 + d0 += 1664 + d0 *= 40318 + d0 >>= 27 + t[j] = byte(d0) & 0x1f + } + + r[0] = (t[0] >> 0) | (t[1] << 5) + r[1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7) + r[2] = (t[3] >> 1) | (t[4] << 4) + r[3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6) + r[4] = (t[6] >> 2) | (t[7] << 3) + r = r[5:] + } + case: + panic_contextless("crypto/mlkem: invalid POLYCOMPRESSEDBYTES") + } +} + +poly_decompress :: proc "contextless" (r: ^Poly, a: []byte) { + a := a + switch len(a) { + case POLYCOMPRESSEDBYTES_768: // Also covers _512 + for i in 0..> 4) + r.coeffs[2*i+1] = i16(((u16(a[0] >> 4) * Q) + 8) >> 4) + a = a[1:] + } + case POLYCOMPRESSEDBYTES_1024: + t: [8]byte = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 0) + t[1] = (a[0] >> 5) | (a[1] << 3) + t[2] = (a[1] >> 2) + t[3] = (a[1] >> 7) | (a[2] << 1) + t[4] = (a[2] >> 4) | (a[3] << 4) + t[5] = (a[3] >> 1) + t[6] = (a[3] >> 6) | (a[4] << 2) + t[7] = (a[4] >> 3) + a = a[5:] + + for j in 0..<8 { + r.coeffs[8*i+j] = i16((u32(t[j] & 31) * Q + 16) >> 5) + } + } + case: + panic_contextless("crypto/mlkem: invalid POLYCOMPRESSEDBYTES") + } +} + +poly_tobytes :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check { + ensure_contextless(len(r) >= POLYBYTES) + + for i in 0..> 15) & Q) + t1 := u16(a.coeffs[2*i+1]) + t1 += u16((i16(t1) >> 15) & Q) + r[3*i+0] = byte(t0 >> 0) + r[3*i+1] = byte(t0 >> 8) | byte(t1 << 4) + r[3*i+2] = byte(t1 >> 4) + } +} + +@(require_results) +poly_frombytes :: proc "contextless" (r: ^Poly, a: []byte) -> bool #no_bounds_check { + ensure_contextless(len(a) >= POLYBYTES) + + ok := true + for i in 0..> 0) | (u16(a[3*i+1]) << 8)) & 0xFFF) + r.coeffs[2*i+1] = i16(((u16(a[3*i+1]) >> 4) | (u16(a[3*i+2]) << 4)) & 0xFFF) + ok &= r.coeffs[2*i] < Q && r.coeffs[2*i+1] < Q + } + + return ok +} + +poly_frommsg :: proc "contextless" (r: ^Poly, msg: []byte) #no_bounds_check { + #assert(INDCPA_MSGBYTES == N/8) + ensure_contextless(len(msg) == INDCPA_MSGBYTES) + + for i in 0..> uint(j))&1) + } + } +} + +poly_tomsg :: proc "contextless" (msg: []byte, a: ^Poly) #no_bounds_check { + ensure_contextless(len(msg) == INDCPA_MSGBYTES) + + for i in 0..> 15) & Q + // t = (((t << 1) + Q/2)/Q) & 1 + t <<= 1 + t += 1665 + t *= 80635 + t >>= 28 + t &= 1 + msg[i] |= byte(t << j) + } + } +} + +poly_getnoise_eta1_512 :: proc(r: ^Poly, seed: []byte, iv: byte) { + buf: [ETA1_512*N/4]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + prf(buf[:], seed, iv) + poly_cbd_eta1_512(r, &buf) +} + +poly_getnoise_eta1 :: proc(r: ^Poly, seed: []byte, iv: byte) { + buf: [ETA1*N/4]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + prf(buf[:], seed, iv) + poly_cbd_eta1(r, &buf) +} + +poly_getnoise_eta2 :: proc(r: ^Poly, seed: []byte, iv: byte) { + buf: [ETA2*N/4]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + prf(buf[:], seed, iv) + poly_cbd_eta2(r, &buf) +} + +poly_ntt :: proc "contextless" (r: ^Poly) { + ntt(&r.coeffs) + poly_reduce(r) +} + +poly_invntt_tomont :: proc "contextless" (r: ^Poly) { + invntt(&r.coeffs) +} + +poly_basemul_montgomery :: proc "contextless" (r, a, b: ^Poly) #no_bounds_check { + for i in 0.. int { + switch k { + case K_512: + return POLYCOMPRESSEDBYTES_512 + case K_768: + return POLYCOMPRESSEDBYTES_768 + case K_1024: + return POLYCOMPRESSEDBYTES_1024 + case: + panic_contextless("crypto/mlkem: invalid k") + } +} diff --git a/core/crypto/_mlkem/polyvec.odin b/core/crypto/_mlkem/polyvec.odin new file mode 100644 index 000000000..387acade8 --- /dev/null +++ b/core/crypto/_mlkem/polyvec.odin @@ -0,0 +1,224 @@ +#+private +package _mlkem + +import "core:crypto" + +Polyvec :: struct { + vec: [K_MAX]Poly, +} + +polyvec_compress :: proc "contextless" (r: []byte, a: ^Polyvec, kay: int) #no_bounds_check { + d0: u64 + + r := r + switch len(r) { + case POLYVECCOMPRESSEDBYTES_512, POLYVECCOMPRESSEDBYTES_768: + ensure_contextless(kay == K_512 || kay == K_768) + + t: [4]u16 = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 15) & Q) + // t[k] = ((((uint32_t)t[k] << 10) + Q/2)/Q) & 0x3ff + d0 = u64(t[k]) + d0 <<= 10 + d0 += 1665 + d0 *= 1290167 + d0 >>= 32 + t[k] = u16(d0 & 0x3ff) + } + + r[0] = byte(t[0] >> 0) + r[1] = byte((t[0] >> 8) | (t[1] << 2)) + r[2] = byte((t[1] >> 6) | (t[2] << 4)) + r[3] = byte((t[2] >> 4) | (t[3] << 6)) + r[4] = byte(t[3] >> 2) + r = r[5:] + } + } + case POLYVECCOMPRESSEDBYTES_1024: + ensure_contextless(kay == K_1024) + + t: [8]u16 = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 15) & Q) + // t[k] = ((((uint32_t)t[k] << 11) + Q/2)/Q) & 0x7ff + d0 = u64(t[k]) + d0 <<= 11 + d0 += 1664 + d0 *= 645084 + d0 >>= 31 + t[k] = u16(d0 & 0x7ff) + } + + r[0] = byte(t[0] >> 0) + r[1] = byte((t[0] >> 8) | (t[1] << 3)) + r[2] = byte((t[1] >> 5) | (t[2] << 6)) + r[3] = byte(t[2] >> 2) + r[4] = byte((t[2] >> 10) | (t[3] << 1)) + r[5] = byte((t[3] >> 7) | (t[4] << 4)) + r[6] = byte((t[4] >> 4) | (t[5] << 7)) + r[7] = byte(t[5] >> 1) + r[8] = byte((t[5] >> 9) | (t[6] << 2)) + r[9] = byte((t[6] >> 6) | (t[7] << 5)) + r[10] = byte(t[7] >> 3) + r = r[11:] + } + } + case: + panic_contextless("crypto/mlkem: invalid POLYVECCOMPRESSEDBYTES") + } +} + +polyvec_decompress :: proc "contextless" (r: ^Polyvec, a: []byte, kay: int) #no_bounds_check { + a := a + switch len(a) { + case POLYVECCOMPRESSEDBYTES_512, POLYVECCOMPRESSEDBYTES_768: + ensure_contextless(kay == K_512 || kay == K_768) + + t: [4]u16 = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 0) | (u16(a[1]) << 8) + t[1] = u16(a[1] >> 2) | (u16(a[2]) << 6) + t[2] = u16(a[2] >> 4) | (u16(a[3]) << 4) + t[3] = u16(a[3] >> 6) | (u16(a[4]) << 2) + a = a[5:] + + for k in 0..<4 { + r.vec[i].coeffs[4*j+k] = i16((u32(t[k] & 0x3FF) * Q + 512) >> 10) + } + } + } + case POLYVECCOMPRESSEDBYTES_1024: + t: [8]u16 = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 0) | (u16(a[1]) << 8) + t[1] = u16(a[1] >> 3) | (u16(a[2]) << 5) + t[2] = u16(a[2] >> 6) | (u16(a[3]) << 2) | (u16(a[4]) << 10) + t[3] = u16(a[4] >> 1) | (u16(a[5]) << 7) + t[4] = u16(a[5] >> 4) | (u16(a[6]) << 4) + t[5] = u16(a[6] >> 7) | (u16(a[7]) << 1) | (u16(a[8]) << 9) + t[6] = u16(a[8] >> 2) | (u16(a[9]) << 6) + t[7] = u16(a[9] >> 5) | (u16(a[10]) << 3) + a = a[11:] + + for k in 0..<8 { + r.vec[i].coeffs[8*j+k] = i16((u32(t[k] & 0x7FF) * Q + 1024) >> 11) + } + } + } + case: + panic_contextless("crypto/mlkem: invalid POLYVECCOMPRESSEDBYTES") + } +} + +polyvec_tobytes :: proc "contextless" (r: []byte, a: ^Polyvec, k: int) #no_bounds_check { + ensure_contextless(len(r) == k * POLYBYTES, "crypto/mlkem: invalid buffer") + + r := r + for i in 0.. bool #no_bounds_check { + switch k { + case K_512, K_768, K_1024: + case: + panic_contextless("crypto/mlkem: invalid POLYVECBYTES") + } + ensure_contextless(len(a) == k * POLYBYTES, "crypto/mlkem: invalid buffer") + + a := a + ok := true + for i in 0.. int { + switch k { + case K_512, K_768, K_1024: + return k * POLYBYTES + case: + return 0 + } +} + +@(require_results) +polyvec_compressed_byte_size :: #force_inline proc "contextless" (k: int) -> int { + switch k { + case K_512: + return POLYVECCOMPRESSEDBYTES_512 + case K_768: + return POLYVECCOMPRESSEDBYTES_768 + case K_1024: + return POLYVECCOMPRESSEDBYTES_1024 + case: + return 0 + } +} + +polyvec_ntt :: proc "contextless" (r: ^Polyvec, k: int) { + for i in 0.. i16 { + QINV :: -3327 // q^-1 mod 2^16 + + t := i16(a) * QINV + return i16((a - i32(t) * Q) >> 16) +} + +@(require_results) +barrett_reduce :: #force_inline proc "contextless" (a: i16) -> i16 { + V : i16 : ((1<<26) + Q / 2) / Q + + t := i16((i32(V)*i32(a) + (1<<25)) >> 26) + t *= Q + return a - t +} diff --git a/core/crypto/_mlkem/symmetric_shake.odin b/core/crypto/_mlkem/symmetric_shake.odin new file mode 100644 index 000000000..bf851f6f6 --- /dev/null +++ b/core/crypto/_mlkem/symmetric_shake.odin @@ -0,0 +1,61 @@ +#+private +package _mlkem + +import "core:crypto" +import "core:crypto/_sha3" +import "core:crypto/sha3" +import "core:crypto/shake" + +XOF_BLOCKBYTES :: _sha3.RATE_128 +#assert(XOF_BLOCKBYTES % 3 == 0) + +prf :: proc(out, key: []byte, iv: byte) { + ctx: shake.Context = --- + defer shake.reset(&ctx) + + shake.init_256(&ctx) + shake.write(&ctx, key) + shake.write(&ctx, []byte{iv}) + shake.read(&ctx, out) +} + +rkprf :: proc(out, key, input: []byte) { + ctx: shake.Context = --- + defer shake.reset(&ctx) + + shake.init_256(&ctx) + shake.write(&ctx, key) + shake.write(&ctx, input) + shake.read(&ctx, out) +} + +xof_absorb :: proc(ctx: ^shake.Context, seed: []byte, x, y: byte) { + shake.init_128(ctx) + + extseed: [SYMBYTES+2]byte = --- + defer crypto.zero_explicit(&extseed, size_of(extseed)) + + copy(extseed[:], seed) + extseed[SYMBYTES+0] = x + extseed[SYMBYTES+1] = y + + shake.write(ctx, extseed[:]) +} + +hash_h :: proc(dst, src: []byte) { + ctx: sha3.Context = --- + + sha3.init_256(&ctx) + sha3.update(&ctx, src) + sha3.final(&ctx, dst) +} + +hash_g :: proc(dst: []byte, srcs: ..[]byte) { + ctx: sha3.Context = --- + + sha3.init_512(&ctx) + for src in srcs { + sha3.update(&ctx, src) + } + sha3.final(&ctx, dst) +} diff --git a/core/crypto/_subtle/subtle.odin b/core/crypto/_subtle/subtle.odin index 454066e4a..01c84cf2a 100644 --- a/core/crypto/_subtle/subtle.odin +++ b/core/crypto/_subtle/subtle.odin @@ -3,6 +3,7 @@ Various useful bit operations in constant time. */ package _subtle +import "core:crypto/_fiat" import "core:math/bits" // byte_eq returns 1 if and only if (⟺) a == b, 0 otherwise. @@ -40,3 +41,41 @@ u64_is_non_zero :: proc "contextless" (a: u64) -> u64 { is_zero := u64_is_zero(a) return (~is_zero) & 1 } + +@(optimization_mode="none") +cmov_bytes :: proc "contextless" (dst, src: []byte, ctrl: int) { + s_len := len(src) + ensure_contextless(s_len == len(dst), "crypto: cmov length mismatch") + + c := -(byte)(ctrl) + for i in 0.. i16 { + c := -(u16)(ctrl) + return a ~ i16(c & u16(a ~ b)) +} + +@(optimization_mode="none") +csel_u16 :: proc "contextless" (a, b: u16, ctrl: int) -> u16 { + c := -(u16)(ctrl) + return a ~ (c & (a ~ b)) +} + +csel_u32 :: proc "contextless" (a, b: u32, ctrl: int) -> u32 { + return _fiat.cmovznz_u32(_fiat.u1(ctrl), a, b) +} + +csel_u64 :: proc "contextless" (a, b: u64, ctrl: int) -> u64 { + return _fiat.cmovznz_u64(_fiat.u1(ctrl), a, b) +} + +csel :: proc { + csel_i16, + csel_u16, + csel_u32, + csel_u64, +} diff --git a/core/crypto/mlkem/api.odin b/core/crypto/mlkem/api.odin new file mode 100644 index 000000000..9c8fd6a2b --- /dev/null +++ b/core/crypto/mlkem/api.odin @@ -0,0 +1,268 @@ +package mlkem + +import "core:crypto" +import "core:crypto/_mlkem" + +// Parameters are the supported ML-KEM parameter sets. +Parameters :: enum { + Invalid, + ML_KEM_512, + ML_KEM_768, + ML_KEM_1024, +} + +// DECAPSULATION_KEY_SEED_SIZE is the size of a Decapsulation key in bytes. +DECAPSULATION_KEY_SEED_SIZE :: 64 // (d, z) in NIST terms. + +// DECAPSULATION_KEY_EXPANDED_SIZES are the per-parameter sizes of the +// decapsulation key in bytes. +DECAPSULATION_KEY_EXPANDED_SIZES := [Parameters]int { + .Invalid = 0, + .ML_KEM_512 = _mlkem.DECAPSKEYBYTES_512, // 1632-bytes + .ML_KEM_768 = _mlkem.DECAPSKEYBYTES_768, // 2400-bytes + .ML_KEM_1024 = _mlkem.DECAPSKEYBYTES_1024, // 3168-bytes +} + +// ENCAPSULATION_KEY_SIZES are the per-parameter sizes of the encapsulation +// key in bytes. +ENCAPSULATION_KEY_SIZES := [Parameters]int { + .Invalid = 0, + .ML_KEM_512 = _mlkem.ENCAPSKEYBYTES_512, // 800-bytes + .ML_KEM_768 = _mlkem.ENCAPSKEYBYTES_768, // 1184-bytes + .ML_KEM_1024 = _mlkem.ENCAPSKEYBYTES_1024, // 1568-bytes +} + +// CIPHERTEXT_SIZES are the per-parameter set sizes of the ciphertext +// in bytes. +CIPHERTEXT_SIZES := [Parameters]int { + .Invalid = 0, + .ML_KEM_512 = _mlkem.CIPHERTEXTBYTES_512, // 768-bytes + .ML_KEM_768 = _mlkem.CIPHERTEXTBYTES_768, // 1088-bytes + .ML_KEM_1024 = _mlkem.CIPHERTEXTBYTES_1024, // 1568-bytes +} + +// SHARED_SECRET_SIZE is the size of the final shared secret in bytes. +SHARED_SECRET_SIZE :: 32 + +// Decapsulation_Key is a ML-KEM decapsulation (aka "private") key. +// This implementation opts to include the encapsulation (aka "public") +// key as well for cases where the decapsulation key is reused (eg: HPKE +// with X-Wing). +Decapsulation_Key :: _mlkem.Decapsulation_Key + +// Encapsulation_Key is a ML-KEM encapsulation (aka "public") key. +Encapsulation_Key :: _mlkem.Encapsulation_Key + +// decapsulation_key_generate uses the system entropy source to generate +// a decapsulation key. This will only fail if and only if (⟺) the system +// entropy source is missing or broken. +@(require_results) +decapsulation_key_generate :: proc(dk: ^Decapsulation_Key, params: Parameters) -> bool { + decapsulation_key_clear(dk) + + if !crypto.HAS_RAND_BYTES { + return false + } + + k := params_to_k(params) + if k == 0 { + panic("crypto/mlkem: invalid parameter set") + } + + seed: [DECAPSULATION_KEY_SEED_SIZE]byte = --- + defer crypto.zero_explicit(&seed, size_of(seed)) + + crypto.rand_bytes(seed[:]) + _mlkem.kem_keygen_internal(dk, seed[:], k) + + return true +} + +// decapsulation_key_set_bytes decodes a byte-encoded decapsulation key +// in (d, z) "seed" format, and returns true if and only if (⟺) the +// operation was successful. +@(require_results) +decapsulation_key_set_bytes :: proc(dk: ^Decapsulation_Key, params: Parameters, seed: []byte) -> bool { + k := params_to_k(params) + if k == 0 { + return false + } + if len(seed) != DECAPSULATION_KEY_SEED_SIZE { + return false + } + + _mlkem.kem_keygen_internal(dk, seed, k) + + return true +} + +// decapsulation_key_bytes sets dst to byte-encoding of dk in the (d, z) +// "seed" format. +decapsulation_key_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) { + ensure(dk.pke_dk.k != 0, "crypto/mlkem: uninitialized Decapsulation_Key") + ensure(len(dst) == DECAPSULATION_KEY_SEED_SIZE, "crypto/mlkem: invalid destination size") + + copy(dst, dk.seed[:]) +} + +// decapsulation_key_expanded_bytes sets dst to the byte-encoding of dk. +// in the expanded FIPS 203 format. This primarily exists for export +// purposes. +decapsulation_key_expanded_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) { + dk_len: int + switch dk.pke_dk.k { + case _mlkem.K_512: + dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_512] + case _mlkem.K_768: + dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_768] + case _mlkem.K_1024: + dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_1024] + case: + panic("crypto/mlkem: uninitialized Decapsulation_Key") + } + ensure(len(dst) == dk_len, "crypto/mlkem: invalid destination size") + + _mlkem.decapsulation_key_expanded_bytes(dk, dst) +} + +// decapsulation_key_encaps_bytes sets dst to the byte-encoding of the +// encasulation key corresponding to dk. +decapsulation_key_encaps_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) { + encapsulation_key_bytes(&dk.ek, dst) +} + +// decapsulation_key_clear clears dk to the uninitialized state. +decapsulation_key_clear :: proc(dk: ^Decapsulation_Key) { + crypto.zero_explicit(dk, size_of(Decapsulation_Key)) +} + +// encapsulation_key_set_bytes decodes a byte-encoded encapsulation key, +// and returns true if and only if (⟺) the operation was successful. +@(require_results) +encapsulation_key_set_bytes :: proc(ek: ^Encapsulation_Key, params: Parameters, b: []byte) -> bool { + k := params_to_k(params) + if k == 0 { + return false + } + if len(b) != ENCAPSULATION_KEY_SIZES[params] { + return false + } + + return _mlkem.encapsulation_key_set_bytes(ek, k, b) +} + +// encapsulation_key_set_decaps sets ek to the encapsulation key corresponding +// to dk. +encapsulation_key_set_decaps :: proc(ek: ^Encapsulation_Key, dk: ^Decapsulation_Key) { + ensure(dk.pke_dk.k != 0, "crypto/mlkem: uninitialized Decapsulation_Key") + _mlkem.encapsulation_key_set_decaps(ek, dk) +} + +// encapsulation_key_encaps_bytes sets dst to the byte-encoding of ek. +encapsulation_key_bytes :: proc(ek: ^Encapsulation_Key, dst: []byte) { + ensure(ek.pke_ek.k != 0, "crypto/mlkem: uninitialized Encapsulation_Key") + + k_len: int + switch ek.pke_ek.k { + case _mlkem.K_512: + k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_512] + case _mlkem.K_768: + k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_768] + case _mlkem.K_1024: + k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_1024] + case: + panic("crypto/mlkem: invalid destination size") + } + + copy(dst, ek.raw_bytes[:k_len]) +} + +// encapsulation_key_clear clears ek to the uninitialized state. +encapsulation_key_clear :: proc(ek: ^Encapsulation_Key) { + crypto.zero_explicit(ek, size_of(Encapsulation_Key)) +} + +// encaps_raw_ek_bytes uses the byte encoded encapsulation key to generate +// a shared secret and an associated ciphertext. This routine will fail +// if the system entropy source is unavailable, or of the encapsulation key +// is invalid. +@(require_results) +encaps_ek_raw_bytes :: proc(params: Parameters, raw_ek, shared_secret, ciphertext: []byte) -> bool { + ek: Encapsulation_Key = --- + if !encapsulation_key_set_bytes(&ek, params, raw_ek) { + return false + } + defer encapsulation_key_clear(&ek) + + return encaps_ek(&ek, shared_secret, ciphertext) +} + +// encaps_ek uses the encapsulation key to generate a shared secret and an +// associated ciphertext. This routine will fail if the system entropy source +// is unavailable. +@(require_results) +encaps_ek :: proc(ek: ^Encapsulation_Key, shared_secret, ciphertext: []byte) -> bool { + ensure(len(shared_secret) == SHARED_SECRET_SIZE, "crypto/mlkem: invalid shared_seret size") + + if !crypto.HAS_RAND_BYTES { + return false + } + + m: [_mlkem.SYMBYTES]byte = --- + defer crypto.zero_explicit(&m, size_of(m)) + + crypto.rand_bytes(m[:]) + _mlkem.kem_encaps_internal(shared_secret, ciphertext, ek, m[:]) + + return true +} + +encaps :: proc { + encaps_ek, + encaps_ek_raw_bytes, +} + +// decaps uses the decapsulation key to generate a shared secret from a +// ciphertext. Due to ML-KEM's implicit rejection mechanism, this function +// will only return false if and only if (⟺) the lengths of the inputs +// are invalid or the decapsulation key is uninitialized. +// +// This routine returning true does not guarantee that the shared secret +// matches that generated by the peer. +@(require_results) +decaps :: proc(dk: ^Decapsulation_Key, ciphertext, shared_secret: []byte) -> bool { + ensure(len(shared_secret) == SHARED_SECRET_SIZE, "crypto/mlkem: invalid shared_seret size") + + ct_len: int + switch dk.pke_dk.k { + case _mlkem.K_512: + ct_len = CIPHERTEXT_SIZES[.ML_KEM_512] + case _mlkem.K_768: + ct_len = CIPHERTEXT_SIZES[.ML_KEM_768] + case _mlkem.K_1024: + ct_len = CIPHERTEXT_SIZES[.ML_KEM_1024] + case: + return false + } + if len(ciphertext) != ct_len { + return false + } + + _mlkem.kem_decaps_internal(shared_secret, dk, ciphertext) + + return true +} + +@(private="file") +params_to_k :: #force_inline proc "contextless" (params: Parameters) -> int { + #partial switch params { + case .ML_KEM_512: + return _mlkem.K_512 + case .ML_KEM_768: + return _mlkem.K_768 + case .ML_KEM_1024: + return _mlkem.K_1024 + } + + return 0 +} diff --git a/core/crypto/mlkem/doc.odin b/core/crypto/mlkem/doc.odin new file mode 100644 index 000000000..ec32def25 --- /dev/null +++ b/core/crypto/mlkem/doc.odin @@ -0,0 +1,7 @@ +/* +ML-KEM Module-Lattice-Based Key-Encapsulation Mechanism. + +See: +- [[ https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.pdf ]] +*/ +package mlkem diff --git a/examples/all/all_js.odin b/examples/all/all_js.odin index 74cdb5fd3..f557cecdb 100644 --- a/examples/all/all_js.odin +++ b/examples/all/all_js.odin @@ -43,6 +43,7 @@ package all @(require) import "core:crypto/legacy/keccak" @(require) import "core:crypto/legacy/md5" @(require) import "core:crypto/legacy/sha1" +@(require) import "core:crypto/mlkem" @(require) import cnoise "core:crypto/noise" @(require) import "core:crypto/pbkdf2" @(require) import "core:crypto/poly1305" diff --git a/examples/all/all_main.odin b/examples/all/all_main.odin index 973ee423e..5565d225d 100644 --- a/examples/all/all_main.odin +++ b/examples/all/all_main.odin @@ -48,6 +48,7 @@ package all @(require) import "core:crypto/legacy/keccak" @(require) import "core:crypto/legacy/md5" @(require) import "core:crypto/legacy/sha1" +@(require) import "core:crypto/mlkem" @(require) import cnoise "core:crypto/noise" @(require) import "core:crypto/pbkdf2" @(require) import "core:crypto/poly1305" diff --git a/tests/benchmark/crypto/benchmark_pqc.odin b/tests/benchmark/crypto/benchmark_pqc.odin new file mode 100644 index 000000000..314594ca5 --- /dev/null +++ b/tests/benchmark/crypto/benchmark_pqc.odin @@ -0,0 +1,84 @@ +package benchmark_core_crypto + +import "core:log" +import "core:testing" +import "core:text/table" +import "core:time" + +import "core:crypto" +import "core:crypto/mlkem" + +@(private = "file") +MLKEM_ITERS :: 50000 + +@(test) +benchmark_crypto_mlkem :: proc(t: ^testing.T) { + if !crypto.HAS_RAND_BYTES { + log.warnf("ML-KEM benchmarks skipped, no system entropy source") + } + + tbl: table.Table + table.init(&tbl) + defer table.destroy(&tbl) + + table.caption(&tbl, "ML-KEM") + table.aligned_header_of_values(&tbl, .Right, "Parameters", "Keygen", "Encaps", "Decaps") + + append_tbl := proc(tbl: ^table.Table, algo_name: string, keygen, encaps, decaps: time.Duration) { + table.aligned_row_of_values( + tbl, + .Right, + algo_name, + table.format(tbl, "%8M", keygen), + table.format(tbl, "%8M", encaps), + table.format(tbl, "%8M", decaps), + ) + } + + for params in mlkem.Parameters { + if params == .Invalid { + continue + } + param_name := MLKEM_PARAMS_NAMES[params] + + decaps_key: mlkem.Decapsulation_Key + start := time.tick_now() + for _ in 0 ..< MLKEM_ITERS { + _ = mlkem.decapsulation_key_generate(&decaps_key, params) + } + keygen := time.tick_since(start) / MLKEM_ITERS + + encaps_key := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params]) + defer delete(encaps_key) + ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params]) + defer delete(ciphertext) + + mlkem.decapsulation_key_encaps_bytes(&decaps_key, encaps_key) + + bob_shared: [mlkem.SHARED_SECRET_SIZE]byte + start = time.tick_now() + for _ in 0 ..< MLKEM_ITERS { + _ = mlkem.encaps(params, encaps_key, bob_shared[:], ciphertext) + } + encaps := time.tick_since(start) / MLKEM_ITERS + + alice_shared: [mlkem.SHARED_SECRET_SIZE]byte + start = time.tick_now() + for _ in 0 ..< MLKEM_ITERS { + _ = mlkem.decaps(&decaps_key, ciphertext, alice_shared[:]) + } + decaps := time.tick_since(start) / MLKEM_ITERS + + append_tbl(&tbl, param_name, keygen, encaps, decaps) + } + + log_table(&tbl) +} + +@(private="file") +MLKEM_PARAMS_NAMES := [mlkem.Parameters]string { + .Invalid = "invalid", + .ML_KEM_512 = "ML-KEM-512", + .ML_KEM_768 = "ML-KEM-768", + .ML_KEM_1024 = "ML-KEM-1024", +} diff --git a/tests/core/crypto/test_core_crypto_pqc.odin b/tests/core/crypto/test_core_crypto_pqc.odin new file mode 100644 index 000000000..8359053e4 --- /dev/null +++ b/tests/core/crypto/test_core_crypto_pqc.odin @@ -0,0 +1,72 @@ +package test_core_crypto + +import "core:bytes" +import "core:log" +import "core:testing" + +import "core:crypto" +import "core:crypto/mlkem" + +@(test) +test_mlkem :: proc(t: ^testing.T) { + if !crypto.HAS_RAND_BYTES { + log.info("rand_bytes not supported - skipping") + return + } + + // Test vectors are huge, and are covered by the wycheproof corpus, + // so just test a full key exchange with all supported parameter + // sets. + for params in mlkem.Parameters { + if params == .Invalid { + continue + } + + // Alice + decaps_key: mlkem.Decapsulation_Key + if !testing.expectf( + t, + mlkem.decapsulation_key_generate(&decaps_key, params), + "%v: decapsulation_key_generate", + params, + ) { + continue + } + defer mlkem.decapsulation_key_clear(&decaps_key) + + ek_bytes := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params]) + defer delete(ek_bytes) + mlkem.decapsulation_key_encaps_bytes(&decaps_key, ek_bytes) + + // Bob + bob_shared_secret: [mlkem.SHARED_SECRET_SIZE]byte + ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params]) + defer delete(ciphertext) + if !testing.expectf( + t, + mlkem.encaps(params, ek_bytes, bob_shared_secret[:], ciphertext), + "%v: encaps", + params, + ) { + continue + } + + // Alice + alice_shared_secret: [mlkem.SHARED_SECRET_SIZE]byte + if !testing.expectf( + t, + mlkem.decaps(&decaps_key, ciphertext, alice_shared_secret[:]), + "%v: decaps", + params, + ) { + continue + } + + testing.expectf( + t, + bytes.equal(alice_shared_secret[:], bob_shared_secret[:]), + "%v: shared secret mismatch", + params, + ) + } +} diff --git a/tests/core/crypto/wycheproof/main.odin b/tests/core/crypto/wycheproof/main.odin index dfdc78267..654ac2a38 100644 --- a/tests/core/crypto/wycheproof/main.odin +++ b/tests/core/crypto/wycheproof/main.odin @@ -35,6 +35,16 @@ import "core:testing" // - crypto/kmac // - kmac128_no_customization_test.json // - kmac256_no_customization_test.json +// - crypto/mlkem +// - mlkem_512_keygen_seed_test.json +// - mlkem_512_encaps_test.json +// - mlkem_512_test.json +// - mlkem_768_keygen_seed_test.json +// - mlkem_768_encaps_test.json +// - mlkem_768_test.json +// - mlkem_1024_keygen_seed_test.json +// - mlkem_1024_encaps_test.json +// - mlkem_1024_test.json // - crypto/pbkdf2 // - pbkdf2_hmacsha1_test.json // - pbkdf2_hmacsha224_test.json diff --git a/tests/core/crypto/wycheproof/pqc.odin b/tests/core/crypto/wycheproof/pqc.odin new file mode 100644 index 000000000..210d7be1d --- /dev/null +++ b/tests/core/crypto/wycheproof/pqc.odin @@ -0,0 +1,401 @@ +package test_wycheproof + +import "core:encoding/hex" +import "core:log" +import "core:mem" +import "core:os" +import "core:testing" + +import "core:crypto/_mlkem" +import "core:crypto/mlkem" + +import "../common" + +@(test) +test_mlkem :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + log.debug("mlkem: starting") + + files_keygen := []string { + "mlkem_512_keygen_seed_test.json", + "mlkem_768_keygen_seed_test.json", + "mlkem_1024_keygen_seed_test.json", + } + for f in files_keygen { + mem.free_all() + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Kem_Test_Group) + load_ok := load(&test_vectors, fn) + if !testing.expectf(t, load_ok, "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_mlkem_keygen(t, &test_vectors), "ML-KEM KeyGen failed") + } + + files_encaps := []string { + "mlkem_512_encaps_test.json", + "mlkem_768_encaps_test.json", + "mlkem_1024_encaps_test.json", + } + for f in files_encaps { + mem.free_all() + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Kem_Test_Group) + load_ok := load(&test_vectors, fn) + if !testing.expectf(t, load_ok, "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_mlkem_encaps(t, &test_vectors), "ML-KEM Encaps failed") + } + + files_decaps := []string { + "mlkem_512_test.json", + "mlkem_768_test.json", + "mlkem_1024_test.json", + } + for f in files_decaps { + mem.free_all() + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Kem_Test_Group) + load_ok := load(&test_vectors, fn) + if !testing.expectf(t, load_ok, "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_mlkem_decaps(t, &test_vectors), "ML-KEM Decaps failed") + } +} + +test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool { + params_str := test_vectors.test_groups[0].parameter_set + params := parameter_set_to_params(params_str) + if params == .Invalid { + return false + } + + log.debugf("%s: KeyGen starting", params_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, tg_id in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + seed := common.hexbytes_decode(test_vector.seed) + + dk: mlkem.Decapsulation_Key + if !testing.expectf( + t, + mlkem.decapsulation_key_set_bytes(&dk, params, seed), + "%s/KeyGen/%d/%d: failed to set decapsulation key from seed", + params_str, + tg_id, + test_vector.tc_id, + test_vector.seed, + ) { + num_failed *= 1 + continue + } + + ek_bytes := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params]) + mlkem.decapsulation_key_encaps_bytes(&dk, ek_bytes) + + ok := common.hexbytes_compare(test_vector.ek, ek_bytes) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(ek_bytes)) + log.errorf( + "%s/KeyGen/%d/%d: ek: expected %s actual %s", + params_str, + tg_id, + test_vector.tc_id, + test_vector.ek, + x, + ) + num_failed += 1 + continue + } + + dk_bytes := make([]byte, mlkem.DECAPSULATION_KEY_EXPANDED_SIZES[params]) + mlkem.decapsulation_key_expanded_bytes(&dk, dk_bytes) + + ok = common.hexbytes_compare(test_vector.dk, dk_bytes) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(dk_bytes)) + log.errorf( + "%s/KeyGen/%d/%d: dk: expected %s actual %s", + tg_id, + params_str, + test_vector.tc_id, + test_vector.dk, + x, + ) + num_failed += 1 + continue + } + + seed_bytes: [mlkem.DECAPSULATION_KEY_SEED_SIZE]byte + mlkem.decapsulation_key_bytes(&dk, seed_bytes[:]) + + ok = common.hexbytes_compare(test_vector.seed, seed_bytes[:]) + if !result_check(test_vector.result, ok) { + x := transmute(string)(hex.encode(seed_bytes[:])) + log.errorf( + "%s/KeyGen/%d/%d: seed: expected %s actual %s", + tg_id, + params_str, + test_vector.tc_id, + test_vector.seed, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "%s/KeyGen: ran %d, passed %d, failed %d, skipped %d", + params_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool { + params_str := test_vectors.test_groups[0].parameter_set + params := parameter_set_to_params(params_str) + if params == .Invalid { + return false + } + + log.debugf("%s: Encaps starting", params_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, tg_id in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + ek: mlkem.Encapsulation_Key + ok := mlkem.encapsulation_key_set_bytes( + &ek, + params, + common.hexbytes_decode(test_vector.ek), + ) + + // The current corpus can only fail if the encapsulation key + // is malformed in some way. + if !result_check(test_vector.result, ok) { + log.errorf( + "%s/Encaps/%d/%d: unexpected set encapsulation key from bytes: %s (%v != %v)", + params_str, + tg_id, + test_vector.tc_id, + test_vector.ek, + test_vector.result, + ok, + ) + num_failed += 1 + continue + } + if !ok { + num_passed += 1 + continue + } + + shared_secret: [mlkem.SHARED_SECRET_SIZE]byte + ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params]) + + _mlkem.kem_encaps_internal( + shared_secret[:], + ciphertext, + &ek, + common.hexbytes_decode(test_vector.m), + ) + + ok = common.hexbytes_compare(test_vector.c, ciphertext) + if !ok { + x := transmute(string)(hex.encode(ciphertext)) + log.errorf( + "%s/Encaps/%d/%d: ciphertext: expected: %s actual: %s", + params_str, + tg_id, + test_vector.tc_id, + test_vector.c, + x, + ) + num_failed += 1 + continue + } + + ok = common.hexbytes_compare(test_vector.k, shared_secret[:]) + if !ok { + x := transmute(string)(hex.encode(shared_secret[:])) + log.errorf( + "%s/Encaps/%d/%d: shared_secret: expected: %s actual: %s", + params_str, + tg_id, + test_vector.tc_id, + test_vector.k, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "%s/Encaps: ran %d, passed %d, failed %d, skipped %d", + params_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool { + params_str := test_vectors.test_groups[0].parameter_set + params := parameter_set_to_params(params_str) + if params == .Invalid { + return false + } + + log.debugf("%s: Decaps starting", params_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, tg_id in test_vectors.test_groups { + for &test_vector in test_group.tests { + num_ran += 1 + + // We do not have an API for decaps with raw seed. + seed := common.hexbytes_decode(test_vector.seed) + switch len(seed) { + case mlkem.DECAPSULATION_KEY_SEED_SIZE: + case: + if testing.expectf( + t, + result_is_invalid(test_vector.result), + "%s/Decaps/%d/%d: test vector expects success with invalid seed", + params_str, + tg_id, + test_vector.tc_id, + ) { + num_passed += 1 + } else { + num_failed += 1 + } + continue + } + + dk: mlkem.Decapsulation_Key + if !testing.expectf( + t, + mlkem.decapsulation_key_set_bytes(&dk, params, seed), + "%s/Decaps/%d/%d: failed to set decapsulation key from seed", + params_str, + tg_id, + test_vector.tc_id, + test_vector.seed, + ) { + num_failed *= 1 + continue + } + + shared_secret: [mlkem.SHARED_SECRET_SIZE]byte + + ok := mlkem.decaps( + &dk, + common.hexbytes_decode(test_vector.c), + shared_secret[:], + ) + if !result_check(test_vector.result, ok) { + log.errorf( + "%s/Decaps/%d/%d: unexpected decapsulation failure", + params_str, + tg_id, + test_vector.tc_id, + ) + num_failed += 1 + continue + } + if !ok { + num_passed += 1 + continue + } + + ok = common.hexbytes_compare(test_vector.k, shared_secret[:]) + if !ok { + x := transmute(string)(hex.encode(shared_secret[:])) + log.errorf( + "%s/Decaps/%d/%d: shared_secret: expected: %s actual: %s", + params_str, + tg_id, + test_vector.tc_id, + test_vector.k, + x, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "%s/Decaps: ran %d, passed %d, failed %d, skipped %d", + params_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +@(require_results, private="file") +parameter_set_to_params :: proc(s: string) -> mlkem.Parameters { + switch s { + case "ML-KEM-512": + return .ML_KEM_512 + case "ML-KEM-768": + return .ML_KEM_768 + case "ML-KEM-1024": + return .ML_KEM_1024 + case: + return .Invalid + } +} diff --git a/tests/core/crypto/wycheproof/schemas.odin b/tests/core/crypto/wycheproof/schemas.odin index 645f0f085..d801207a0 100644 --- a/tests/core/crypto/wycheproof/schemas.odin +++ b/tests/core/crypto/wycheproof/schemas.odin @@ -64,6 +64,11 @@ Test_Vectors_Note :: struct { links: []string `json:"links"`, } +Test_Group_Source :: struct { + name: string `json:"name"`, + version: string `json:"version"`, +} + Aead_Test_Group :: struct { iv_size: int `json:"ivSize"`, key_size: int `json:"keySize"`, @@ -198,3 +203,23 @@ Pbkdf_Test_Vector :: struct { result: Result `json:"result"`, flags: []string `json:"flags"`, } + +Kem_Test_Group :: struct { + type: string `json:"type"`, + source: Test_Group_Source `json:"source"`, + parameter_set: string `json:"parameterSet"`, + tests: []Kem_Test_Vector `json:"tests"`, +} + +Kem_Test_Vector :: struct { + tc_id: int `json:"tcId"`, + flags: []string `json:"flags"`, + comment: string `json:"comment"`, + seed: common.Hex_Bytes `json:"seed"`, + m: common.Hex_Bytes `json:"m"`, + ek: common.Hex_Bytes `json:"ek"`, + dk: common.Hex_Bytes `json:"dk"`, + c: common.Hex_Bytes `json:"c"`, + k: common.Hex_Bytes `json:"K"`, + result: Result `json:"result"`, +}