diff --git a/core/crypto/_mldsa/constants.odin b/core/crypto/_mldsa/constants.odin new file mode 100644 index 000000000..6a9b6e096 --- /dev/null +++ b/core/crypto/_mldsa/constants.odin @@ -0,0 +1,72 @@ +#+private +package _mldsa + +CRHBYTES :: 64 +TRBYTES :: 64 + +N :: 256 +Q :: 8380417 +D :: 13 + +K_MAX :: 8 +L_MAX :: 7 + +POLYZ_PACKEDBYTES_MAX :: 640 + +POLYT1_PACKEDBYTES :: 320 +POLYT0_PACKEDBYTES :: 416 + +POLYVECT1_PACKEDBYTES_MAX :: K_MAX * POLYT1_PACKEDBYTES +POLYW1_PACKEDBYTES_MAX :: 192 + +CTILDBYTES_MAX :: 64 + +@(require_results) +polyeta_packedbytes :: #force_inline proc "contextless" (params: ^Params) -> int { + POLYETA_PACKEDBYTES_2 :: 96 + POLYETA_PACKEDBYTES_4 :: 128 + + switch params.eta { + case 2: + return POLYETA_PACKEDBYTES_2 + case 4: + return POLYETA_PACKEDBYTES_4 + case: + unreachable() + } +} + +@(require_results) +polyz_packedbytes :: #force_inline proc "contextless" (params: ^Params) -> int { + POLYZ_PACKEDBYTES_GAMMA1_17 :: 576 + POLYZ_PACKEDBYTES_GAMMA1_19 :: 640 + + switch params.gamma1 { + case 1 << 17: + return POLYZ_PACKEDBYTES_GAMMA1_17 + case 1 << 19: + return POLYZ_PACKEDBYTES_GAMMA1_19 + case: + unreachable() + } +} + +@(require_results) +polyw1_packedbytes :: #force_inline proc "contextless" (params: ^Params) -> int { + POLYW1_PACKEDBYTES_GAMMA2_95232 :: 192 + POLYW1_PACKEDBYTES_GAMMA2_261888 :: 128 + + switch params.gamma2 { + case (Q-1)/88: + return POLYW1_PACKEDBYTES_GAMMA2_95232 + case (Q-1)/32: + return POLYW1_PACKEDBYTES_GAMMA2_261888 + case: + unreachable() + } +} + +@(require_results) +polyvech_packedbytes :: #force_inline proc "contextless" (params: ^Params) -> int { + return params.omega + params.k +} diff --git a/core/crypto/_mldsa/dsa_internal.odin b/core/crypto/_mldsa/dsa_internal.odin new file mode 100644 index 000000000..bfb6b7508 --- /dev/null +++ b/core/crypto/_mldsa/dsa_internal.odin @@ -0,0 +1,394 @@ +package _mldsa + +import "core:crypto" +import "core:crypto/shake" + +// This implementation is derived from the PQ-CRYSTALS reference +// implementation [[ https://github.com/pq-crystals/dilithium ]], +// primarily for licensing reasons. Arguably mldsa-native is +// a more "up to date" codebase, but the changes to the +// ref code is minor and they slapped an attribution-required +// license on something that was originally CC-0/Apache 2.0. + +SEEDBYTES :: 32 +RNDBYTES :: 32 +CTXBYTES_MAX :: 255 + +Params :: struct { + k: int, + l: int, + eta: i32, + tau: int, + beta: i32, + gamma1: i32, + gamma2: i32, + omega: int, + ctild_bytes: int, +} + +@(rodata) +Params_44 := Params{ + k = 4, + l = 4, + eta = 2, + tau = 39, + beta = 78, + gamma1 = 1 << 17, + gamma2 = (Q-1)/88, + omega = 80, + ctild_bytes = 32, +} + +@(rodata) +Params_65 := Params{ + k = 6, + l = 5, + eta = 4, + tau = 49, + beta = 196, + gamma1 = 1 << 19, + gamma2 = (Q-1)/32, + omega = 55, + ctild_bytes = 48, +} + +@(rodata) +Params_87 := Params{ + k = 8, + l = 7, + eta = 2, + tau = 60, + beta = 120, + gamma1 = 1 << 19, + gamma2 = (Q-1)/32, + omega = 75, + ctild_bytes = 64, +} + +Private_Key :: struct { + params: ^Params, + + rho: [SEEDBYTES]byte, + tr: [TRBYTES]byte, + key: [SEEDBYTES]byte, + t0: Polyvec_K, + s1: Polyvec_L, + s2: Polyvec_K, + + pub_key: Public_Key, + seed: [SEEDBYTES]byte, +} + +Public_Key :: struct { + params: ^Params, + + t1: Polyvec_K, + rho: [SEEDBYTES]byte, + mu: [TRBYTES]byte, +} + +@(private) +Signature :: struct { + params: ^Params, + + c: [CTILDBYTES_MAX]byte, + z: Polyvec_L, + h: Polyvec_K, +} + +dsa_keygen_internal :: proc( + priv_key: ^Private_Key, + seed: []byte, + params: ^Params, +) { + ensure(len(seed) == SEEDBYTES, "crypto/mldsa: invalid seed") + + pub_key := &priv_key.pub_key + pub_key.params = params + priv_key.params = params + + copy(priv_key.seed[:], seed) + + seedbuf: [2*SEEDBYTES + CRHBYTES]byte = --- + mat_: [K_MAX]Polyvec_L = --- + defer crypto.zero_explicit(&seedbuf, size_of(seedbuf)) + defer crypto.zero_explicit(&mat_, size_of(mat_)) + + // Expand randomness for rho, rhoprime and key + copy(seedbuf[:], seed) + seedbuf[SEEDBYTES] = byte(params.k) + seedbuf[SEEDBYTES+1] = byte(params.l) + shake256(seedbuf[:], seedbuf[:SEEDBYTES+2]) + copy(priv_key.rho[:], seedbuf[:SEEDBYTES]) + rhoprime := seedbuf[SEEDBYTES:SEEDBYTES+CRHBYTES] + copy(priv_key.key[:], seedbuf[SEEDBYTES+CRHBYTES:]) + + // Expand matrix + mat := mat_[:params.k] + polyvec_matrix_expand(mat, priv_key.rho[:], params) + + // Sample short vectors s1 and s2 + polyvec_l_uniform_eta(&priv_key.s1, rhoprime, 0, params) + polyvec_k_uniform_eta(&priv_key.s2, rhoprime, u16(params.l), params) + + // Matrix-vector multiplication + s1hat: Polyvec_L = --- + defer crypto.zero_explicit(&s1hat, size_of(Polyvec_L)) + polyvec_copy(&s1hat, &priv_key.s1, params) + polyvec_l_ntt(&s1hat, params) + polyvec_matrix_pointwise_montgomery(&pub_key.t1, mat, &s1hat, params) + polyvec_k_reduce(&pub_key.t1, params) + polyvec_k_invntt_tomont(&pub_key.t1, params) + + // Add error vector s2 + polyvec_k_add(&pub_key.t1, &pub_key.t1, &priv_key.s2, params) + + // Extract t1 and write public key + pk_bytes_: [SEEDBYTES+POLYVECT1_PACKEDBYTES_MAX]byte = --- + pk_bytes := pk_bytes_[:public_key_size(params)] + polyvec_k_caddq(&pub_key.t1, params) + polyvec_k_power2round(&pub_key.t1, &priv_key.t0, &pub_key.t1, params) + copy(pub_key.rho[:], priv_key.rho[:]) + _ = pack_pk(pk_bytes, pub_key) + + // Compute H(rho, t1) and write secret key + shake256(pub_key.mu[:], pk_bytes) + copy(priv_key.tr[:], pub_key.mu[:]) +} + +dsa_sign_internal :: proc( + sig_bytes: []byte, + m: []byte, + ctx: []byte, + rnd: []byte, + priv_key: ^Private_Key, + external_mu: []byte = nil +) -> bool { + params := priv_key.params + switch params { + case &Params_44, &Params_65, &Params_87: + case: + return false + } + if len(sig_bytes) != signature_size(params) { + return false + } + ensure(len(ctx) <= CTXBYTES_MAX, "crypto/mlkem: invalid contxt size") + ensure(len(rnd) == RNDBYTES, "crypto/mlkem: invalid rnd size") + + mu, rhoprime: [CRHBYTES]byte = ---, --- + mat_: [K_MAX]Polyvec_L + w1_bytes_: [SEEDBYTES+POLYVECT1_PACKEDBYTES_MAX]byte = --- + s1, y: Polyvec_L = ---, --- + t0, s2, w1, w0: Polyvec_K = ---, ---, ---, --- + cp: Poly + + polyvec_copy(&s1, &priv_key.s1, params) + polyvec_copy(&s2, &priv_key.s2, params) + polyvec_copy(&t0, &priv_key.t0, params) + + defer crypto.zero_explicit(&mu, size_of(mu)) + defer crypto.zero_explicit(&rhoprime, size_of(rhoprime)) + defer crypto.zero_explicit(&mat_, size_of(mat_)) + defer crypto.zero_explicit(&w1_bytes_, size_of(w1_bytes_)) + defer polyvec_clear([]^Polyvec_L{&s1, &y}) + defer polyvec_clear([]^Polyvec_K{&t0, &s2, &w1, &w0}) + defer crypto.zero_explicit(&cp, size_of(cp)) + + sig: Signature = --- + sig.params = params + h := &sig.h + z := &sig.z + c := sig.c[:params.ctild_bytes] + + w1_bytes := w1_bytes_[:params.k*polyw1_packedbytes(params)] + + // Compute mu = CRH(tr, pre, msg) + if len(external_mu) == 0 { + // The FIPS publication handles the shake prefix + // in the public sign operation, but doing it + // here makes more sense. + ctx_buf: [2]byte + shake_ctx: shake.Context = --- + defer shake.reset(&shake_ctx) + + ctx_len := len(ctx) + + shake.init_256(&shake_ctx) + shake.write(&shake_ctx, priv_key.tr[:]) + if ctx_len > 0 { + ctx_buf[1] = byte(ctx_len) + } + shake.write(&shake_ctx, ctx_buf[:]) + if ctx_len > 0 { + shake.write(&shake_ctx, ctx) + } + shake.write(&shake_ctx, m) + shake.read(&shake_ctx, mu[:]) + } else { + ensure(len(external_mu) == CRHBYTES, "crypto/mlkem: invalid external mu") + copy(mu[:], external_mu) + } + + // Compute rhoprime = CRH(key, rnd, mu) + shake256(rhoprime[:], priv_key.key[:], rnd, mu[:]) + + // Expand matrix and transform vectors + mat := mat_[:params.k] + polyvec_matrix_expand(mat, priv_key.rho[:], params) + polyvec_l_ntt(&s1, params) + polyvec_k_ntt(&s2, params) + polyvec_k_ntt(&t0, params) + + // Rejection-sampling loop + iv: u32 // ref uses u16, but ML-DSA-87 will reuse the IV at p = ~2^{-23400} + for { + // Sample intermediate vector y + polyvec_l_uniform_gamma1(&y, rhoprime[:], iv, params) + iv += 1 + + // Matrix-vector multiplication + polyvec_copy(z, &y, params) + polyvec_l_ntt(z, params) + polyvec_matrix_pointwise_montgomery(&w1, mat, z, params) + polyvec_k_reduce(&w1, params) + polyvec_k_invntt_tomont(&w1, params) + + // Decompose w and call the random oracle + polyvec_k_caddq(&w1, params) + polyvec_k_decompose(&w1, &w0, &w1, params) + polyvec_k_pack_w1(w1_bytes, &w1, params) + + shake256(c, mu[:], w1_bytes) + poly_challenge(&cp, c, params) + poly_ntt(&cp) + + // Compute z, reject if it reveals secret + polyvec_l_pointwise_poly_montgomery(z, &cp, &s1, params) + polyvec_l_invntt_tomont(z, params) + polyvec_l_add(z, z, &y, params) + polyvec_l_reduce(z, params) + if polyvec_l_chknorm(z, params.gamma1 - params.beta, params) { + continue + } + + // Check that subtracting cs2 does not change high bits of w + // and low bits do not reveal secret information + polyvec_k_pointwise_poly_montgomery(h, &cp, &s2, params) + polyvec_k_invntt_tomont(h, params) + polyvec_k_sub(&w0, &w0, h, params) + polyvec_k_reduce(&w0, params) + if polyvec_k_chknorm(&w0, params.gamma2 - params.beta, params) { + continue + } + + // Compute hints for w1 + polyvec_k_pointwise_poly_montgomery(h, &cp, &t0, params) + polyvec_k_invntt_tomont(h, params) + polyvec_k_reduce(h, params) + if polyvec_k_chknorm(h, params.gamma2, params) { + continue + } + + polyvec_k_add(&w0, &w0, h, params) + n := polyvec_k_make_hint(h, &w0, &w1, params) + if n <= uint(params.omega) { + break + } + } + + // Write signature + return pack_sig(sig_bytes, &sig) +} + +dsa_verify_internal :: proc( + sig_bytes: []byte, + m: []byte, + ctx: []byte, + pub_key: ^Public_Key, +) -> bool { + ensure(len(ctx) <= CTXBYTES_MAX, "crypto/mlkem: invalid contxt size") + + params := pub_key.params + switch params { + case &Params_44, &Params_65, &Params_87: + case: + return false + } + + sig: Signature = --- + if !unpack_sig(&sig, sig_bytes, params) { + return false + } + if polyvec_l_chknorm(&sig.z, params.gamma1 - params.beta, params) { + return false + } + c := sig.c[:params.ctild_bytes] + z := &sig.z + h := &sig.h + + t1: Polyvec_K = --- + polyvec_copy(&t1, &pub_key.t1, params) + rho := pub_key.rho[:] + + // Compute CRH(H(rho, t1), pre, msg) + mu: [CRHBYTES]byte + { + // The FIPS publication handles the shake prefix + // in the public sign operation, but doing it + // here makes more sense. + ctx_buf: [2]byte + shake_ctx: shake.Context = --- + defer shake.reset(&shake_ctx) + + ctx_len := len(ctx) + + shake.init_256(&shake_ctx) + shake.write(&shake_ctx, pub_key.mu[:]) + if ctx_len > 0 { + ctx_buf[1] = byte(ctx_len) + } + shake.write(&shake_ctx, ctx_buf[:]) + if ctx_len > 0 { + shake.write(&shake_ctx, ctx) + } + shake.write(&shake_ctx, m) + shake.read(&shake_ctx, mu[:]) + } + + // Matrix-vector multiplication; compute Az - c2^dt1 + mat_: [K_MAX]Polyvec_L + w1: Polyvec_K = --- + cp: Poly = --- + mat := mat_[:params.l] + + poly_challenge(&cp, c, params) + polyvec_matrix_expand(mat, rho, params) + + polyvec_l_ntt(z, params) + polyvec_matrix_pointwise_montgomery(&w1, mat, z, params) + + poly_ntt(&cp) + polyvec_k_shiftl(&t1, params) + polyvec_k_ntt(&t1, params) + polyvec_k_pointwise_poly_montgomery(&t1, &cp, &t1, params) + + polyvec_k_sub(&w1, &w1, &t1, params) + polyvec_k_reduce(&w1, params) + polyvec_k_invntt_tomont(&w1, params) + + // Reconstruct w1 + buf_: [K_MAX*POLYW1_PACKEDBYTES_MAX]byte = --- + buf := buf_[:params.k*polyw1_packedbytes(params)] + polyvec_k_caddq(&w1, params) + polyvec_k_use_hint(&w1, &w1, h, params) + polyvec_k_pack_w1(buf, &w1, params) + + // Call random oracle and verify challenge + c2_: [CTILDBYTES_MAX]byte + c2 := c2_[:params.ctild_bytes] + shake256(c2, mu[:], buf) + + // Note/perf: Can be vartime + return crypto.compare_constant_time(c, c2) == 1 +} diff --git a/core/crypto/_mldsa/ntt.odin b/core/crypto/_mldsa/ntt.odin new file mode 100644 index 000000000..c36f88327 --- /dev/null +++ b/core/crypto/_mldsa/ntt.odin @@ -0,0 +1,75 @@ +#+private +package _mldsa + +@(rodata) +ZETAS := [N]i32 { + 0, 25847, -2608894, -518909, 237124, -777960, -876248, 466468, + 1826347, 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103, + 2725464, 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549, + -2118186, -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005, + 2706023, 95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439, + -3861115, -3043716, 3574422, -2867647, 3539968, -300467, 2348700, -539299, + -1699267, -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596, + 811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892, -2797779, + -3930395, -1528703, -3677745, -3041255, -1452451, 3475950, 2176455, -1585221, + -1257611, 1939314, -4083598, -1000202, -3190144, -3157330, -3632928, 126922, + 3412210, -983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047, + -671102, -1228525, -22981, -1308169, -381987, 1349076, 1852771, -1430430, + -3343383, 264944, 508951, 3097992, 44288, -1100098, 904516, 3958618, + -3724342, -8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856, + 189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330, + 1285669, -1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961, + 2091667, 3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462, + 266997, 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378, + 900702, 1859098, 909542, 819034, 495491, -1613174, -43260, -522500, + -655327, -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838, + 342297, 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044, + 2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974, + -3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970, + -1333058, 1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642, + -1279661, 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031, + -542412, -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993, + -2013608, 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385, + -3183426, 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107, + -3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078, + -426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893, + -2939036, -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687, + -554416, 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782, +} + +ntt :: proc "contextless" (a: ^[N]i32) #no_bounds_check { + j, k := 0, 1 + for l := 128; l > 0; l >>= 1 { + for start := 0; start < N; start = j + l { + zeta := ZETAS[k] + k += 1 + for j = start; j < start + l; j += 1 { + t := montgomery_reduce(i64(zeta) * i64(a[j + l])) + a[j + l] = a[j] - t + a[j] = a[j] + t + } + } + } +} + +invntt_tomont :: proc "contextless" (a: ^[N]i32) #no_bounds_check { + F :: 41978 // mont^2/256 + + j, k := 0, 255 + for l := 1; l < N; l <<= 1 { + for start := 0; start < N; start = j + l { + zeta := -ZETAS[k] + k -= 1 + for j = start; j < start + l; j += 1 { + t := a[j] + a[j] = t + a[j + l] + a[j + l] = t - a[j + l] + a[j + l] = montgomery_reduce(i64(zeta) * i64(a[j + l])) + } + } + } + + for i in 0.. bool { + if len(pk_bytes) != public_key_size(pub_key.params) { + return false + } + + seed_bytes, t1_bytes := pk_bytes[:SEEDBYTES], pk_bytes[SEEDBYTES:] + + copy(seed_bytes, pub_key.rho[:]) + + for i in 0.. bool { + if len(pk_bytes) != public_key_size(params) { + return false + } + + seed_bytes, t1_bytes := pk_bytes[:SEEDBYTES], pk_bytes[SEEDBYTES:] + + pub_key.params = params + + copy(pub_key.rho[:], seed_bytes) + + for i in 0.. bool { + params := priv_key.params + if len(sk_bytes) != private_key_size(params) { + return false + } + + sk_bytes := sk_bytes + polyeta_len := polyeta_packedbytes(params) + + copy(sk_bytes, priv_key.rho[:]) + sk_bytes = sk_bytes[SEEDBYTES:] + + copy(sk_bytes, priv_key.key[:]) + sk_bytes = sk_bytes[SEEDBYTES:] + + copy(sk_bytes, priv_key.tr[:]) + sk_bytes = sk_bytes[TRBYTES:] + + for i in 0.. bool { + if len(sig_bytes) != signature_size(sig.params) { + return false + } + + sig_bytes := sig_bytes + polyz_len := polyz_packedbytes(sig.params) + + copy(sig_bytes, sig.c[:sig.params.ctild_bytes]) + sig_bytes = sig_bytes[sig.params.ctild_bytes:] + + for i in 0.. bool { + if len(sig_bytes) != signature_size(params) { + return false + } + + intrinsics.mem_zero(sig, size_of(Signature)) + + sig_bytes := sig_bytes + polyz_len := polyz_packedbytes(params) + omega := params.omega + + copy(sig.c[:], sig_bytes[:params.ctild_bytes]) + sig_bytes = sig_bytes[params.ctild_bytes:] + + for i in 0.. byte(omega) { + return false + } + + for j := k; j < int(sig_bytes[omega + i]); j += 1 { + // Coefficients are ordered for strong unforgeability + if j > k && sig_bytes[j] <= sig_bytes[j-1] { + return false + } + sig.h.vec[i].coeffs[sig_bytes[j]] = 1 + } + + k = int(sig_bytes[omega + i]) + } + + // Extra indices are zero for strong unforgeability + for j := k; j < omega; j += 1 { + if sig_bytes[j] != 0 { + return false + } + } + + sig.params = params + + return true +} + +@(private,require_results) +public_key_size :: #force_inline proc "contextless" (params: ^Params) -> int { + return SEEDBYTES + params.k * POLYT1_PACKEDBYTES +} + +@(private,require_results) +private_key_size :: #force_inline proc "contextless" (params: ^Params) -> int { + return 2*SEEDBYTES + TRBYTES + (params.l + params.k) * polyeta_packedbytes(params) + params.k * POLYT0_PACKEDBYTES +} + +@(private,require_results) +signature_size :: #force_inline proc "contextless" (params: ^Params) -> int { + return params.ctild_bytes + params.l * polyz_packedbytes(params) + polyvech_packedbytes(params) +} diff --git a/core/crypto/_mldsa/poly.odin b/core/crypto/_mldsa/poly.odin new file mode 100644 index 000000000..985a807d2 --- /dev/null +++ b/core/crypto/_mldsa/poly.odin @@ -0,0 +1,564 @@ +#+private +package _mldsa + +import "base:intrinsics" +import "core:crypto" +import "core:crypto/shake" + +Poly :: struct { + coeffs: [N]i32, +} + +poly_reduce :: proc "contextless" (a: ^Poly) { + for v, i in a.coeffs { + a.coeffs[i] = reduce32(v) + } +} + +poly_caddq :: proc "contextless" (a: ^Poly) { + for v, i in a.coeffs { + a.coeffs[i] = caddq(v) + } +} + +poly_add :: proc "contextless" (c, a, b: ^Poly) #no_bounds_check { + for i in 0.. uint #no_bounds_check { + s: uint + + for i in 0.. bool #no_bounds_check { + // It is ok to leak which coefficient violates the bound since + // the probability for each coefficient is independent of secret + // data but we must not leak the sign of the centralized + // representative. + for i in 0..> 31 + t = a.coeffs[i] - (t & 2 * a.coeffs[i]) + + if t >= bound { + return true + } + } + + return false +} + +unchecked_get_u24le :: #force_inline proc "contextless" (b: []byte) -> u32 #no_bounds_check { + r := u32(b[0]) + r |= u32(b[1]) << 8 + r |= u32(b[2]) << 16 + return r +} + +rej_uniform :: proc "contextless" (a: []i32, buf: []byte) -> int #no_bounds_check { + ctr, pos: int + + a_len, b_len := len(a), len(buf) + for ctr < a_len && pos + 3 <= b_len { + t := unchecked_get_u24le(buf[pos:]) + t &= 0x7FFFFF + pos += 3 + + if t < Q { + a[ctr] = i32(t) + ctr += 1 + } + } + + return ctr +} + +poly_uniform :: proc(a: ^Poly, seed: []byte, iv: u16) #no_bounds_check { + // Note/yawning: The dilithium reference code does something + // inexplicably more complicated, but this is identical in + // behavior, and simpler. + #assert(STREAM128_BLOCKBYTES % 3 == 0) + POLY_UNIFORM_NBLOCKS :: ((768 + STREAM128_BLOCKBYTES - 1)/STREAM128_BLOCKBYTES) + + buf: [POLY_UNIFORM_NBLOCKS*STREAM128_BLOCKBYTES]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + ctx: shake.Context = --- + defer shake.reset(&ctx) + stream128_init(&ctx, seed, iv) + + shake.read(&ctx, buf[:]) + ctr := rej_uniform(a.coeffs[:], buf[:]) + + b := buf[:STREAM128_BLOCKBYTES] + for ctr < N { + shake.read(&ctx, b) + ctr += rej_uniform(a.coeffs[ctr:], b) + } +} + +rej_eta :: proc "contextless" (a: []i32, buf: []byte, params: ^Params) -> int { + ctr, pos: int + a_len, b_len := len(a), len(buf) + switch params.eta { + case 2: + for ctr < a_len && pos < b_len { + t0 := u32(buf[pos] & 0x0F) + t1 := u32(buf[pos] >> 4) + pos += 1 + + if t0 < 15 { + t0 = t0 - (205 * t0 >> 10) * 5 + a[ctr] = i32(2 - t0) + ctr += 1 + } + if t1 < 15 && ctr < a_len { + t1 = t1 - (205 * t1 >> 10) * 5 + a[ctr] = i32(2 - t1) + ctr += 1 + } + } + case 4: + for ctr < a_len && pos < b_len { + t0 := u32(buf[pos] & 0x0F) + t1 := u32(buf[pos] >> 4) + pos += 1 + + if t0 < 9 { + a[ctr] = i32(4 - t0) + ctr += 1 + } + if t1 < 9 && ctr < a_len { + a[ctr] = i32(4 - t1) + ctr += 1 + } + } + case: + unreachable() + } + + return ctr +} + +poly_uniform_eta :: proc(a: ^Poly, seed: []byte, iv: u16, params: ^Params) { + POLY_UNIFORM_ETA2_NBLOCKS :: ((136 + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES) + POLY_UNIFORM_ETA4_NBLOCKS :: ((227 + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES) + + buf_: [POLY_UNIFORM_ETA4_NBLOCKS*STREAM256_BLOCKBYTES]byte = --- + buf: []byte + switch params.eta { + case 2: + buf = buf_[:POLY_UNIFORM_ETA2_NBLOCKS*STREAM256_BLOCKBYTES] + case 4: + buf = buf_[:POLY_UNIFORM_ETA4_NBLOCKS*STREAM256_BLOCKBYTES] + case: + unreachable() + } + defer crypto.zero_explicit(&buf_, size_of(buf_)) + + ctx: shake.Context = --- + defer shake.reset(&ctx) + stream256_init(&ctx, seed, iv) + + shake.read(&ctx, buf) + ctr := rej_eta(a.coeffs[:], buf, params) + + b := buf[:STREAM256_BLOCKBYTES] + for ctr < N { + shake.read(&ctx, b) + ctr += rej_eta(a.coeffs[ctr:], b, params) + } +} + +poly_uniform_gamma1 :: proc(a: ^Poly, seed: []byte, iv: u16, params: ^Params) { + POLY_UNIFORM_GAMMA1_NBLOCKS_MAX :: ((POLYZ_PACKEDBYTES_MAX + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES) + + n_blocks := (polyz_packedbytes(params) + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES + + buf_: [POLY_UNIFORM_GAMMA1_NBLOCKS_MAX*STREAM256_BLOCKBYTES]byte = --- + buf := buf_[:n_blocks*STREAM256_BLOCKBYTES] + defer crypto.zero_explicit(&buf_, size_of(buf_)) + + ctx: shake.Context = --- + defer shake.reset(&ctx) + stream256_init(&ctx, seed, iv) + + shake.read(&ctx, buf) + polyz_unpack(a, buf, params) +} + +poly_challenge :: proc(c: ^Poly, seed: []byte, params: ^Params) #no_bounds_check { + buf: [STREAM256_BLOCKBYTES]byte = --- + defer crypto.zero_explicit(&buf, size_of(buf)) + + ctx: shake.Context = --- + defer shake.reset(&ctx) + + shake.init_256(&ctx) + shake.write(&ctx, seed) + shake.read(&ctx, buf[:]) + + signs: u64 + for i in uint(0)..<8 { + signs |= u64(buf[i]) << (8*i) + } + pos := 8 + + b: int + intrinsics.mem_zero(c, size_of(Poly)) + for i := N - params.tau; i < N; i+= 1 { + for { + if pos >= STREAM256_BLOCKBYTES { + shake.read(&ctx, buf[:]) + pos = 0 + } + + b = int(buf[pos]) + pos += 1 + if b <= i { + break + } + } + + c.coeffs[i] = c.coeffs[b] + c.coeffs[b] = i32(1 - 2 * (signs & 1)) + signs >>= 1 + } +} + +polyeta_pack :: proc "contextless" (r: []byte, a: ^Poly, params: ^Params) #no_bounds_check { + t: [8]byte = --- + defer crypto.zero_explicit(&t, size_of(t)) + + eta := params.eta + switch eta { + case 2: + for i in 0..> 0) | (t[1] << 3) | (t[2] << 6) + r[3*i+1] = (t[2] >> 2) | (t[3] << 1) | (t[4] << 4) | (t[5] << 7) + r[3*i+2] = (t[5] >> 1) | (t[6] << 2) | (t[7] << 5) + } + case 4: + for i in 0..> 0) & 7) + r.coeffs[8*i+1] = i32((a[3*i+0] >> 3) & 7) + r.coeffs[8*i+2] = i32(((a[3*i+0] >> 6) | (a[3*i+1] << 2)) & 7) + r.coeffs[8*i+3] = i32((a[3*i+1] >> 1) & 7) + r.coeffs[8*i+4] = i32((a[3*i+1] >> 4) & 7) + r.coeffs[8*i+5] = i32(((a[3*i+1] >> 7) | (a[3*i+2] << 1)) & 7) + r.coeffs[8*i+6] = i32((a[3*i+2] >> 2) & 7) + r.coeffs[8*i+7] = i32((a[3*i+2] >> 5) & 7) + + r.coeffs[8*i+0] = eta - r.coeffs[8*i+0] + r.coeffs[8*i+1] = eta - r.coeffs[8*i+1] + r.coeffs[8*i+2] = eta - r.coeffs[8*i+2] + r.coeffs[8*i+3] = eta - r.coeffs[8*i+3] + r.coeffs[8*i+4] = eta - r.coeffs[8*i+4] + r.coeffs[8*i+5] = eta - r.coeffs[8*i+5] + r.coeffs[8*i+6] = eta - r.coeffs[8*i+6] + r.coeffs[8*i+7] = eta - r.coeffs[8*i+7] + } + case 4: + for i in 0..> 4) + r.coeffs[2*i+0] = eta - r.coeffs[2*i+0] + r.coeffs[2*i+1] = eta - r.coeffs[2*i+1] + } + case: + unreachable() + } +} + +polyt1_pack :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check { + for i in 0..> 0) + r[5*i+1] = byte((a.coeffs[4*i+0] >> 8) | (a.coeffs[4*i+1] << 2)) + r[5*i+2] = byte((a.coeffs[4*i+1] >> 6) | (a.coeffs[4*i+2] << 4)) + r[5*i+3] = byte((a.coeffs[4*i+2] >> 4) | (a.coeffs[4*i+3] << 6)) + r[5*i+4] = byte(a.coeffs[4*i+3] >> 2) + } +} + +polyt1_unpack :: proc "contextless" (r: ^Poly, a: []byte) #no_bounds_check { + for i in 0..> 0) | (u32(a[5*i+1]) << 8)) & 0x3FF) + r.coeffs[4*i+1] = i32((u32(a[5*i+1] >> 2) | (u32(a[5*i+2]) << 6)) & 0x3FF) + r.coeffs[4*i+2] = i32((u32(a[5*i+2] >> 4) | (u32(a[5*i+3]) << 4)) & 0x3FF) + r.coeffs[4*i+3] = i32((u32(a[5*i+3] >> 6) | (u32(a[5*i+4]) << 2)) & 0x3FF) + } +} + +polyt0_pack :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check { + t: [8]byte = --- + defer crypto.zero_explicit(&t, size_of(t)) + + for i in 0..> 8 + r[13*i+ 1] |= t[1] << 5 + r[13*i+ 2] = t[1] >> 3 + r[13*i+ 3] = t[1] >> 11 + r[13*i+ 3] |= t[2] << 2 + r[13*i+ 4] = t[2] >> 6 + r[13*i+ 4] |= t[3] << 7 + r[13*i+ 5] = t[3] >> 1 + r[13*i+ 6] = t[3] >> 9 + r[13*i+ 6] |= t[4] << 4 + r[13*i+ 7] = t[4] >> 4 + r[13*i+ 8] = t[4] >> 12 + r[13*i+ 8] |= t[5] << 1 + r[13*i+ 9] = t[5] >> 7 + r[13*i+ 9] |= t[6] << 6 + r[13*i+10] = t[6] >> 2 + r[13*i+11] = t[6] >> 10 + r[13*i+11] |= t[7] << 3 + r[13*i+12] = t[7] >> 5 + } +} + +polyt0_unpack :: proc "contextless" (r: ^Poly, a: []byte) #no_bounds_check { + for i in 0..> 5) + r.coeffs[8*i+1] |= i32(u32(a[13*i+2]) << 3) + r.coeffs[8*i+1] |= i32(u32(a[13*i+3]) << 11) + r.coeffs[8*i+1] &= 0x1FFF + + r.coeffs[8*i+2] = i32(a[13*i+3] >> 2) + r.coeffs[8*i+2] |= i32(u32(a[13*i+4]) << 6) + r.coeffs[8*i+2] &= 0x1FFF + + r.coeffs[8*i+3] = i32(a[13*i+4] >> 7) + r.coeffs[8*i+3] |= i32(u32(a[13*i+5]) << 1) + r.coeffs[8*i+3] |= i32(u32(a[13*i+6]) << 9) + r.coeffs[8*i+3] &= 0x1FFF + + r.coeffs[8*i+4] = i32(a[13*i+6] >> 4) + r.coeffs[8*i+4] |= i32(u32(a[13*i+7]) << 4) + r.coeffs[8*i+4] |= i32(u32(a[13*i+8]) << 12) + r.coeffs[8*i+4] &= 0x1FFF + + r.coeffs[8*i+5] = i32(a[13*i+8] >> 1) + r.coeffs[8*i+5] |= i32(u32(a[13*i+9]) << 7) + r.coeffs[8*i+5] &= 0x1FFF + + r.coeffs[8*i+6] = i32(a[13*i+9] >> 6) + r.coeffs[8*i+6] |= i32(u32(a[13*i+10]) << 2) + r.coeffs[8*i+6] |= i32(u32(a[13*i+11]) << 10) + r.coeffs[8*i+6] &= 0x1FFF + + r.coeffs[8*i+7] = i32(a[13*i+11] >> 3) + r.coeffs[8*i+7] |= i32(u32(a[13*i+12]) << 5) + r.coeffs[8*i+7] &= 0x1FFF + + r.coeffs[8*i+0] = (1 << (D-1)) - r.coeffs[8*i+0] + r.coeffs[8*i+1] = (1 << (D-1)) - r.coeffs[8*i+1] + r.coeffs[8*i+2] = (1 << (D-1)) - r.coeffs[8*i+2] + r.coeffs[8*i+3] = (1 << (D-1)) - r.coeffs[8*i+3] + r.coeffs[8*i+4] = (1 << (D-1)) - r.coeffs[8*i+4] + r.coeffs[8*i+5] = (1 << (D-1)) - r.coeffs[8*i+5] + r.coeffs[8*i+6] = (1 << (D-1)) - r.coeffs[8*i+6] + r.coeffs[8*i+7] = (1 << (D-1)) - r.coeffs[8*i+7] + } +} + +polyz_pack :: proc "contextless" (r: []byte, a: ^Poly, params: ^Params) #no_bounds_check { + t: [4]u32 = --- + defer crypto.zero_explicit(&t, size_of(t)) + + gamma1 := params.gamma1 + switch gamma1 { + case 1 << 17: + for i in 0..> 8) + r[9*i+2] = byte(t[0] >> 16) + r[9*i+2] |= byte(t[1] << 2) + r[9*i+3] = byte(t[1] >> 6) + r[9*i+4] = byte(t[1] >> 14) + r[9*i+4] |= byte(t[2] << 4) + r[9*i+5] = byte(t[2] >> 4) + r[9*i+6] = byte(t[2] >> 12) + r[9*i+6] |= byte(t[3] << 6) + r[9*i+7] = byte(t[3] >> 2) + r[9*i+8] = byte(t[3] >> 10) + } + case 1 << 19: + for i in 0..> 8) + r[5*i+2] = byte(t[0] >> 16) + r[5*i+2] |= byte(t[1] << 4) + r[5*i+3] = byte(t[1] >> 4) + r[5*i+4] = byte(t[1] >> 12) + } + case: + unreachable() + } +} + +polyz_unpack :: proc "contextless" (r: ^Poly, a: []byte, params: ^Params) #no_bounds_check { + gamma1 := params.gamma1 + switch gamma1 { + case 1 << 17: + for i in 0..> 2) + r.coeffs[4*i+1] |= i32(u32(a[9*i+3]) << 6) + r.coeffs[4*i+1] |= i32(u32(a[9*i+4]) << 14) + r.coeffs[4*i+1] &= 0x3FFFF + + r.coeffs[4*i+2] = i32(a[9*i+4] >> 4) + r.coeffs[4*i+2] |= i32(u32(a[9*i+5]) << 4) + r.coeffs[4*i+2] |= i32(u32(a[9*i+6]) << 12) + r.coeffs[4*i+2] &= 0x3FFFF + + r.coeffs[4*i+3] = i32(a[9*i+6] >> 6) + r.coeffs[4*i+3] |= i32(u32(a[9*i+7]) << 2) + r.coeffs[4*i+3] |= i32(u32(a[9*i+8]) << 10) + r.coeffs[4*i+3] &= 0x3FFFF + + r.coeffs[4*i+0] = gamma1 - r.coeffs[4*i+0] + r.coeffs[4*i+1] = gamma1 - r.coeffs[4*i+1] + r.coeffs[4*i+2] = gamma1 - r.coeffs[4*i+2] + r.coeffs[4*i+3] = gamma1 - r.coeffs[4*i+3] + } + case 1 << 19: + for i in 0..> 4) + r.coeffs[2*i+1] |= i32(u32(a[5*i+3]) << 4) + r.coeffs[2*i+1] |= i32(u32(a[5*i+4]) << 12) + /* r.coeffs[2*i+1] &= 0xFFFFF */ /* No effect, since we're anyway at 20 bits */ + + r.coeffs[2*i+0] = gamma1 - r.coeffs[2*i+0] + r.coeffs[2*i+1] = gamma1 - r.coeffs[2*i+1] + } + case: + unreachable() + } +} + +polyw1_pack :: proc "contextless" (r: []byte, a: ^Poly, params: ^Params) #no_bounds_check { + switch params.gamma2 { + case (Q-1)/88: + for i in 0..> 2) + r[3*i+1] |= byte(a.coeffs[4*i+2] << 4) + r[3*i+2] = byte(a.coeffs[4*i+2] >> 4) + r[3*i+2] |= byte(a.coeffs[4*i+3] << 2) + } + case (Q-1)/32: + for i in 0.. bool #no_bounds_check { + for i in 0.. bool #no_bounds_check { + for i in 0.. uint #no_bounds_check { + s: uint + + for i in 0.. i32 { + QINV :: 58728449 // q^(-1) mod 2^32 + + t := i32(i64(i32(a)) * QINV) + t = i32((a - i64(t) * Q) >> 32) + return t +} + +@(require_results) +reduce32 :: #force_inline proc "contextless" (a: i32) -> i32 { + t := (a + (1 << 22)) >> 23 + t = a - t * Q + return t +} + +@(require_results) +caddq :: #force_inline proc "contextless" (a: i32) -> i32 { + a := a + a += (a >> 31) & Q + return a +} + +// @(require_results) +// freeze :: #force_inline proc "contextless" (a: i32) -> i32 { +// a := a +// a = reduce32(a) +// a = caddq(a) +// return a +// } diff --git a/core/crypto/_mldsa/rounding.odin b/core/crypto/_mldsa/rounding.odin new file mode 100644 index 000000000..a879d6159 --- /dev/null +++ b/core/crypto/_mldsa/rounding.odin @@ -0,0 +1,56 @@ +#+private +package _mldsa + +power2round :: proc "contextless" (a: i32) -> (i32, i32) { + a1 := (a + (1 << (D-1)) - 1) >> D + a0 := a - (a1 << D) + return a0, a1 +} + +decompose :: proc "contextless" (a: i32, gamma2: i32) -> (i32, i32) { + a1 := (a + 127) >> 7 + switch gamma2 { + case (Q - 1)/32: + a1 = (a1 * 1025 + (1 << 21)) >> 22 + a1 &= 15 + case (Q - 1)/88: + a1 = (a1 * 11275 + (1 << 23)) >> 24 + a1 ~= ((43 - a1) >> 31) & a1 + } + + a0 := a - a1 * 2 * gamma2 + a0 -= (((Q - 1)/2 - a0) >> 31) & Q + return a0, a1 +} + +make_hint :: proc "contextless" (a0, a1: i32, gamma2: i32) -> uint { + if (a0 > gamma2 || a0 < -gamma2 || (a0 == -gamma2 && a1 != 0)) { + return 1 + } + return 0 +} + +use_hint :: proc "contextless" (a: i32, hint: uint, gamma2: i32) -> i32 { + a0, a1 := decompose(a, gamma2) + if hint == 0 { + return a1 + } + + switch gamma2 { + case (Q - 1)/32: + if (a0 > 0) { + return (a1 + 1) & 15 + } else { + return (a1 -1) & 15 + } + case (Q - 1)/88: + if (a0 > 0) { + return (a1 == 43) ? 0 : a1 + 1 + } else { + return (a1 == 0) ? 43 : a1 - 1 + } + } + + unreachable() +} + diff --git a/core/crypto/_mldsa/symmetric_shake.odin b/core/crypto/_mldsa/symmetric_shake.odin new file mode 100644 index 000000000..5c0aa99d6 --- /dev/null +++ b/core/crypto/_mldsa/symmetric_shake.odin @@ -0,0 +1,39 @@ +#+private +package _mldsa + +import "core:crypto/_sha3" +import "core:crypto/shake" + +STREAM128_BLOCKBYTES :: _sha3.RATE_128 +STREAM256_BLOCKBYTES :: _sha3.RATE_256 + +stream128_init :: proc(ctx: ^shake.Context, seed: []byte, iv: u16) { + t: [2]byte = --- + t[0] = byte(iv) + t[1] = byte(iv >> 8) + + shake.init_128(ctx) + shake.write(ctx, seed) + shake.write(ctx, t[:]) +} + +stream256_init :: proc(ctx: ^shake.Context, seed: []byte, iv: u16) { + t: [2]byte = --- + t[0] = byte(iv) + t[1] = byte(iv >> 8) + + shake.init_256(ctx) + shake.write(ctx, seed) + shake.write(ctx, t[:]) +} + +shake256 :: proc(dst: []byte, srcs: ..[]byte) { + ctx: shake.Context = --- + defer shake.reset(&ctx) + + shake.init_256(&ctx) + for src in srcs { + shake.write(&ctx, src) + } + shake.read(&ctx, dst) +} diff --git a/core/crypto/mldsa/api.odin b/core/crypto/mldsa/api.odin new file mode 100644 index 000000000..ce4f9cffd --- /dev/null +++ b/core/crypto/mldsa/api.odin @@ -0,0 +1,290 @@ +package mldsa + +import "core:crypto" +import "core:crypto/_mldsa" + +// Parameters are the supported ML-DSA parameter sets. +Parameters :: enum { + Invalid, + ML_DSA_44, + ML_DSA_65, + ML_DSA_87, +} + +// PRIVATE_KEY_SEED_SIZE is the size of a private key in bytes. +PRIVATE_KEY_SEED_SIZE :: _mldsa.SEEDBYTES // 32-bytes + +// MAX_CTX_SIZE is the maximum size of the signature context +// (domain separation tag) in bytes. +MAX_CTX_SIZE :: _mldsa.CTXBYTES_MAX // 255-bytes + +// PUBLIC_KEY_SIZES are the per-parameter sizes of a public +// key in bytes. +PUBLIC_KEY_SIZES := [Parameters]int { + .Invalid = 0, + .ML_DSA_44 = 1312, + .ML_DSA_65 = 1952, + .ML_DSA_87 = 2592, +} + +// SIGNATURE_SIZES are the per-parameter sizes of a signature +// in byte. +SIGNATURE_SIZES := [Parameters]int { + .Invalid = 0, + .ML_DSA_44 = 2420, + .ML_DSA_65 = 3309, + .ML_DSA_87 = 4627, +} + +@(private="file") +_PARAMS_TO_INTERNAL := [Parameters]^_mldsa.Params { + .Invalid = nil, + .ML_DSA_44 = &_mldsa.Params_44, + .ML_DSA_65 = &_mldsa.Params_65, + .ML_DSA_87 = &_mldsa.Params_87, +} + +// Private_Key is a ML-DSA private key. +Private_Key :: _mldsa.Private_Key + +// Public_Key is a ML-DSA public key. +Public_Key :: _mldsa.Public_Key + +// private_key_generate uses the system entropy source to generate a new +// Private_Key. This will only fail if and only if (⟺) the system entropy +// source is missing or broken. +@(require_results) +private_key_generate :: proc(priv_key: ^Private_Key, params: Parameters) -> bool { + private_key_clear(priv_key) + + if !crypto.HAS_RAND_BYTES { + return false + } + + params_ := _PARAMS_TO_INTERNAL[params] + if params_ == nil { + return false + } + + seed: [PRIVATE_KEY_SEED_SIZE]byte = --- + defer crypto.zero_explicit(&seed, size_of(seed)) + + crypto.rand_bytes(seed[:]) + + _mldsa.dsa_keygen_internal(priv_key, seed[:], params_) + + return true +} + +// private_key_set_bytes decodes a byte-encoded private key in "seed" format, +// and returns true if and only if (⟺) the operation was successful. +@(require_results) +private_key_set_bytes :: proc(priv_key: ^Private_Key, params: Parameters, b: []byte) -> bool { + private_key_clear(priv_key) + + params_ := _PARAMS_TO_INTERNAL[params] + if params_ == nil { + return false + } + if len(b) != PRIVATE_KEY_SEED_SIZE { + return false + } + + _mldsa.dsa_keygen_internal(priv_key, b, params_) + + return true +} + +// private_key_bytes sets dst to byte-encoding of priv_key in the "seed" +// format. +private_key_bytes :: proc(priv_key: ^Private_Key, dst: []byte) { + ensure(priv_key.params != nil, "crypto/mldsa: uninitialized private key") + ensure(len(dst) == PRIVATE_KEY_SEED_SIZE, "crypto/mldsa: invalid destination size") + + copy(dst, priv_key.seed[:]) +} + +// private_key_public_bytes sets dst to the byte-encoding of the public +// key corresponding to priv_key. +private_key_public_bytes :: proc(priv_key: ^Private_Key, dst: []byte) { + public_key_bytes(&priv_key.pub_key, dst) +} + +// private_key_set sets priv_key to src. +private_key_set :: proc(priv_key, src: ^Private_Key) { + if src == nil || internal_to_params(src.params) == .Invalid { + private_key_clear(priv_key) + return + } + + _mldsa.set_sk(priv_key, src) +} + +// private_key_equal returns true if and only if (⟺) the private keys are +// equal, in constant time. +@(require_results) +private_key_equal :: proc(p, q: ^Private_Key) -> bool { + if p.params != q.params { + return false + } + if p.params == nil { + return true + } + + // Just compare the seed that was passed to dsa_keygen_internal, + // since the process is completely deterministic. + return crypto.compare_constant_time(p.seed[:], q.seed[:]) == 1 +} + +// private_key_clear clears priv_key to the uninitialized state. +private_key_clear :: proc "contextless" (priv_key: ^Private_Key) { + _mldsa.clear_sk(priv_key) +} + +// public_key_set_bytes decodes a byte-encoded public key, and returns +// true if and only if (⟺) the operation was successful. +@(require_results) +public_key_set_bytes :: proc(pub_key: ^Public_Key, params: Parameters, b: []byte) -> bool { + params_ := _PARAMS_TO_INTERNAL[params] + if params_ == nil { + return false + } + + return _mldsa.unpack_pk(pub_key, b, params_) +} + +// public_key_set sets pub_key to src. +public_key_set :: proc(pub_key, src: ^Public_Key) { + if src == nil || internal_to_params(src.params) == .Invalid { + public_key_clear(pub_key) + return + } + + _mldsa.set_pk(pub_key, src) +} + +// public_key_set_priv sets pub_key to the public component of priv_key. +public_key_set_priv :: proc(pub_key: ^Public_Key, priv_key: ^Private_Key) { + ensure(priv_key.params != nil, "crypto/mldsa: uninitialized private key") + public_key_set(pub_key, &priv_key.pub_key) +} + +// public_key_bytes sets dst to byte-encoding of pub_key. +public_key_bytes :: proc(pub_key: ^Public_Key, dst: []byte) { + ensure(pub_key.params != nil, "crypto/mldsa: uninitialized public key") + params := internal_to_params(pub_key.params) + ensure(len(dst) == PUBLIC_KEY_SIZES[params], "crypto/mldsa: invalid destination size") + + _ = _mldsa.pack_pk(dst, pub_key) +} + +// public_key_equal returns true if and only if (⟺) the public keys are equal, +// in constant time. +@(require_results) +public_key_equal :: proc(p, q: ^Public_Key) -> bool { + if p.params != q.params { + return false + } + if p.params == nil { + return true + } + + // Comparing the pre-computed hash should be enough, but pack + // both public keys and do the comparisons. + PUBLIC_KEY_SIZE_MAX :: 2592 + + l := PUBLIC_KEY_SIZES[internal_to_params(p.params)] + p_buf_, q_buf_: [PUBLIC_KEY_SIZE_MAX]byte = ---, --- + p_buf, q_buf := p_buf_[:l], q_buf_[:l] + + _ = _mldsa.pack_pk(p_buf, p) + _ = _mldsa.pack_pk(q_buf, q) + + return crypto.compare_constant_time(p_buf, q_buf) == 1 +} + +// public_key_clear clears pub_key to the uninitialized state. +public_key_clear :: proc "contextless" (pub_key: ^Public_Key) { + _mldsa.clear_pk(pub_key) +} + +// sign writes the signature by priv_key over (ctx, msg) to sig and +// returns true if and only if (⟺) the signing succeeded. +// +// ctx is an optional domain separation tag and may be omitted (nil). +@(require_results) +sign :: proc(priv_key: ^Private_Key, ctx, msg, sig: []byte, deterministic := !crypto.HAS_RAND_BYTES) -> bool { + params := internal_to_params(priv_key.params) + ensure(params != .Invalid, "crypto/mldsa: invalid private key") + ensure(len(sig) == SIGNATURE_SIZES[params], "crypto/mldsa: invalid destination size") + + if !deterministic && !crypto.HAS_RAND_BYTES { + return false + } + if len(ctx) > MAX_CTX_SIZE { + return false + } + + rnd: [_mldsa.RNDBYTES]byte + defer crypto.zero_explicit(&rnd, size_of(rnd)) + + if !deterministic { + crypto.rand_bytes(rnd[:]) + } + + return _mldsa.dsa_sign_internal(sig, msg, ctx, rnd[:], priv_key) +} + +// verify returns true if and only if (⟺) sig is a valid signature by pub_key +// over (ctx, msg). +@(require_results) +verify :: proc(pub_key: ^Public_Key, ctx, msg, sig: []byte) -> bool { + params := internal_to_params(pub_key.params) + ensure(params != .Invalid, "crypto/mldsa: invalid public key") + + if len(sig) != SIGNATURE_SIZES[params] { + return false + } + if len(ctx) > MAX_CTX_SIZE { + return false + } + + return _mldsa.dsa_verify_internal(sig, msg, ctx, pub_key) +} + +// params returns the Parameters used by a Private_Key or Public_Key +// instance. +@(require_results) +params :: proc(k: ^$T) -> Parameters where (T == Private_Key || T == Public_Key) { + return internal_to_params(k.params) +} + +// key_size returns the key size of a Private_Key or Public_Key in bytes. +@(require_results) +key_size :: proc(k: ^$T) -> int where (T == Private_Key || T == Public_Key) { + when T == Private_Key { + return PRIVATE_KEY_SEED_SIZE + } else { + return PUBLIC_KEY_SIZES[internal_to_params(k.params)] + } +} + +// signature_size returns the key size of a signature in bytes. +@(require_results) +signature_size :: proc(k: ^$T) -> int where (T == Private_Key || T == Public_Key) { + return SIGNATURE_SIZES[internal_to_params(k.params)] +} + +@(private="file",require_results) +internal_to_params :: proc "contextless" (params: ^_mldsa.Params) -> Parameters { + switch params { + case &_mldsa.Params_44: + return .ML_DSA_44 + case &_mldsa.Params_65: + return .ML_DSA_65 + case &_mldsa.Params_87: + return .ML_DSA_87 + case: + return .Invalid + } +} diff --git a/core/crypto/mldsa/doc.odin b/core/crypto/mldsa/doc.odin new file mode 100644 index 000000000..4801c070e --- /dev/null +++ b/core/crypto/mldsa/doc.odin @@ -0,0 +1,7 @@ +/* +Module-Lattice-Based Digital Signature Algorithm. + +See: +- [[ https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.204.pdf ]] +*/ +package mldsa diff --git a/examples/all/all_js.odin b/examples/all/all_js.odin index f557cecdb..becc8f522 100644 --- a/examples/all/all_js.odin +++ b/examples/all/all_js.odin @@ -43,6 +43,7 @@ package all @(require) import "core:crypto/legacy/keccak" @(require) import "core:crypto/legacy/md5" @(require) import "core:crypto/legacy/sha1" +@(require) import "core:crypto/mldsa" @(require) import "core:crypto/mlkem" @(require) import cnoise "core:crypto/noise" @(require) import "core:crypto/pbkdf2" diff --git a/examples/all/all_main.odin b/examples/all/all_main.odin index 5565d225d..a35781338 100644 --- a/examples/all/all_main.odin +++ b/examples/all/all_main.odin @@ -48,6 +48,7 @@ package all @(require) import "core:crypto/legacy/keccak" @(require) import "core:crypto/legacy/md5" @(require) import "core:crypto/legacy/sha1" +@(require) import "core:crypto/mldsa" @(require) import "core:crypto/mlkem" @(require) import cnoise "core:crypto/noise" @(require) import "core:crypto/pbkdf2" diff --git a/tests/benchmark/crypto/benchmark_ecc.odin b/tests/benchmark/crypto/benchmark_ecc.odin index 95db33ab3..6315cedd9 100644 --- a/tests/benchmark/crypto/benchmark_ecc.odin +++ b/tests/benchmark/crypto/benchmark_ecc.odin @@ -18,8 +18,8 @@ import "core:crypto/hash" ECDH_ITERS :: 10000 @(private = "file") DSA_ITERS :: 10000 -@(private = "file") -MSG : string : "Got a job for you, 621." +@(private) +SIG_MSG : string : "Got a job for you, 621." @(test) benchmark_crypto_ecc :: proc(t: ^testing.T) { @@ -126,8 +126,8 @@ bench_dsa :: proc() { @(private = "file") bench_ed25519 :: proc() -> (sk, sig, verif: time.Duration) { - priv_str := "cafebabecafebabecafebabecafebabecafebabecafebabecafebabecafebabe" - priv_bytes, _ := hex.decode(transmute([]byte)(priv_str), context.temp_allocator) + SEED : string : "cafebabecafebabecafebabecafebabecafebabecafebabecafebabecafebabe" + priv_bytes, _ := hex.decode(transmute([]byte)(SEED), context.temp_allocator) priv_key: ed25519.Private_Key start := time.tick_now() for _ in 0 ..< DSA_ITERS { @@ -136,13 +136,11 @@ bench_ed25519 :: proc() -> (sk, sig, verif: time.Duration) { } sk = time.tick_since(start) / DSA_ITERS - pub_bytes := priv_key._pub_key._b[:] // "I know what I am doing" pub_key: ed25519.Public_Key - ok := ed25519.public_key_set_bytes(&pub_key, pub_bytes[:]) - assert(ok, "public key should deserialize") + ed25519.public_key_set_priv(&pub_key, &priv_key) sig_bytes: [ed25519.SIGNATURE_SIZE]byte - msg_bytes := transmute([]byte)(MSG) + msg_bytes := transmute([]byte)(SIG_MSG) start = time.tick_now() for _ in 0 ..< DSA_ITERS { ed25519.sign(&priv_key, msg_bytes, sig_bytes[:]) @@ -151,7 +149,7 @@ bench_ed25519 :: proc() -> (sk, sig, verif: time.Duration) { start = time.tick_now() for _ in 0 ..< DSA_ITERS { - ok = ed25519.verify(&pub_key, msg_bytes, sig_bytes[:]) + ok := ed25519.verify(&pub_key, msg_bytes, sig_bytes[:]) assert(ok, "signature should validate") } verif = time.tick_since(start) / DSA_ITERS @@ -175,7 +173,7 @@ bench_ecdsa :: proc(curve: ecdsa.Curve, hash: hash.Algorithm) -> (sk, sig, verif ecdsa.public_key_set_priv(&pub_key, &priv_key) sig_bytes := make([]byte, ecdsa.RAW_SIGNATURE_SIZES[curve], context.temp_allocator) - msg_bytes := transmute([]byte)(MSG) + msg_bytes := transmute([]byte)(SIG_MSG) start = time.tick_now() for _ in 0 ..< DSA_ITERS { ok := ecdsa.sign_raw(&priv_key, hash, msg_bytes, sig_bytes, true) diff --git a/tests/benchmark/crypto/benchmark_pqc.odin b/tests/benchmark/crypto/benchmark_pqc.odin index 314594ca5..20b2827fb 100644 --- a/tests/benchmark/crypto/benchmark_pqc.odin +++ b/tests/benchmark/crypto/benchmark_pqc.odin @@ -6,15 +6,19 @@ import "core:text/table" import "core:time" import "core:crypto" +import "core:crypto/mldsa" import "core:crypto/mlkem" @(private = "file") MLKEM_ITERS :: 50000 +@(private = "file") +MLDSA_ITERS :: 10000 @(test) benchmark_crypto_mlkem :: proc(t: ^testing.T) { if !crypto.HAS_RAND_BYTES { log.warnf("ML-KEM benchmarks skipped, no system entropy source") + return } tbl: table.Table @@ -75,6 +79,86 @@ benchmark_crypto_mlkem :: proc(t: ^testing.T) { log_table(&tbl) } +@(test) +bench_mldsa :: proc(t: ^testing.T) { + if !crypto.HAS_RAND_BYTES { + log.warnf("ML-DSA benchmarks skipped, no system entropy source") + return + } + + tbl: table.Table + table.init(&tbl) + defer table.destroy(&tbl) + + table.caption(&tbl, "ML-DSA") + table.aligned_header_of_values(&tbl, .Right, "Parameters", "Op", "Time") + + append_tbl := proc(tbl: ^table.Table, algo_name, op: string, t: time.Duration) { + table.aligned_row_of_values( + tbl, + .Right, + algo_name, + op, + table.format(tbl, "%8M", t), + ) + } + + do_bench := proc(params: mldsa.Parameters) -> (sk, sig, verif: time.Duration) { + // The time taken is highly seed dependent due to rejection + // sampling using SHAKE, so we hit up the system entropy source + // and crank up the iteration count. + priv_key: mldsa.Private_Key + start := time.tick_now() + for _ in 0 ..< MLDSA_ITERS*2 { + ok := mldsa.private_key_generate(&priv_key, params) + assert(ok, "private key should generate") + } + sk = time.tick_since(start) / (MLDSA_ITERS*2) + + pub_key: mldsa.Public_Key + mldsa.public_key_set_priv(&pub_key, &priv_key) + + msg_bytes := transmute([]byte)(SIG_MSG) + sig_bytes := make([]byte, mldsa.SIGNATURE_SIZES[params]) + defer delete(sig_bytes) + + // FIPS defaults to hedged mode with non-deterministic signatures. + start = time.tick_now() + for _ in 0 ..< MLDSA_ITERS { + ok := mldsa.sign(&priv_key, nil, msg_bytes, sig_bytes) + assert(ok, "signature should succeed") + } + sig = time.tick_since(start) / MLDSA_ITERS + + start = time.tick_now() + for _ in 0 ..< MLDSA_ITERS { + ok := mldsa.verify(&pub_key, nil, msg_bytes, sig_bytes) + assert(ok, "signature should validate") + } + verif = time.tick_since(start) / MLDSA_ITERS + + return + } + + for params in mldsa.Parameters { + if params == .Invalid { + continue + } + param_name := MLDSA_PARAMS_NAMES[params] + + sig, sk, verif := do_bench(params) + append_tbl(&tbl, param_name, "private_key_generate", sk) + append_tbl(&tbl, param_name, "sign", sig) + append_tbl(&tbl, param_name, "verify", verif) + + if params != .ML_DSA_87 { + table.row(&tbl) + } + } + + log_table(&tbl) +} + @(private="file") MLKEM_PARAMS_NAMES := [mlkem.Parameters]string { .Invalid = "invalid", @@ -82,3 +166,11 @@ MLKEM_PARAMS_NAMES := [mlkem.Parameters]string { .ML_KEM_768 = "ML-KEM-768", .ML_KEM_1024 = "ML-KEM-1024", } + +@(private="file") +MLDSA_PARAMS_NAMES := [mldsa.Parameters]string { + .Invalid = "invalid", + .ML_DSA_44 = "ML-DSA-44", + .ML_DSA_65 = "ML-DSA-65", + .ML_DSA_87 = "ML-DSA-87", +} diff --git a/tests/core/crypto/test_core_crypto_pqc.odin b/tests/core/crypto/test_core_crypto_pqc.odin index 8359053e4..f00e47b5a 100644 --- a/tests/core/crypto/test_core_crypto_pqc.odin +++ b/tests/core/crypto/test_core_crypto_pqc.odin @@ -1,10 +1,12 @@ package test_core_crypto import "core:bytes" +import "core:fmt" import "core:log" import "core:testing" import "core:crypto" +import "core:crypto/mldsa" import "core:crypto/mlkem" @(test) @@ -70,3 +72,88 @@ test_mlkem :: proc(t: ^testing.T) { ) } } + +@(test) +test_mldsa :: proc(t: ^testing.T) { + TEST_MSG : string : "ML-DSA test message" + msg_bytes := transmute([]byte)(TEST_MSG) + + // Test vectors are huge, and are covered by the wycheproof corpus, + // so do some casual tests. + for params in mldsa.Parameters { + if params == .Invalid { + continue + } + + seed: [mldsa.PRIVATE_KEY_SEED_SIZE]byte + fmt.bprintf(seed[:], "odin test - %v", params) + + priv_key: mldsa.Private_Key + if !testing.expectf( + t, + mldsa.private_key_set_bytes(&priv_key, params, seed[:]), + "%v: private_key_set_bytes", + params, + ) { + continue + } + + sig_det_bytes := make([]byte, mldsa.SIGNATURE_SIZES[params]) + defer delete(sig_det_bytes) + + if !testing.expectf( + t, + mldsa.sign(&priv_key, nil, msg_bytes, sig_det_bytes, true), + "%v: sign (deterministic)", + params, + ) { + continue + } + + pub_key: mldsa.Public_Key + mldsa.public_key_set_priv(&pub_key, &priv_key) + + if !testing.expectf( + t, + mldsa.verify(&pub_key, nil, msg_bytes, sig_det_bytes), + "%v: verify (deterministic)", + params, + ) { + continue + } + + if !crypto.HAS_RAND_BYTES { + continue + } + + sig_hedged_bytes := make([]byte, mldsa.SIGNATURE_SIZES[params]) + defer delete(sig_hedged_bytes) + + if !testing.expectf( + t, + mldsa.sign(&priv_key, nil, msg_bytes, sig_hedged_bytes), + "%v: sign (hedged)", + params, + ) { + continue + } + + if !testing.expectf( + t, + mldsa.verify(&pub_key, nil, msg_bytes, sig_hedged_bytes), + "%v: verify (hedged)", + params, + ) { + continue + } + + // False positive rate of 1/(2^256), assuming a functional + // entropy source. + testing.expectf( + t, + !bytes.equal(sig_det_bytes, sig_hedged_bytes), + "%v: deterministic sig should not equal hedged", + params, + ) + } +} diff --git a/tests/core/crypto/wycheproof/pqc.odin b/tests/core/crypto/wycheproof/pqc.odin index 210d7be1d..d485f621f 100644 --- a/tests/core/crypto/wycheproof/pqc.odin +++ b/tests/core/crypto/wycheproof/pqc.odin @@ -4,11 +4,15 @@ import "core:encoding/hex" import "core:log" import "core:mem" import "core:os" +import "core:slice" import "core:testing" import "core:crypto/_mlkem" import "core:crypto/mlkem" +import "core:crypto/_mldsa" +import "core:crypto/mldsa" + import "../common" @(test) @@ -81,7 +85,7 @@ test_mlkem :: proc(t: ^testing.T) { test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool { params_str := test_vectors.test_groups[0].parameter_set - params := parameter_set_to_params(params_str) + params := mlkem_parameter_set_to_params(params_str) if params == .Invalid { return false } @@ -93,19 +97,32 @@ test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr for &test_vector in test_group.tests { num_ran += 1 + if comment := test_vector.comment; comment != "" { + log.debugf( + "%s/KeyGen/%d/%d: %s: %+v", + params_str, + tg_id, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("%s/KeyGen/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags) + } + seed := common.hexbytes_decode(test_vector.seed) dk: mlkem.Decapsulation_Key if !testing.expectf( t, mlkem.decapsulation_key_set_bytes(&dk, params, seed), - "%s/KeyGen/%d/%d: failed to set decapsulation key from seed", + "%s/KeyGen/%d/%d: failed to set decapsulation key from seed: %s", params_str, tg_id, test_vector.tc_id, test_vector.seed, ) { - num_failed *= 1 + num_failed += 1 continue } @@ -184,7 +201,7 @@ test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool { params_str := test_vectors.test_groups[0].parameter_set - params := parameter_set_to_params(params_str) + params := mlkem_parameter_set_to_params(params_str) if params == .Invalid { return false } @@ -196,6 +213,19 @@ test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr for &test_vector in test_group.tests { num_ran += 1 + if comment := test_vector.comment; comment != "" { + log.debugf( + "%s/Encaps/%d/%d: %s: %+v", + params_str, + tg_id, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("%s/Encaps/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags) + } + ek: mlkem.Encapsulation_Key ok := mlkem.encapsulation_key_set_bytes( &ek, @@ -284,7 +314,7 @@ test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool { params_str := test_vectors.test_groups[0].parameter_set - params := parameter_set_to_params(params_str) + params := mlkem_parameter_set_to_params(params_str) if params == .Invalid { return false } @@ -296,6 +326,19 @@ test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr for &test_vector in test_group.tests { num_ran += 1 + if comment := test_vector.comment; comment != "" { + log.debugf( + "%s/Decaps/%d/%d: %s: %+v", + params_str, + tg_id, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("%s/Decaps/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags) + } + // We do not have an API for decaps with raw seed. seed := common.hexbytes_decode(test_vector.seed) switch len(seed) { @@ -386,8 +429,8 @@ test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr return num_failed == 0 } -@(require_results, private="file") -parameter_set_to_params :: proc(s: string) -> mlkem.Parameters { +@(require_results,private="file") +mlkem_parameter_set_to_params :: proc(s: string) -> mlkem.Parameters { switch s { case "ML-KEM-512": return .ML_KEM_512 @@ -399,3 +442,321 @@ parameter_set_to_params :: proc(s: string) -> mlkem.Parameters { return .Invalid } } + +@(test) +test_mldsa :: proc(t: ^testing.T) { + arena: mem.Arena + arena_backing := make([]byte, ARENA_SIZE) + defer delete(arena_backing) + mem.arena_init(&arena, arena_backing) + context.allocator = mem.arena_allocator(&arena) + + log.debug("mldsa: starting") + + files_sign := []string { + "mldsa_44_sign_seed_test.json", + "mldsa_65_sign_seed_test.json", + "mldsa_87_sign_seed_test.json", + } + for f in files_sign { + mem.free_all() + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Mldsa_Test_Group) + load_ok := load(&test_vectors, fn) + if !testing.expectf(t, load_ok, "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_mldsa_sign(t, &test_vectors), "ML-DSA Sign failed") + } + + files_verify := []string { + "mldsa_44_verify_test.json", + "mldsa_65_verify_test.json", + "mldsa_87_verify_test.json", + } + for f in files_verify { + mem.free_all() + + fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator) + + test_vectors: Test_Vectors(Mldsa_Test_Group) + load_ok := load(&test_vectors, fn) + if !testing.expectf(t, load_ok, "Unable to load {}", f) { + continue + } + + testing.expectf(t, test_mldsa_verify(t, &test_vectors), "ML-DSA Verify failed") + } +} + +test_mldsa_sign :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Mldsa_Test_Group)) -> bool { + FLAG_INTERNAL :: "Internal" + + dummy_rnd: [_mldsa.RNDBYTES]byte + + params_str := test_vectors.algorithm + params := mldsa_parameter_set_to_params(params_str) + if params == .Invalid { + return false + } + + log.debugf("%s: Sign starting", params_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, tg_id in test_vectors.test_groups { + seed := common.hexbytes_decode(test_group.private_seed) + priv_key: mldsa.Private_Key + + tg_len := len(test_group.tests) + if !testing.expectf( + t, + mldsa.private_key_set_bytes(&priv_key, params, seed), + "%s/Sign/%d: failed to set private key from seed: %s", + params_str, + tg_id, + test_group.private_seed, + ) { + num_ran += tg_len + num_failed += tg_len + continue + } + + pub_bytes := make([]byte, mldsa.PUBLIC_KEY_SIZES[params]) + mldsa.private_key_public_bytes(&priv_key, pub_bytes) + + ok := common.hexbytes_compare(test_group.public_key, pub_bytes) + if !ok { + x := transmute(string)(hex.encode(pub_bytes[:])) + log.errorf( + "%s/Sign/%d: public key: expected: %s actual: %s", + params_str, + tg_id, + test_group.public_key, + x, + ) + num_ran += tg_len + num_failed += tg_len + continue + } + + pub_key: mldsa.Public_Key + if !testing.expectf( + t, + mldsa.public_key_set_bytes(&pub_key, params, pub_bytes), + "%s/Sign/%d: failed to set public key", + params_str, + tg_id, + ) { + num_ran += tg_len + num_failed += tg_len + continue + } + + sig := make([]byte, mldsa.SIGNATURE_SIZES[params]) + for &test_vector in test_group.tests { + num_ran += 1 + + if comment := test_vector.comment; comment != "" { + log.debugf( + "%s/Sign/%d/%d: %s: %+v", + params_str, + tg_id, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("%s/Sign/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags) + } + + ctx := common.hexbytes_decode(test_vector.ctx) + msg := common.hexbytes_decode(test_vector.msg) + + is_external_mu := slice.contains(test_vector.flags, FLAG_INTERNAL) + switch is_external_mu { + case false: + ok = mldsa.sign( + &priv_key, + ctx, + msg, + sig, + true, + ) + case true: + ok = _mldsa.dsa_sign_internal( + sig, + msg, + ctx, + dummy_rnd[:], + &priv_key, + common.hexbytes_decode(test_vector.mu), + ) + } + if !result_check(test_vector.result, ok) { + log.errorf( + "%s/Sign/%d/%d: unexpected sign result: %v", + params_str, + tg_id, + test_vector.tc_id, + ok, + ) + num_failed += 1 + continue + } + if result_is_invalid(test_vector.result) { + num_passed += 1 + continue + } + + ok = common.hexbytes_compare(test_vector.sig, sig) + if !ok { + x := transmute(string)(hex.encode(sig)) + log.errorf( + "%s/Sign/%d/%d: sign: expected: %s actual: %s", + params_str, + tg_id, + test_vector.tc_id, + test_vector.sig, + x, + ) + num_failed += 1 + continue + } + + // Might as well verify as well if we have the ctx/msg. + if !is_external_mu { + if !testing.expectf( + t, + mldsa.verify(&pub_key, ctx, msg, sig), + "%s/Sign/%d/%d: verify failed", + params_str, + tg_id, + test_vector.tc_id, + ) { + num_failed += 1 + continue + } + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "%s/Sign: ran %d, passed %d, failed %d, skipped %d", + params_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +test_mldsa_verify:: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Mldsa_Test_Group)) -> bool { + params_str := test_vectors.algorithm + params := mldsa_parameter_set_to_params(params_str) + if params == .Invalid { + return false + } + + log.debugf("%s: Verify starting", params_str) + + num_ran, num_passed, num_failed, num_skipped: int + for &test_group, tg_id in test_vectors.test_groups { + tg_len := len(test_group.tests) + + pub_key_bytes := common.hexbytes_decode(test_group.public_key) + pub_key: mldsa.Public_Key + + expected := len(pub_key_bytes) == mldsa.PUBLIC_KEY_SIZES[params] + ok := mldsa.public_key_set_bytes(&pub_key, params, pub_key_bytes) + if !testing.expectf( + t, + ok == expected, + "%s/Verify/%d: failed to set public key", + params_str, + tg_id, + ) { + num_ran += tg_len + num_failed += tg_len + continue + } + if expected == false { + num_ran += tg_len + num_passed += tg_len + continue + } + + for &test_vector in test_group.tests { + if comment := test_vector.comment; comment != "" { + log.debugf( + "%s/Verify/%d/%d: %s: %+v", + params_str, + tg_id, + test_vector.tc_id, + comment, + test_vector.flags, + ) + } else { + log.debugf("%s/Verify/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags) + } + + num_ran += 1 + + ctx := common.hexbytes_decode(test_vector.ctx) + msg := common.hexbytes_decode(test_vector.msg) + sig := common.hexbytes_decode(test_vector.sig) + + ok = mldsa.verify(&pub_key, ctx, msg, sig) + if !result_check(test_vector.result, ok) { + log.errorf( + "%s/Verify/%d/%d: unexpected verify result: %v", + params_str, + tg_id, + test_vector.tc_id, + ok, + ) + num_failed += 1 + continue + } + + num_passed += 1 + } + } + + assert(num_ran == test_vectors.number_of_tests) + assert(num_passed + num_failed + num_skipped == num_ran) + + log.infof( + "%s/Verify: ran %d, passed %d, failed %d, skipped %d", + params_str, + num_ran, + num_passed, + num_failed, + num_skipped, + ) + + return num_failed == 0 +} + +@(require_results,private="file") +mldsa_parameter_set_to_params :: proc(s: string) -> mldsa.Parameters { + switch s { + case "ML-DSA-44": + return .ML_DSA_44 + case "ML-DSA-65": + return .ML_DSA_65 + case "ML-DSA-87": + return .ML_DSA_87 + case: + return .Invalid + } +} diff --git a/tests/core/crypto/wycheproof/schemas.odin b/tests/core/crypto/wycheproof/schemas.odin index d801207a0..36fc22d10 100644 --- a/tests/core/crypto/wycheproof/schemas.odin +++ b/tests/core/crypto/wycheproof/schemas.odin @@ -29,10 +29,6 @@ result_is_invalid :: proc(r: Result) -> bool { return r == "invalid" } - -// The type namings are not following Odin convention, to better match -// the schema, though the fields do. - load :: proc(tvs: ^$T/Test_Vectors, fn: string) -> bool { raw_json, err := os.read_entire_file_from_path(fn, context.allocator) if err != os.ERROR_NONE { @@ -223,3 +219,24 @@ Kem_Test_Vector :: struct { k: common.Hex_Bytes `json:"K"`, result: Result `json:"result"`, } + +Mldsa_Test_Group :: struct { + type: string `json:"type"`, + private_seed: common.Hex_Bytes `json:"privateSeed"`, + private_key_pkcs8: common.Hex_Bytes `json:"privateKeyPkcs8"`, + public_key: common.Hex_Bytes `json:"publicKey"`, + public_key_der: common.Hex_Bytes `json:"publicKeyDer"`, + source: Test_Group_Source `json:"source"`, + tests: []Mldsa_Test_Vector `json:"tests"`, +} + +Mldsa_Test_Vector :: struct { + tc_id: int `json:"tcId"`, + comment: string `json:"comment"`, + msg: common.Hex_Bytes `json:"msg"`, + ctx: common.Hex_Bytes `json:"ctx"`, + mu: common.Hex_Bytes `json:"mu"`, + sig: common.Hex_Bytes `json:"sig"`, + result: Result `json:"result"`, + flags: []string `json:"flags"`, +}