core/crypto/mldsa: Initial import

This commit is contained in:
Yawning Angel
2026-05-11 17:17:52 +09:00
parent 0c1c0372c7
commit ccc17780b4
18 changed files with 2527 additions and 21 deletions

View File

@@ -0,0 +1,72 @@
#+private
package _mldsa
CRHBYTES :: 64
TRBYTES :: 64
N :: 256
Q :: 8380417
D :: 13
K_MAX :: 8
L_MAX :: 7
POLYZ_PACKEDBYTES_MAX :: 640
POLYT1_PACKEDBYTES :: 320
POLYT0_PACKEDBYTES :: 416
POLYVECT1_PACKEDBYTES_MAX :: K_MAX * POLYT1_PACKEDBYTES
POLYW1_PACKEDBYTES_MAX :: 192
CTILDBYTES_MAX :: 64
@(require_results)
polyeta_packedbytes :: #force_inline proc "contextless" (params: ^Params) -> int {
POLYETA_PACKEDBYTES_2 :: 96
POLYETA_PACKEDBYTES_4 :: 128
switch params.eta {
case 2:
return POLYETA_PACKEDBYTES_2
case 4:
return POLYETA_PACKEDBYTES_4
case:
unreachable()
}
}
@(require_results)
polyz_packedbytes :: #force_inline proc "contextless" (params: ^Params) -> int {
POLYZ_PACKEDBYTES_GAMMA1_17 :: 576
POLYZ_PACKEDBYTES_GAMMA1_19 :: 640
switch params.gamma1 {
case 1 << 17:
return POLYZ_PACKEDBYTES_GAMMA1_17
case 1 << 19:
return POLYZ_PACKEDBYTES_GAMMA1_19
case:
unreachable()
}
}
@(require_results)
polyw1_packedbytes :: #force_inline proc "contextless" (params: ^Params) -> int {
POLYW1_PACKEDBYTES_GAMMA2_95232 :: 192
POLYW1_PACKEDBYTES_GAMMA2_261888 :: 128
switch params.gamma2 {
case (Q-1)/88:
return POLYW1_PACKEDBYTES_GAMMA2_95232
case (Q-1)/32:
return POLYW1_PACKEDBYTES_GAMMA2_261888
case:
unreachable()
}
}
@(require_results)
polyvech_packedbytes :: #force_inline proc "contextless" (params: ^Params) -> int {
return params.omega + params.k
}

View File

@@ -0,0 +1,394 @@
package _mldsa
import "core:crypto"
import "core:crypto/shake"
// This implementation is derived from the PQ-CRYSTALS reference
// implementation [[ https://github.com/pq-crystals/dilithium ]],
// primarily for licensing reasons. Arguably mldsa-native is
// a more "up to date" codebase, but the changes to the
// ref code is minor and they slapped an attribution-required
// license on something that was originally CC-0/Apache 2.0.
SEEDBYTES :: 32
RNDBYTES :: 32
CTXBYTES_MAX :: 255
Params :: struct {
k: int,
l: int,
eta: i32,
tau: int,
beta: i32,
gamma1: i32,
gamma2: i32,
omega: int,
ctild_bytes: int,
}
@(rodata)
Params_44 := Params{
k = 4,
l = 4,
eta = 2,
tau = 39,
beta = 78,
gamma1 = 1 << 17,
gamma2 = (Q-1)/88,
omega = 80,
ctild_bytes = 32,
}
@(rodata)
Params_65 := Params{
k = 6,
l = 5,
eta = 4,
tau = 49,
beta = 196,
gamma1 = 1 << 19,
gamma2 = (Q-1)/32,
omega = 55,
ctild_bytes = 48,
}
@(rodata)
Params_87 := Params{
k = 8,
l = 7,
eta = 2,
tau = 60,
beta = 120,
gamma1 = 1 << 19,
gamma2 = (Q-1)/32,
omega = 75,
ctild_bytes = 64,
}
Private_Key :: struct {
params: ^Params,
rho: [SEEDBYTES]byte,
tr: [TRBYTES]byte,
key: [SEEDBYTES]byte,
t0: Polyvec_K,
s1: Polyvec_L,
s2: Polyvec_K,
pub_key: Public_Key,
seed: [SEEDBYTES]byte,
}
Public_Key :: struct {
params: ^Params,
t1: Polyvec_K,
rho: [SEEDBYTES]byte,
mu: [TRBYTES]byte,
}
@(private)
Signature :: struct {
params: ^Params,
c: [CTILDBYTES_MAX]byte,
z: Polyvec_L,
h: Polyvec_K,
}
dsa_keygen_internal :: proc(
priv_key: ^Private_Key,
seed: []byte,
params: ^Params,
) {
ensure(len(seed) == SEEDBYTES, "crypto/mldsa: invalid seed")
pub_key := &priv_key.pub_key
pub_key.params = params
priv_key.params = params
copy(priv_key.seed[:], seed)
seedbuf: [2*SEEDBYTES + CRHBYTES]byte = ---
mat_: [K_MAX]Polyvec_L = ---
defer crypto.zero_explicit(&seedbuf, size_of(seedbuf))
defer crypto.zero_explicit(&mat_, size_of(mat_))
// Expand randomness for rho, rhoprime and key
copy(seedbuf[:], seed)
seedbuf[SEEDBYTES] = byte(params.k)
seedbuf[SEEDBYTES+1] = byte(params.l)
shake256(seedbuf[:], seedbuf[:SEEDBYTES+2])
copy(priv_key.rho[:], seedbuf[:SEEDBYTES])
rhoprime := seedbuf[SEEDBYTES:SEEDBYTES+CRHBYTES]
copy(priv_key.key[:], seedbuf[SEEDBYTES+CRHBYTES:])
// Expand matrix
mat := mat_[:params.k]
polyvec_matrix_expand(mat, priv_key.rho[:], params)
// Sample short vectors s1 and s2
polyvec_l_uniform_eta(&priv_key.s1, rhoprime, 0, params)
polyvec_k_uniform_eta(&priv_key.s2, rhoprime, u16(params.l), params)
// Matrix-vector multiplication
s1hat: Polyvec_L = ---
defer crypto.zero_explicit(&s1hat, size_of(Polyvec_L))
polyvec_copy(&s1hat, &priv_key.s1, params)
polyvec_l_ntt(&s1hat, params)
polyvec_matrix_pointwise_montgomery(&pub_key.t1, mat, &s1hat, params)
polyvec_k_reduce(&pub_key.t1, params)
polyvec_k_invntt_tomont(&pub_key.t1, params)
// Add error vector s2
polyvec_k_add(&pub_key.t1, &pub_key.t1, &priv_key.s2, params)
// Extract t1 and write public key
pk_bytes_: [SEEDBYTES+POLYVECT1_PACKEDBYTES_MAX]byte = ---
pk_bytes := pk_bytes_[:public_key_size(params)]
polyvec_k_caddq(&pub_key.t1, params)
polyvec_k_power2round(&pub_key.t1, &priv_key.t0, &pub_key.t1, params)
copy(pub_key.rho[:], priv_key.rho[:])
_ = pack_pk(pk_bytes, pub_key)
// Compute H(rho, t1) and write secret key
shake256(pub_key.mu[:], pk_bytes)
copy(priv_key.tr[:], pub_key.mu[:])
}
dsa_sign_internal :: proc(
sig_bytes: []byte,
m: []byte,
ctx: []byte,
rnd: []byte,
priv_key: ^Private_Key,
external_mu: []byte = nil
) -> bool {
params := priv_key.params
switch params {
case &Params_44, &Params_65, &Params_87:
case:
return false
}
if len(sig_bytes) != signature_size(params) {
return false
}
ensure(len(ctx) <= CTXBYTES_MAX, "crypto/mlkem: invalid contxt size")
ensure(len(rnd) == RNDBYTES, "crypto/mlkem: invalid rnd size")
mu, rhoprime: [CRHBYTES]byte = ---, ---
mat_: [K_MAX]Polyvec_L
w1_bytes_: [SEEDBYTES+POLYVECT1_PACKEDBYTES_MAX]byte = ---
s1, y: Polyvec_L = ---, ---
t0, s2, w1, w0: Polyvec_K = ---, ---, ---, ---
cp: Poly
polyvec_copy(&s1, &priv_key.s1, params)
polyvec_copy(&s2, &priv_key.s2, params)
polyvec_copy(&t0, &priv_key.t0, params)
defer crypto.zero_explicit(&mu, size_of(mu))
defer crypto.zero_explicit(&rhoprime, size_of(rhoprime))
defer crypto.zero_explicit(&mat_, size_of(mat_))
defer crypto.zero_explicit(&w1_bytes_, size_of(w1_bytes_))
defer polyvec_clear([]^Polyvec_L{&s1, &y})
defer polyvec_clear([]^Polyvec_K{&t0, &s2, &w1, &w0})
defer crypto.zero_explicit(&cp, size_of(cp))
sig: Signature = ---
sig.params = params
h := &sig.h
z := &sig.z
c := sig.c[:params.ctild_bytes]
w1_bytes := w1_bytes_[:params.k*polyw1_packedbytes(params)]
// Compute mu = CRH(tr, pre, msg)
if len(external_mu) == 0 {
// The FIPS publication handles the shake prefix
// in the public sign operation, but doing it
// here makes more sense.
ctx_buf: [2]byte
shake_ctx: shake.Context = ---
defer shake.reset(&shake_ctx)
ctx_len := len(ctx)
shake.init_256(&shake_ctx)
shake.write(&shake_ctx, priv_key.tr[:])
if ctx_len > 0 {
ctx_buf[1] = byte(ctx_len)
}
shake.write(&shake_ctx, ctx_buf[:])
if ctx_len > 0 {
shake.write(&shake_ctx, ctx)
}
shake.write(&shake_ctx, m)
shake.read(&shake_ctx, mu[:])
} else {
ensure(len(external_mu) == CRHBYTES, "crypto/mlkem: invalid external mu")
copy(mu[:], external_mu)
}
// Compute rhoprime = CRH(key, rnd, mu)
shake256(rhoprime[:], priv_key.key[:], rnd, mu[:])
// Expand matrix and transform vectors
mat := mat_[:params.k]
polyvec_matrix_expand(mat, priv_key.rho[:], params)
polyvec_l_ntt(&s1, params)
polyvec_k_ntt(&s2, params)
polyvec_k_ntt(&t0, params)
// Rejection-sampling loop
iv: u32 // ref uses u16, but ML-DSA-87 will reuse the IV at p = ~2^{-23400}
for {
// Sample intermediate vector y
polyvec_l_uniform_gamma1(&y, rhoprime[:], iv, params)
iv += 1
// Matrix-vector multiplication
polyvec_copy(z, &y, params)
polyvec_l_ntt(z, params)
polyvec_matrix_pointwise_montgomery(&w1, mat, z, params)
polyvec_k_reduce(&w1, params)
polyvec_k_invntt_tomont(&w1, params)
// Decompose w and call the random oracle
polyvec_k_caddq(&w1, params)
polyvec_k_decompose(&w1, &w0, &w1, params)
polyvec_k_pack_w1(w1_bytes, &w1, params)
shake256(c, mu[:], w1_bytes)
poly_challenge(&cp, c, params)
poly_ntt(&cp)
// Compute z, reject if it reveals secret
polyvec_l_pointwise_poly_montgomery(z, &cp, &s1, params)
polyvec_l_invntt_tomont(z, params)
polyvec_l_add(z, z, &y, params)
polyvec_l_reduce(z, params)
if polyvec_l_chknorm(z, params.gamma1 - params.beta, params) {
continue
}
// Check that subtracting cs2 does not change high bits of w
// and low bits do not reveal secret information
polyvec_k_pointwise_poly_montgomery(h, &cp, &s2, params)
polyvec_k_invntt_tomont(h, params)
polyvec_k_sub(&w0, &w0, h, params)
polyvec_k_reduce(&w0, params)
if polyvec_k_chknorm(&w0, params.gamma2 - params.beta, params) {
continue
}
// Compute hints for w1
polyvec_k_pointwise_poly_montgomery(h, &cp, &t0, params)
polyvec_k_invntt_tomont(h, params)
polyvec_k_reduce(h, params)
if polyvec_k_chknorm(h, params.gamma2, params) {
continue
}
polyvec_k_add(&w0, &w0, h, params)
n := polyvec_k_make_hint(h, &w0, &w1, params)
if n <= uint(params.omega) {
break
}
}
// Write signature
return pack_sig(sig_bytes, &sig)
}
dsa_verify_internal :: proc(
sig_bytes: []byte,
m: []byte,
ctx: []byte,
pub_key: ^Public_Key,
) -> bool {
ensure(len(ctx) <= CTXBYTES_MAX, "crypto/mlkem: invalid contxt size")
params := pub_key.params
switch params {
case &Params_44, &Params_65, &Params_87:
case:
return false
}
sig: Signature = ---
if !unpack_sig(&sig, sig_bytes, params) {
return false
}
if polyvec_l_chknorm(&sig.z, params.gamma1 - params.beta, params) {
return false
}
c := sig.c[:params.ctild_bytes]
z := &sig.z
h := &sig.h
t1: Polyvec_K = ---
polyvec_copy(&t1, &pub_key.t1, params)
rho := pub_key.rho[:]
// Compute CRH(H(rho, t1), pre, msg)
mu: [CRHBYTES]byte
{
// The FIPS publication handles the shake prefix
// in the public sign operation, but doing it
// here makes more sense.
ctx_buf: [2]byte
shake_ctx: shake.Context = ---
defer shake.reset(&shake_ctx)
ctx_len := len(ctx)
shake.init_256(&shake_ctx)
shake.write(&shake_ctx, pub_key.mu[:])
if ctx_len > 0 {
ctx_buf[1] = byte(ctx_len)
}
shake.write(&shake_ctx, ctx_buf[:])
if ctx_len > 0 {
shake.write(&shake_ctx, ctx)
}
shake.write(&shake_ctx, m)
shake.read(&shake_ctx, mu[:])
}
// Matrix-vector multiplication; compute Az - c2^dt1
mat_: [K_MAX]Polyvec_L
w1: Polyvec_K = ---
cp: Poly = ---
mat := mat_[:params.l]
poly_challenge(&cp, c, params)
polyvec_matrix_expand(mat, rho, params)
polyvec_l_ntt(z, params)
polyvec_matrix_pointwise_montgomery(&w1, mat, z, params)
poly_ntt(&cp)
polyvec_k_shiftl(&t1, params)
polyvec_k_ntt(&t1, params)
polyvec_k_pointwise_poly_montgomery(&t1, &cp, &t1, params)
polyvec_k_sub(&w1, &w1, &t1, params)
polyvec_k_reduce(&w1, params)
polyvec_k_invntt_tomont(&w1, params)
// Reconstruct w1
buf_: [K_MAX*POLYW1_PACKEDBYTES_MAX]byte = ---
buf := buf_[:params.k*polyw1_packedbytes(params)]
polyvec_k_caddq(&w1, params)
polyvec_k_use_hint(&w1, &w1, h, params)
polyvec_k_pack_w1(buf, &w1, params)
// Call random oracle and verify challenge
c2_: [CTILDBYTES_MAX]byte
c2 := c2_[:params.ctild_bytes]
shake256(c2, mu[:], buf)
// Note/perf: Can be vartime
return crypto.compare_constant_time(c, c2) == 1
}

View File

@@ -0,0 +1,75 @@
#+private
package _mldsa
@(rodata)
ZETAS := [N]i32 {
0, 25847, -2608894, -518909, 237124, -777960, -876248, 466468,
1826347, 2353451, -359251, -2091905, 3119733, -2884855, 3111497, 2680103,
2725464, 1024112, -1079900, 3585928, -549488, -1119584, 2619752, -2108549,
-2118186, -3859737, -1399561, -3277672, 1757237, -19422, 4010497, 280005,
2706023, 95776, 3077325, 3530437, -1661693, -3592148, -2537516, 3915439,
-3861115, -3043716, 3574422, -2867647, 3539968, -300467, 2348700, -539299,
-1699267, -1643818, 3505694, -3821735, 3507263, -2140649, -1600420, 3699596,
811944, 531354, 954230, 3881043, 3900724, -2556880, 2071892, -2797779,
-3930395, -1528703, -3677745, -3041255, -1452451, 3475950, 2176455, -1585221,
-1257611, 1939314, -4083598, -1000202, -3190144, -3157330, -3632928, 126922,
3412210, -983419, 2147896, 2715295, -2967645, -3693493, -411027, -2477047,
-671102, -1228525, -22981, -1308169, -381987, 1349076, 1852771, -1430430,
-3343383, 264944, 508951, 3097992, 44288, -1100098, 904516, 3958618,
-3724342, -8578, 1653064, -3249728, 2389356, -210977, 759969, -1316856,
189548, -3553272, 3159746, -1851402, -2409325, -177440, 1315589, 1341330,
1285669, -1584928, -812732, -1439742, -3019102, -3881060, -3628969, 3839961,
2091667, 3407706, 2316500, 3817976, -3342478, 2244091, -2446433, -3562462,
266997, 2434439, -1235728, 3513181, -3520352, -3759364, -1197226, -3193378,
900702, 1859098, 909542, 819034, 495491, -1613174, -43260, -522500,
-655327, -3122442, 2031748, 3207046, -3556995, -525098, -768622, -3595838,
342297, 286988, -2437823, 4108315, 3437287, -3342277, 1735879, 203044,
2842341, 2691481, -2590150, 1265009, 4055324, 1247620, 2486353, 1595974,
-3767016, 1250494, 2635921, -3548272, -2994039, 1869119, 1903435, -1050970,
-1333058, 1237275, -3318210, -1430225, -451100, 1312455, 3306115, -1962642,
-1279661, 1917081, -2546312, -1374803, 1500165, 777191, 2235880, 3406031,
-542412, -2831860, -1671176, -1846953, -2584293, -3724270, 594136, -3776993,
-2013608, 2432395, 2454455, -164721, 1957272, 3369112, 185531, -1207385,
-3183426, 162844, 1616392, 3014001, 810149, 1652634, -3694233, -1799107,
-3038916, 3523897, 3866901, 269760, 2213111, -975884, 1717735, 472078,
-426683, 1723600, -1803090, 1910376, -1667432, -1104333, -260646, -3833893,
-2939036, -2235985, -420899, -2286327, 183443, -976891, 1612842, -3545687,
-554416, 3919660, -48306, -1362209, 3937738, 1400424, -846154, 1976782,
}
ntt :: proc "contextless" (a: ^[N]i32) #no_bounds_check {
j, k := 0, 1
for l := 128; l > 0; l >>= 1 {
for start := 0; start < N; start = j + l {
zeta := ZETAS[k]
k += 1
for j = start; j < start + l; j += 1 {
t := montgomery_reduce(i64(zeta) * i64(a[j + l]))
a[j + l] = a[j] - t
a[j] = a[j] + t
}
}
}
}
invntt_tomont :: proc "contextless" (a: ^[N]i32) #no_bounds_check {
F :: 41978 // mont^2/256
j, k := 0, 255
for l := 1; l < N; l <<= 1 {
for start := 0; start < N; start = j + l {
zeta := -ZETAS[k]
k -= 1
for j = start; j < start + l; j += 1 {
t := a[j]
a[j] = t + a[j + l]
a[j + l] = t - a[j + l]
a[j + l] = montgomery_reduce(i64(zeta) * i64(a[j + l]))
}
}
}
for i in 0..<N {
a[i] = montgomery_reduce(F * i64(a[i]))
}
}

View File

@@ -0,0 +1,208 @@
package _mldsa
import "base:intrinsics"
import "core:crypto"
@(require_results)
pack_pk :: proc "contextless" (pk_bytes: []byte, pub_key: ^Public_Key) -> bool {
if len(pk_bytes) != public_key_size(pub_key.params) {
return false
}
seed_bytes, t1_bytes := pk_bytes[:SEEDBYTES], pk_bytes[SEEDBYTES:]
copy(seed_bytes, pub_key.rho[:])
for i in 0..<pub_key.params.k {
polyt1_pack(t1_bytes[i*POLYT1_PACKEDBYTES:], &pub_key.t1.vec[i])
}
return true
}
@(require_results)
unpack_pk :: proc(pub_key: ^Public_Key, pk_bytes: []byte, params: ^Params) -> bool {
if len(pk_bytes) != public_key_size(params) {
return false
}
seed_bytes, t1_bytes := pk_bytes[:SEEDBYTES], pk_bytes[SEEDBYTES:]
pub_key.params = params
copy(pub_key.rho[:], seed_bytes)
for i in 0..<params.k {
polyt1_unpack(&pub_key.t1.vec[i], t1_bytes[i*POLYT1_PACKEDBYTES:])
}
shake256(pub_key.mu[:], pk_bytes)
return true
}
set_pk :: proc(dst, src: ^Public_Key) {
dst.params = src.params
polyvec_copy(&dst.t1, &src.t1, src.params)
copy(dst.rho[:], src.rho[:])
copy(dst.mu[:], src.mu[:])
}
clear_pk :: proc "contextless" (pub_key: ^Public_Key) {
crypto.zero_explicit(pub_key, size_of(Public_Key))
}
@(require_results)
pack_sk :: proc "contextless" (sk_bytes: []byte, priv_key: ^Private_Key) -> bool {
params := priv_key.params
if len(sk_bytes) != private_key_size(params) {
return false
}
sk_bytes := sk_bytes
polyeta_len := polyeta_packedbytes(params)
copy(sk_bytes, priv_key.rho[:])
sk_bytes = sk_bytes[SEEDBYTES:]
copy(sk_bytes, priv_key.key[:])
sk_bytes = sk_bytes[SEEDBYTES:]
copy(sk_bytes, priv_key.tr[:])
sk_bytes = sk_bytes[TRBYTES:]
for i in 0..<params.l {
polyeta_pack(sk_bytes[i*polyeta_len:], &priv_key.s1.vec[i], params)
}
sk_bytes = sk_bytes[polyeta_len*params.l:]
for i in 0..<params.k {
polyeta_pack(sk_bytes[i*polyeta_len:], &priv_key.s2.vec[i], params)
}
sk_bytes = sk_bytes[polyeta_len*params.k:]
for i in 0..<params.k {
polyt1_pack(sk_bytes[i*POLYT1_PACKEDBYTES:], &priv_key.t0.vec[i])
}
return true
}
set_sk :: proc(dst, src: ^Private_Key) {
dst.params = src.params
copy(dst.rho[:], src.rho[:])
copy(dst.tr[:], src.tr[:])
copy(dst.key[:], src.key[:])
polyvec_copy(&dst.t0, &src.t0, src.params)
polyvec_copy(&dst.s1, &src.s1, src.params)
polyvec_copy(&dst.s2, &src.s2, src.params)
set_pk(&dst.pub_key, &src.pub_key)
copy(dst.seed[:], src.seed[:])
}
clear_sk :: proc "contextless" (priv_key: ^Private_Key) {
crypto.zero_explicit(priv_key, size_of(Private_Key))
}
@(private,require_results)
pack_sig :: proc "contextless" (sig_bytes: []byte, sig: ^Signature) -> bool {
if len(sig_bytes) != signature_size(sig.params) {
return false
}
sig_bytes := sig_bytes
polyz_len := polyz_packedbytes(sig.params)
copy(sig_bytes, sig.c[:sig.params.ctild_bytes])
sig_bytes = sig_bytes[sig.params.ctild_bytes:]
for i in 0..<sig.params.l {
polyz_pack(sig_bytes[i*polyz_len:], &sig.z.vec[i], sig.params)
}
sig_bytes = sig_bytes[sig.params.l*polyz_len:]
intrinsics.mem_zero(raw_data(sig_bytes), len(sig_bytes))
k: int
for i in 0..<sig.params.k {
for j in 0..<N {
if sig.h.vec[i].coeffs[j] != 0 {
sig_bytes[k] = byte(j)
k += 1
}
}
sig_bytes[sig.params.omega + i] = byte(k)
}
return true
}
@(private,require_results)
unpack_sig :: proc "contextless" (sig: ^Signature, sig_bytes: []byte, params: ^Params) -> bool {
if len(sig_bytes) != signature_size(params) {
return false
}
intrinsics.mem_zero(sig, size_of(Signature))
sig_bytes := sig_bytes
polyz_len := polyz_packedbytes(params)
omega := params.omega
copy(sig.c[:], sig_bytes[:params.ctild_bytes])
sig_bytes = sig_bytes[params.ctild_bytes:]
for i in 0..<params.l {
polyz_unpack(&sig.z.vec[i], sig_bytes[i*polyz_len:], params)
}
sig_bytes = sig_bytes[params.l*polyz_len:]
// Decode h
k: int
for i in 0..<params.k {
if sig_bytes[omega + i] < byte(k) || sig_bytes[omega + i] > byte(omega) {
return false
}
for j := k; j < int(sig_bytes[omega + i]); j += 1 {
// Coefficients are ordered for strong unforgeability
if j > k && sig_bytes[j] <= sig_bytes[j-1] {
return false
}
sig.h.vec[i].coeffs[sig_bytes[j]] = 1
}
k = int(sig_bytes[omega + i])
}
// Extra indices are zero for strong unforgeability
for j := k; j < omega; j += 1 {
if sig_bytes[j] != 0 {
return false
}
}
sig.params = params
return true
}
@(private,require_results)
public_key_size :: #force_inline proc "contextless" (params: ^Params) -> int {
return SEEDBYTES + params.k * POLYT1_PACKEDBYTES
}
@(private,require_results)
private_key_size :: #force_inline proc "contextless" (params: ^Params) -> int {
return 2*SEEDBYTES + TRBYTES + (params.l + params.k) * polyeta_packedbytes(params) + params.k * POLYT0_PACKEDBYTES
}
@(private,require_results)
signature_size :: #force_inline proc "contextless" (params: ^Params) -> int {
return params.ctild_bytes + params.l * polyz_packedbytes(params) + polyvech_packedbytes(params)
}

View File

@@ -0,0 +1,564 @@
#+private
package _mldsa
import "base:intrinsics"
import "core:crypto"
import "core:crypto/shake"
Poly :: struct {
coeffs: [N]i32,
}
poly_reduce :: proc "contextless" (a: ^Poly) {
for v, i in a.coeffs {
a.coeffs[i] = reduce32(v)
}
}
poly_caddq :: proc "contextless" (a: ^Poly) {
for v, i in a.coeffs {
a.coeffs[i] = caddq(v)
}
}
poly_add :: proc "contextless" (c, a, b: ^Poly) #no_bounds_check {
for i in 0..<N {
c.coeffs[i] = a.coeffs[i] + b.coeffs[i]
}
}
poly_sub :: proc "contextless" (c, a, b: ^Poly) #no_bounds_check {
for i in 0..<N {
c.coeffs[i] = a.coeffs[i] - b.coeffs[i]
}
}
poly_shiftl :: proc "contextless" (a: ^Poly) {
for i in 0..<N {
a.coeffs[i] <<= D
}
}
poly_ntt :: proc "contextless" (a: ^Poly) {
ntt(&a.coeffs)
}
poly_invntt_tomont :: proc "contextless" (a: ^Poly) {
invntt_tomont(&a.coeffs)
}
poly_pointwise_montgomery :: proc "contextless" (c, a, b: ^Poly) #no_bounds_check {
for i in 0..<N {
c.coeffs[i] = montgomery_reduce(i64(a.coeffs[i]) * i64(b.coeffs[i]))
}
}
poly_power2round :: proc "contextless" (a1, a0, a: ^Poly) #no_bounds_check {
for i in 0..<N {
a0.coeffs[i], a1.coeffs[i] = power2round(a.coeffs[i])
}
}
poly_decompose :: proc "contextless" (a1, a0, a: ^Poly, params: ^Params) #no_bounds_check {
for i in 0..<N {
a0.coeffs[i], a1.coeffs[i] = decompose(a.coeffs[i], params.gamma2)
}
}
poly_make_hint :: proc "contextless" (h, a0, a1: ^Poly, params: ^Params) -> uint #no_bounds_check {
s: uint
for i in 0..<N {
h.coeffs[i] = i32(make_hint(a0.coeffs[i], a1.coeffs[i], params.gamma2))
s += uint(h.coeffs[i])
}
return s
}
poly_use_hint :: proc "contextless" (b, a, h: ^Poly, params: ^Params) {
for i in 0..<N {
b.coeffs[i] = use_hint(a.coeffs[i], uint(h.coeffs[i]), params.gamma2)
}
}
poly_chknorm :: proc "contextless" (a: ^Poly, bound: i32) -> bool #no_bounds_check {
// It is ok to leak which coefficient violates the bound since
// the probability for each coefficient is independent of secret
// data but we must not leak the sign of the centralized
// representative.
for i in 0..<N {
// Absolute value
t := a.coeffs[i] >> 31
t = a.coeffs[i] - (t & 2 * a.coeffs[i])
if t >= bound {
return true
}
}
return false
}
unchecked_get_u24le :: #force_inline proc "contextless" (b: []byte) -> u32 #no_bounds_check {
r := u32(b[0])
r |= u32(b[1]) << 8
r |= u32(b[2]) << 16
return r
}
rej_uniform :: proc "contextless" (a: []i32, buf: []byte) -> int #no_bounds_check {
ctr, pos: int
a_len, b_len := len(a), len(buf)
for ctr < a_len && pos + 3 <= b_len {
t := unchecked_get_u24le(buf[pos:])
t &= 0x7FFFFF
pos += 3
if t < Q {
a[ctr] = i32(t)
ctr += 1
}
}
return ctr
}
poly_uniform :: proc(a: ^Poly, seed: []byte, iv: u16) #no_bounds_check {
// Note/yawning: The dilithium reference code does something
// inexplicably more complicated, but this is identical in
// behavior, and simpler.
#assert(STREAM128_BLOCKBYTES % 3 == 0)
POLY_UNIFORM_NBLOCKS :: ((768 + STREAM128_BLOCKBYTES - 1)/STREAM128_BLOCKBYTES)
buf: [POLY_UNIFORM_NBLOCKS*STREAM128_BLOCKBYTES]byte = ---
defer crypto.zero_explicit(&buf, size_of(buf))
ctx: shake.Context = ---
defer shake.reset(&ctx)
stream128_init(&ctx, seed, iv)
shake.read(&ctx, buf[:])
ctr := rej_uniform(a.coeffs[:], buf[:])
b := buf[:STREAM128_BLOCKBYTES]
for ctr < N {
shake.read(&ctx, b)
ctr += rej_uniform(a.coeffs[ctr:], b)
}
}
rej_eta :: proc "contextless" (a: []i32, buf: []byte, params: ^Params) -> int {
ctr, pos: int
a_len, b_len := len(a), len(buf)
switch params.eta {
case 2:
for ctr < a_len && pos < b_len {
t0 := u32(buf[pos] & 0x0F)
t1 := u32(buf[pos] >> 4)
pos += 1
if t0 < 15 {
t0 = t0 - (205 * t0 >> 10) * 5
a[ctr] = i32(2 - t0)
ctr += 1
}
if t1 < 15 && ctr < a_len {
t1 = t1 - (205 * t1 >> 10) * 5
a[ctr] = i32(2 - t1)
ctr += 1
}
}
case 4:
for ctr < a_len && pos < b_len {
t0 := u32(buf[pos] & 0x0F)
t1 := u32(buf[pos] >> 4)
pos += 1
if t0 < 9 {
a[ctr] = i32(4 - t0)
ctr += 1
}
if t1 < 9 && ctr < a_len {
a[ctr] = i32(4 - t1)
ctr += 1
}
}
case:
unreachable()
}
return ctr
}
poly_uniform_eta :: proc(a: ^Poly, seed: []byte, iv: u16, params: ^Params) {
POLY_UNIFORM_ETA2_NBLOCKS :: ((136 + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES)
POLY_UNIFORM_ETA4_NBLOCKS :: ((227 + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES)
buf_: [POLY_UNIFORM_ETA4_NBLOCKS*STREAM256_BLOCKBYTES]byte = ---
buf: []byte
switch params.eta {
case 2:
buf = buf_[:POLY_UNIFORM_ETA2_NBLOCKS*STREAM256_BLOCKBYTES]
case 4:
buf = buf_[:POLY_UNIFORM_ETA4_NBLOCKS*STREAM256_BLOCKBYTES]
case:
unreachable()
}
defer crypto.zero_explicit(&buf_, size_of(buf_))
ctx: shake.Context = ---
defer shake.reset(&ctx)
stream256_init(&ctx, seed, iv)
shake.read(&ctx, buf)
ctr := rej_eta(a.coeffs[:], buf, params)
b := buf[:STREAM256_BLOCKBYTES]
for ctr < N {
shake.read(&ctx, b)
ctr += rej_eta(a.coeffs[ctr:], b, params)
}
}
poly_uniform_gamma1 :: proc(a: ^Poly, seed: []byte, iv: u16, params: ^Params) {
POLY_UNIFORM_GAMMA1_NBLOCKS_MAX :: ((POLYZ_PACKEDBYTES_MAX + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES)
n_blocks := (polyz_packedbytes(params) + STREAM256_BLOCKBYTES - 1)/STREAM256_BLOCKBYTES
buf_: [POLY_UNIFORM_GAMMA1_NBLOCKS_MAX*STREAM256_BLOCKBYTES]byte = ---
buf := buf_[:n_blocks*STREAM256_BLOCKBYTES]
defer crypto.zero_explicit(&buf_, size_of(buf_))
ctx: shake.Context = ---
defer shake.reset(&ctx)
stream256_init(&ctx, seed, iv)
shake.read(&ctx, buf)
polyz_unpack(a, buf, params)
}
poly_challenge :: proc(c: ^Poly, seed: []byte, params: ^Params) #no_bounds_check {
buf: [STREAM256_BLOCKBYTES]byte = ---
defer crypto.zero_explicit(&buf, size_of(buf))
ctx: shake.Context = ---
defer shake.reset(&ctx)
shake.init_256(&ctx)
shake.write(&ctx, seed)
shake.read(&ctx, buf[:])
signs: u64
for i in uint(0)..<8 {
signs |= u64(buf[i]) << (8*i)
}
pos := 8
b: int
intrinsics.mem_zero(c, size_of(Poly))
for i := N - params.tau; i < N; i+= 1 {
for {
if pos >= STREAM256_BLOCKBYTES {
shake.read(&ctx, buf[:])
pos = 0
}
b = int(buf[pos])
pos += 1
if b <= i {
break
}
}
c.coeffs[i] = c.coeffs[b]
c.coeffs[b] = i32(1 - 2 * (signs & 1))
signs >>= 1
}
}
polyeta_pack :: proc "contextless" (r: []byte, a: ^Poly, params: ^Params) #no_bounds_check {
t: [8]byte = ---
defer crypto.zero_explicit(&t, size_of(t))
eta := params.eta
switch eta {
case 2:
for i in 0..<N/8 {
t[0] = byte(eta - a.coeffs[8*i+0])
t[1] = byte(eta - a.coeffs[8*i+1])
t[2] = byte(eta - a.coeffs[8*i+2])
t[3] = byte(eta - a.coeffs[8*i+3])
t[4] = byte(eta - a.coeffs[8*i+4])
t[5] = byte(eta - a.coeffs[8*i+5])
t[6] = byte(eta - a.coeffs[8*i+6])
t[7] = byte(eta - a.coeffs[8*i+7])
r[3*i+0] = (t[0] >> 0) | (t[1] << 3) | (t[2] << 6)
r[3*i+1] = (t[2] >> 2) | (t[3] << 1) | (t[4] << 4) | (t[5] << 7)
r[3*i+2] = (t[5] >> 1) | (t[6] << 2) | (t[7] << 5)
}
case 4:
for i in 0..<N/2 {
t[0] = byte(eta - a.coeffs[2*i+0])
t[1] = byte(eta - a.coeffs[2*i+1])
r[i] = t[0] | (t[1] << 4)
}
case:
unreachable()
}
}
polyeta_unpack :: proc "contextless" (r: ^Poly, a: []byte, params: ^Params) #no_bounds_check {
eta := params.eta
switch eta {
case 2:
for i in 0..<N/8 {
r.coeffs[8*i+0] = i32((a[3*i+0] >> 0) & 7)
r.coeffs[8*i+1] = i32((a[3*i+0] >> 3) & 7)
r.coeffs[8*i+2] = i32(((a[3*i+0] >> 6) | (a[3*i+1] << 2)) & 7)
r.coeffs[8*i+3] = i32((a[3*i+1] >> 1) & 7)
r.coeffs[8*i+4] = i32((a[3*i+1] >> 4) & 7)
r.coeffs[8*i+5] = i32(((a[3*i+1] >> 7) | (a[3*i+2] << 1)) & 7)
r.coeffs[8*i+6] = i32((a[3*i+2] >> 2) & 7)
r.coeffs[8*i+7] = i32((a[3*i+2] >> 5) & 7)
r.coeffs[8*i+0] = eta - r.coeffs[8*i+0]
r.coeffs[8*i+1] = eta - r.coeffs[8*i+1]
r.coeffs[8*i+2] = eta - r.coeffs[8*i+2]
r.coeffs[8*i+3] = eta - r.coeffs[8*i+3]
r.coeffs[8*i+4] = eta - r.coeffs[8*i+4]
r.coeffs[8*i+5] = eta - r.coeffs[8*i+5]
r.coeffs[8*i+6] = eta - r.coeffs[8*i+6]
r.coeffs[8*i+7] = eta - r.coeffs[8*i+7]
}
case 4:
for i in 0..<N/2 {
r.coeffs[2*i+0] = i32(a[i] & 0x0F)
r.coeffs[2*i+1] = i32(a[i] >> 4)
r.coeffs[2*i+0] = eta - r.coeffs[2*i+0]
r.coeffs[2*i+1] = eta - r.coeffs[2*i+1]
}
case:
unreachable()
}
}
polyt1_pack :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check {
for i in 0..<N/4 {
r[5*i+0] = byte(a.coeffs[4*i+0] >> 0)
r[5*i+1] = byte((a.coeffs[4*i+0] >> 8) | (a.coeffs[4*i+1] << 2))
r[5*i+2] = byte((a.coeffs[4*i+1] >> 6) | (a.coeffs[4*i+2] << 4))
r[5*i+3] = byte((a.coeffs[4*i+2] >> 4) | (a.coeffs[4*i+3] << 6))
r[5*i+4] = byte(a.coeffs[4*i+3] >> 2)
}
}
polyt1_unpack :: proc "contextless" (r: ^Poly, a: []byte) #no_bounds_check {
for i in 0..<N/4 {
r.coeffs[4*i+0] = i32((u32(a[5*i+0] >> 0) | (u32(a[5*i+1]) << 8)) & 0x3FF)
r.coeffs[4*i+1] = i32((u32(a[5*i+1] >> 2) | (u32(a[5*i+2]) << 6)) & 0x3FF)
r.coeffs[4*i+2] = i32((u32(a[5*i+2] >> 4) | (u32(a[5*i+3]) << 4)) & 0x3FF)
r.coeffs[4*i+3] = i32((u32(a[5*i+3] >> 6) | (u32(a[5*i+4]) << 2)) & 0x3FF)
}
}
polyt0_pack :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check {
t: [8]byte = ---
defer crypto.zero_explicit(&t, size_of(t))
for i in 0..<N/8 {
t[0] = byte((1 << (D-1)) - a.coeffs[8*i+0])
t[1] = byte((1 << (D-1)) - a.coeffs[8*i+1])
t[2] = byte((1 << (D-1)) - a.coeffs[8*i+2])
t[3] = byte((1 << (D-1)) - a.coeffs[8*i+3])
t[4] = byte((1 << (D-1)) - a.coeffs[8*i+4])
t[5] = byte((1 << (D-1)) - a.coeffs[8*i+5])
t[6] = byte((1 << (D-1)) - a.coeffs[8*i+6])
t[7] = byte((1 << (D-1)) - a.coeffs[8*i+7])
r[13*i+ 0] = t[0]
r[13*i+ 1] = t[0] >> 8
r[13*i+ 1] |= t[1] << 5
r[13*i+ 2] = t[1] >> 3
r[13*i+ 3] = t[1] >> 11
r[13*i+ 3] |= t[2] << 2
r[13*i+ 4] = t[2] >> 6
r[13*i+ 4] |= t[3] << 7
r[13*i+ 5] = t[3] >> 1
r[13*i+ 6] = t[3] >> 9
r[13*i+ 6] |= t[4] << 4
r[13*i+ 7] = t[4] >> 4
r[13*i+ 8] = t[4] >> 12
r[13*i+ 8] |= t[5] << 1
r[13*i+ 9] = t[5] >> 7
r[13*i+ 9] |= t[6] << 6
r[13*i+10] = t[6] >> 2
r[13*i+11] = t[6] >> 10
r[13*i+11] |= t[7] << 3
r[13*i+12] = t[7] >> 5
}
}
polyt0_unpack :: proc "contextless" (r: ^Poly, a: []byte) #no_bounds_check {
for i in 0..<N/8 {
r.coeffs[8*i+0] = i32(a[13*i+0])
r.coeffs[8*i+0] |= i32(u32(a[13*i+1]) << 8)
r.coeffs[8*i+0] &= 0x1FFF
r.coeffs[8*i+1] = i32(a[13*i+1] >> 5)
r.coeffs[8*i+1] |= i32(u32(a[13*i+2]) << 3)
r.coeffs[8*i+1] |= i32(u32(a[13*i+3]) << 11)
r.coeffs[8*i+1] &= 0x1FFF
r.coeffs[8*i+2] = i32(a[13*i+3] >> 2)
r.coeffs[8*i+2] |= i32(u32(a[13*i+4]) << 6)
r.coeffs[8*i+2] &= 0x1FFF
r.coeffs[8*i+3] = i32(a[13*i+4] >> 7)
r.coeffs[8*i+3] |= i32(u32(a[13*i+5]) << 1)
r.coeffs[8*i+3] |= i32(u32(a[13*i+6]) << 9)
r.coeffs[8*i+3] &= 0x1FFF
r.coeffs[8*i+4] = i32(a[13*i+6] >> 4)
r.coeffs[8*i+4] |= i32(u32(a[13*i+7]) << 4)
r.coeffs[8*i+4] |= i32(u32(a[13*i+8]) << 12)
r.coeffs[8*i+4] &= 0x1FFF
r.coeffs[8*i+5] = i32(a[13*i+8] >> 1)
r.coeffs[8*i+5] |= i32(u32(a[13*i+9]) << 7)
r.coeffs[8*i+5] &= 0x1FFF
r.coeffs[8*i+6] = i32(a[13*i+9] >> 6)
r.coeffs[8*i+6] |= i32(u32(a[13*i+10]) << 2)
r.coeffs[8*i+6] |= i32(u32(a[13*i+11]) << 10)
r.coeffs[8*i+6] &= 0x1FFF
r.coeffs[8*i+7] = i32(a[13*i+11] >> 3)
r.coeffs[8*i+7] |= i32(u32(a[13*i+12]) << 5)
r.coeffs[8*i+7] &= 0x1FFF
r.coeffs[8*i+0] = (1 << (D-1)) - r.coeffs[8*i+0]
r.coeffs[8*i+1] = (1 << (D-1)) - r.coeffs[8*i+1]
r.coeffs[8*i+2] = (1 << (D-1)) - r.coeffs[8*i+2]
r.coeffs[8*i+3] = (1 << (D-1)) - r.coeffs[8*i+3]
r.coeffs[8*i+4] = (1 << (D-1)) - r.coeffs[8*i+4]
r.coeffs[8*i+5] = (1 << (D-1)) - r.coeffs[8*i+5]
r.coeffs[8*i+6] = (1 << (D-1)) - r.coeffs[8*i+6]
r.coeffs[8*i+7] = (1 << (D-1)) - r.coeffs[8*i+7]
}
}
polyz_pack :: proc "contextless" (r: []byte, a: ^Poly, params: ^Params) #no_bounds_check {
t: [4]u32 = ---
defer crypto.zero_explicit(&t, size_of(t))
gamma1 := params.gamma1
switch gamma1 {
case 1 << 17:
for i in 0..<N/4 {
t[0] = u32(gamma1 - a.coeffs[4*i+0])
t[1] = u32(gamma1 - a.coeffs[4*i+1])
t[2] = u32(gamma1 - a.coeffs[4*i+2])
t[3] = u32(gamma1 - a.coeffs[4*i+3])
r[9*i+0] = byte(t[0])
r[9*i+1] = byte(t[0] >> 8)
r[9*i+2] = byte(t[0] >> 16)
r[9*i+2] |= byte(t[1] << 2)
r[9*i+3] = byte(t[1] >> 6)
r[9*i+4] = byte(t[1] >> 14)
r[9*i+4] |= byte(t[2] << 4)
r[9*i+5] = byte(t[2] >> 4)
r[9*i+6] = byte(t[2] >> 12)
r[9*i+6] |= byte(t[3] << 6)
r[9*i+7] = byte(t[3] >> 2)
r[9*i+8] = byte(t[3] >> 10)
}
case 1 << 19:
for i in 0..<N/2 {
t[0] = u32(gamma1 - a.coeffs[2*i+0])
t[1] = u32(gamma1 - a.coeffs[2*i+1])
r[5*i+0] = byte(t[0])
r[5*i+1] = byte(t[0] >> 8)
r[5*i+2] = byte(t[0] >> 16)
r[5*i+2] |= byte(t[1] << 4)
r[5*i+3] = byte(t[1] >> 4)
r[5*i+4] = byte(t[1] >> 12)
}
case:
unreachable()
}
}
polyz_unpack :: proc "contextless" (r: ^Poly, a: []byte, params: ^Params) #no_bounds_check {
gamma1 := params.gamma1
switch gamma1 {
case 1 << 17:
for i in 0..<N/4 {
r.coeffs[4*i+0] = i32(a[9*i+0])
r.coeffs[4*i+0] |= i32(u32(a[9*i+1]) << 8)
r.coeffs[4*i+0] |= i32(u32(a[9*i+2]) << 16)
r.coeffs[4*i+0] &= 0x3FFFF
r.coeffs[4*i+1] = i32(a[9*i+2] >> 2)
r.coeffs[4*i+1] |= i32(u32(a[9*i+3]) << 6)
r.coeffs[4*i+1] |= i32(u32(a[9*i+4]) << 14)
r.coeffs[4*i+1] &= 0x3FFFF
r.coeffs[4*i+2] = i32(a[9*i+4] >> 4)
r.coeffs[4*i+2] |= i32(u32(a[9*i+5]) << 4)
r.coeffs[4*i+2] |= i32(u32(a[9*i+6]) << 12)
r.coeffs[4*i+2] &= 0x3FFFF
r.coeffs[4*i+3] = i32(a[9*i+6] >> 6)
r.coeffs[4*i+3] |= i32(u32(a[9*i+7]) << 2)
r.coeffs[4*i+3] |= i32(u32(a[9*i+8]) << 10)
r.coeffs[4*i+3] &= 0x3FFFF
r.coeffs[4*i+0] = gamma1 - r.coeffs[4*i+0]
r.coeffs[4*i+1] = gamma1 - r.coeffs[4*i+1]
r.coeffs[4*i+2] = gamma1 - r.coeffs[4*i+2]
r.coeffs[4*i+3] = gamma1 - r.coeffs[4*i+3]
}
case 1 << 19:
for i in 0..<N/2 {
r.coeffs[2*i+0] = i32(a[5*i+0])
r.coeffs[2*i+0] |= i32(u32(a[5*i+1]) << 8)
r.coeffs[2*i+0] |= i32(u32(a[5*i+2]) << 16)
r.coeffs[2*i+0] &= 0xFFFFF
r.coeffs[2*i+1] = i32(a[5*i+2] >> 4)
r.coeffs[2*i+1] |= i32(u32(a[5*i+3]) << 4)
r.coeffs[2*i+1] |= i32(u32(a[5*i+4]) << 12)
/* r.coeffs[2*i+1] &= 0xFFFFF */ /* No effect, since we're anyway at 20 bits */
r.coeffs[2*i+0] = gamma1 - r.coeffs[2*i+0]
r.coeffs[2*i+1] = gamma1 - r.coeffs[2*i+1]
}
case:
unreachable()
}
}
polyw1_pack :: proc "contextless" (r: []byte, a: ^Poly, params: ^Params) #no_bounds_check {
switch params.gamma2 {
case (Q-1)/88:
for i in 0..<N/4 {
r[3*i+0] = byte(a.coeffs[4*i+0])
r[3*i+0] |= byte(a.coeffs[4*i+1] << 6)
r[3*i+1] = byte(a.coeffs[4*i+1] >> 2)
r[3*i+1] |= byte(a.coeffs[4*i+2] << 4)
r[3*i+2] = byte(a.coeffs[4*i+2] >> 4)
r[3*i+2] |= byte(a.coeffs[4*i+3] << 2)
}
case (Q-1)/32:
for i in 0..<N/2 {
r[i] = byte(a.coeffs[2*i+0] | (a.coeffs[2*i+1] << 4))
}
case:
unreachable()
}
}

View File

@@ -0,0 +1,209 @@
#+private
package _mldsa
import "core:crypto"
Polyvec_L :: struct {
vec: [L_MAX]Poly,
}
Polyvec_K :: struct {
vec: [K_MAX]Poly,
}
polyvec_copy :: proc "contextless" (dst, src: ^$T, params: ^Params) where T == Polyvec_L || T == Polyvec_K {
when T == Polyvec_L {
n := params.l
} else {
n := params.k
}
for i in 0..<n {
copy(dst.vec[i].coeffs[:], src.vec[i].coeffs[:])
}
}
polyvec_clear :: proc "contextless" (vecs: []^$T) where T == Polyvec_L || T == Polyvec_K {
for _, i in vecs {
crypto.zero_explicit(vecs[i], size_of(T))
}
}
polyvec_matrix_expand :: proc(mat: []Polyvec_L, rho: []byte, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
for j in 0..<params.l {
poly_uniform(&mat[i].vec[j], rho, u16((i << 8) + j))
}
}
}
polyvec_matrix_pointwise_montgomery :: proc "contextless" (t: ^Polyvec_K, mat: []Polyvec_L, v: ^Polyvec_L, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
polyvec_l_pointwise_acc_montgomery(&t.vec[i], &mat[i], v, params)
}
}
polyvec_l_uniform_eta :: proc(v: ^Polyvec_L, seed: []byte, iv: u16, params: ^Params) #no_bounds_check {
iv := iv
for i in 0..<params.l {
poly_uniform_eta(&v.vec[i], seed, iv, params)
iv += 1
}
}
polyvec_l_uniform_gamma1 :: proc(v: ^Polyvec_L, seed: []byte, iv: u32, params: ^Params) #no_bounds_check {
for i in 0..<params.l {
poly_uniform_gamma1(&v.vec[i], seed, u16(u32(params.l) * iv + u32(i)), params)
}
}
polyvec_l_reduce :: proc "contextless" (v: ^Polyvec_L, params: ^Params) #no_bounds_check {
for i in 0..<params.l {
poly_reduce(&v.vec[i])
}
}
polyvec_l_add :: proc "contextless" (w, u, v: ^Polyvec_L, params: ^Params) #no_bounds_check {
for i in 0..<params.l {
poly_add(&w.vec[i], &u.vec[i], &v.vec[i])
}
}
polyvec_l_ntt :: proc "contextless" (v: ^Polyvec_L, params: ^Params) {
for i in 0..<params.l {
poly_ntt(&v.vec[i])
}
}
polyvec_l_invntt_tomont :: proc "contextless" (v: ^Polyvec_L, params: ^Params) {
for i in 0..<params.l {
poly_invntt_tomont(&v.vec[i])
}
}
polyvec_l_pointwise_poly_montgomery :: proc "contextless" (r: ^Polyvec_L, a: ^Poly, v: ^Polyvec_L, params: ^Params) #no_bounds_check {
for i in 0..<params.l {
poly_pointwise_montgomery(&r.vec[i], a, &v.vec[i])
}
}
polyvec_l_pointwise_acc_montgomery :: proc "contextless" (w: ^Poly, u, v: ^Polyvec_L, params: ^Params) #no_bounds_check {
t: Poly
poly_pointwise_montgomery(w, &u.vec[0], &v.vec[0])
for i in 1..<params.l {
poly_pointwise_montgomery(&t, &u.vec[i], &v.vec[i])
poly_add(w, w, &t)
}
}
polyvec_l_chknorm :: proc "contextless" (v: ^Polyvec_L, bound: i32, params: ^Params) -> bool #no_bounds_check {
for i in 0..<params.l {
if poly_chknorm(&v.vec[i],bound) {
return true
}
}
return false
}
polyvec_k_uniform_eta :: proc (v: ^Polyvec_K, seed: []byte, iv: u16, params: ^Params) #no_bounds_check {
iv := iv
for i in 0..<params.k {
poly_uniform_eta(&v.vec[i], seed, iv, params)
iv += 1
}
}
polyvec_k_reduce :: proc "contextless" (v: ^Polyvec_K, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
poly_reduce(&v.vec[i])
}
}
polyvec_k_caddq :: proc "contextless" (v: ^Polyvec_K, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
poly_caddq(&v.vec[i])
}
}
polyvec_k_add :: proc "contextless" (w, u, v: ^Polyvec_K, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
poly_add(&w.vec[i], &u.vec[i], &v.vec[i])
}
}
polyvec_k_sub :: proc "contextless" (w, u, v: ^Polyvec_K, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
poly_sub(&w.vec[i], &u.vec[i], &v.vec[i])
}
}
polyvec_k_shiftl :: proc "contextless" (v: ^Polyvec_K, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
poly_shiftl(&v.vec[i])
}
}
polyvec_k_ntt :: proc "contextless" (v: ^Polyvec_K, params: ^Params) {
for i in 0..<params.k {
poly_ntt(&v.vec[i])
}
}
polyvec_k_invntt_tomont :: proc "contextless" (v: ^Polyvec_K, params: ^Params) {
for i in 0..<params.k {
poly_invntt_tomont(&v.vec[i])
}
}
polyvec_k_pointwise_poly_montgomery :: proc "contextless" (r: ^Polyvec_K, a: ^Poly, v: ^Polyvec_K, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
poly_pointwise_montgomery(&r.vec[i], a, &v.vec[i])
}
}
polyvec_k_chknorm :: proc "contextless" (v: ^Polyvec_K, bound: i32, params: ^Params) -> bool #no_bounds_check {
for i in 0..<params.k {
if poly_chknorm(&v.vec[i],bound) {
return true
}
}
return false
}
polyvec_k_power2round :: proc "contextless" (v1, v0, v: ^Polyvec_K, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
poly_power2round(&v1.vec[i], &v0.vec[i], &v.vec[i])
}
}
polyvec_k_decompose :: proc "contextless" (v1, v0, v: ^Polyvec_K, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
poly_decompose(&v1.vec[i], &v0.vec[i], &v.vec[i], params)
}
}
polyvec_k_make_hint :: proc "contextless" (h, v0, v1: ^Polyvec_K, params: ^Params) -> uint #no_bounds_check {
s: uint
for i in 0..<params.k {
s += poly_make_hint(&h.vec[i], &v0.vec[i], &v1.vec[i], params)
}
return s
}
polyvec_k_use_hint :: proc "contextless" (w, u, h: ^Polyvec_K, params: ^Params) #no_bounds_check {
for i in 0..<params.k {
poly_use_hint(&w.vec[i],&u.vec[i], &h.vec[i], params)
}
}
polyvec_k_pack_w1 :: proc "contextless" (r: []byte, w1: ^Polyvec_K, params: ^Params) #no_bounds_check {
packed_len := polyw1_packedbytes(params)
for i in 0..<params.k {
polyw1_pack(r[i*packed_len:], &w1.vec[i], params)
}
}

View File

@@ -0,0 +1,35 @@
#+private
package _mldsa
// MONT :: -4186625 // 2^32 % Q
@(require_results)
montgomery_reduce :: proc "contextless" (a: i64) -> i32 {
QINV :: 58728449 // q^(-1) mod 2^32
t := i32(i64(i32(a)) * QINV)
t = i32((a - i64(t) * Q) >> 32)
return t
}
@(require_results)
reduce32 :: #force_inline proc "contextless" (a: i32) -> i32 {
t := (a + (1 << 22)) >> 23
t = a - t * Q
return t
}
@(require_results)
caddq :: #force_inline proc "contextless" (a: i32) -> i32 {
a := a
a += (a >> 31) & Q
return a
}
// @(require_results)
// freeze :: #force_inline proc "contextless" (a: i32) -> i32 {
// a := a
// a = reduce32(a)
// a = caddq(a)
// return a
// }

View File

@@ -0,0 +1,56 @@
#+private
package _mldsa
power2round :: proc "contextless" (a: i32) -> (i32, i32) {
a1 := (a + (1 << (D-1)) - 1) >> D
a0 := a - (a1 << D)
return a0, a1
}
decompose :: proc "contextless" (a: i32, gamma2: i32) -> (i32, i32) {
a1 := (a + 127) >> 7
switch gamma2 {
case (Q - 1)/32:
a1 = (a1 * 1025 + (1 << 21)) >> 22
a1 &= 15
case (Q - 1)/88:
a1 = (a1 * 11275 + (1 << 23)) >> 24
a1 ~= ((43 - a1) >> 31) & a1
}
a0 := a - a1 * 2 * gamma2
a0 -= (((Q - 1)/2 - a0) >> 31) & Q
return a0, a1
}
make_hint :: proc "contextless" (a0, a1: i32, gamma2: i32) -> uint {
if (a0 > gamma2 || a0 < -gamma2 || (a0 == -gamma2 && a1 != 0)) {
return 1
}
return 0
}
use_hint :: proc "contextless" (a: i32, hint: uint, gamma2: i32) -> i32 {
a0, a1 := decompose(a, gamma2)
if hint == 0 {
return a1
}
switch gamma2 {
case (Q - 1)/32:
if (a0 > 0) {
return (a1 + 1) & 15
} else {
return (a1 -1) & 15
}
case (Q - 1)/88:
if (a0 > 0) {
return (a1 == 43) ? 0 : a1 + 1
} else {
return (a1 == 0) ? 43 : a1 - 1
}
}
unreachable()
}

View File

@@ -0,0 +1,39 @@
#+private
package _mldsa
import "core:crypto/_sha3"
import "core:crypto/shake"
STREAM128_BLOCKBYTES :: _sha3.RATE_128
STREAM256_BLOCKBYTES :: _sha3.RATE_256
stream128_init :: proc(ctx: ^shake.Context, seed: []byte, iv: u16) {
t: [2]byte = ---
t[0] = byte(iv)
t[1] = byte(iv >> 8)
shake.init_128(ctx)
shake.write(ctx, seed)
shake.write(ctx, t[:])
}
stream256_init :: proc(ctx: ^shake.Context, seed: []byte, iv: u16) {
t: [2]byte = ---
t[0] = byte(iv)
t[1] = byte(iv >> 8)
shake.init_256(ctx)
shake.write(ctx, seed)
shake.write(ctx, t[:])
}
shake256 :: proc(dst: []byte, srcs: ..[]byte) {
ctx: shake.Context = ---
defer shake.reset(&ctx)
shake.init_256(&ctx)
for src in srcs {
shake.write(&ctx, src)
}
shake.read(&ctx, dst)
}

290
core/crypto/mldsa/api.odin Normal file
View File

@@ -0,0 +1,290 @@
package mldsa
import "core:crypto"
import "core:crypto/_mldsa"
// Parameters are the supported ML-DSA parameter sets.
Parameters :: enum {
Invalid,
ML_DSA_44,
ML_DSA_65,
ML_DSA_87,
}
// PRIVATE_KEY_SEED_SIZE is the size of a private key in bytes.
PRIVATE_KEY_SEED_SIZE :: _mldsa.SEEDBYTES // 32-bytes
// MAX_CTX_SIZE is the maximum size of the signature context
// (domain separation tag) in bytes.
MAX_CTX_SIZE :: _mldsa.CTXBYTES_MAX // 255-bytes
// PUBLIC_KEY_SIZES are the per-parameter sizes of a public
// key in bytes.
PUBLIC_KEY_SIZES := [Parameters]int {
.Invalid = 0,
.ML_DSA_44 = 1312,
.ML_DSA_65 = 1952,
.ML_DSA_87 = 2592,
}
// SIGNATURE_SIZES are the per-parameter sizes of a signature
// in byte.
SIGNATURE_SIZES := [Parameters]int {
.Invalid = 0,
.ML_DSA_44 = 2420,
.ML_DSA_65 = 3309,
.ML_DSA_87 = 4627,
}
@(private="file")
_PARAMS_TO_INTERNAL := [Parameters]^_mldsa.Params {
.Invalid = nil,
.ML_DSA_44 = &_mldsa.Params_44,
.ML_DSA_65 = &_mldsa.Params_65,
.ML_DSA_87 = &_mldsa.Params_87,
}
// Private_Key is a ML-DSA private key.
Private_Key :: _mldsa.Private_Key
// Public_Key is a ML-DSA public key.
Public_Key :: _mldsa.Public_Key
// private_key_generate uses the system entropy source to generate a new
// Private_Key. This will only fail if and only if (⟺) the system entropy
// source is missing or broken.
@(require_results)
private_key_generate :: proc(priv_key: ^Private_Key, params: Parameters) -> bool {
private_key_clear(priv_key)
if !crypto.HAS_RAND_BYTES {
return false
}
params_ := _PARAMS_TO_INTERNAL[params]
if params_ == nil {
return false
}
seed: [PRIVATE_KEY_SEED_SIZE]byte = ---
defer crypto.zero_explicit(&seed, size_of(seed))
crypto.rand_bytes(seed[:])
_mldsa.dsa_keygen_internal(priv_key, seed[:], params_)
return true
}
// private_key_set_bytes decodes a byte-encoded private key in "seed" format,
// and returns true if and only if (⟺) the operation was successful.
@(require_results)
private_key_set_bytes :: proc(priv_key: ^Private_Key, params: Parameters, b: []byte) -> bool {
private_key_clear(priv_key)
params_ := _PARAMS_TO_INTERNAL[params]
if params_ == nil {
return false
}
if len(b) != PRIVATE_KEY_SEED_SIZE {
return false
}
_mldsa.dsa_keygen_internal(priv_key, b, params_)
return true
}
// private_key_bytes sets dst to byte-encoding of priv_key in the "seed"
// format.
private_key_bytes :: proc(priv_key: ^Private_Key, dst: []byte) {
ensure(priv_key.params != nil, "crypto/mldsa: uninitialized private key")
ensure(len(dst) == PRIVATE_KEY_SEED_SIZE, "crypto/mldsa: invalid destination size")
copy(dst, priv_key.seed[:])
}
// private_key_public_bytes sets dst to the byte-encoding of the public
// key corresponding to priv_key.
private_key_public_bytes :: proc(priv_key: ^Private_Key, dst: []byte) {
public_key_bytes(&priv_key.pub_key, dst)
}
// private_key_set sets priv_key to src.
private_key_set :: proc(priv_key, src: ^Private_Key) {
if src == nil || internal_to_params(src.params) == .Invalid {
private_key_clear(priv_key)
return
}
_mldsa.set_sk(priv_key, src)
}
// private_key_equal returns true if and only if (⟺) the private keys are
// equal, in constant time.
@(require_results)
private_key_equal :: proc(p, q: ^Private_Key) -> bool {
if p.params != q.params {
return false
}
if p.params == nil {
return true
}
// Just compare the seed that was passed to dsa_keygen_internal,
// since the process is completely deterministic.
return crypto.compare_constant_time(p.seed[:], q.seed[:]) == 1
}
// private_key_clear clears priv_key to the uninitialized state.
private_key_clear :: proc "contextless" (priv_key: ^Private_Key) {
_mldsa.clear_sk(priv_key)
}
// public_key_set_bytes decodes a byte-encoded public key, and returns
// true if and only if (⟺) the operation was successful.
@(require_results)
public_key_set_bytes :: proc(pub_key: ^Public_Key, params: Parameters, b: []byte) -> bool {
params_ := _PARAMS_TO_INTERNAL[params]
if params_ == nil {
return false
}
return _mldsa.unpack_pk(pub_key, b, params_)
}
// public_key_set sets pub_key to src.
public_key_set :: proc(pub_key, src: ^Public_Key) {
if src == nil || internal_to_params(src.params) == .Invalid {
public_key_clear(pub_key)
return
}
_mldsa.set_pk(pub_key, src)
}
// public_key_set_priv sets pub_key to the public component of priv_key.
public_key_set_priv :: proc(pub_key: ^Public_Key, priv_key: ^Private_Key) {
ensure(priv_key.params != nil, "crypto/mldsa: uninitialized private key")
public_key_set(pub_key, &priv_key.pub_key)
}
// public_key_bytes sets dst to byte-encoding of pub_key.
public_key_bytes :: proc(pub_key: ^Public_Key, dst: []byte) {
ensure(pub_key.params != nil, "crypto/mldsa: uninitialized public key")
params := internal_to_params(pub_key.params)
ensure(len(dst) == PUBLIC_KEY_SIZES[params], "crypto/mldsa: invalid destination size")
_ = _mldsa.pack_pk(dst, pub_key)
}
// public_key_equal returns true if and only if (⟺) the public keys are equal,
// in constant time.
@(require_results)
public_key_equal :: proc(p, q: ^Public_Key) -> bool {
if p.params != q.params {
return false
}
if p.params == nil {
return true
}
// Comparing the pre-computed hash should be enough, but pack
// both public keys and do the comparisons.
PUBLIC_KEY_SIZE_MAX :: 2592
l := PUBLIC_KEY_SIZES[internal_to_params(p.params)]
p_buf_, q_buf_: [PUBLIC_KEY_SIZE_MAX]byte = ---, ---
p_buf, q_buf := p_buf_[:l], q_buf_[:l]
_ = _mldsa.pack_pk(p_buf, p)
_ = _mldsa.pack_pk(q_buf, q)
return crypto.compare_constant_time(p_buf, q_buf) == 1
}
// public_key_clear clears pub_key to the uninitialized state.
public_key_clear :: proc "contextless" (pub_key: ^Public_Key) {
_mldsa.clear_pk(pub_key)
}
// sign writes the signature by priv_key over (ctx, msg) to sig and
// returns true if and only if (⟺) the signing succeeded.
//
// ctx is an optional domain separation tag and may be omitted (nil).
@(require_results)
sign :: proc(priv_key: ^Private_Key, ctx, msg, sig: []byte, deterministic := !crypto.HAS_RAND_BYTES) -> bool {
params := internal_to_params(priv_key.params)
ensure(params != .Invalid, "crypto/mldsa: invalid private key")
ensure(len(sig) == SIGNATURE_SIZES[params], "crypto/mldsa: invalid destination size")
if !deterministic && !crypto.HAS_RAND_BYTES {
return false
}
if len(ctx) > MAX_CTX_SIZE {
return false
}
rnd: [_mldsa.RNDBYTES]byte
defer crypto.zero_explicit(&rnd, size_of(rnd))
if !deterministic {
crypto.rand_bytes(rnd[:])
}
return _mldsa.dsa_sign_internal(sig, msg, ctx, rnd[:], priv_key)
}
// verify returns true if and only if (⟺) sig is a valid signature by pub_key
// over (ctx, msg).
@(require_results)
verify :: proc(pub_key: ^Public_Key, ctx, msg, sig: []byte) -> bool {
params := internal_to_params(pub_key.params)
ensure(params != .Invalid, "crypto/mldsa: invalid public key")
if len(sig) != SIGNATURE_SIZES[params] {
return false
}
if len(ctx) > MAX_CTX_SIZE {
return false
}
return _mldsa.dsa_verify_internal(sig, msg, ctx, pub_key)
}
// params returns the Parameters used by a Private_Key or Public_Key
// instance.
@(require_results)
params :: proc(k: ^$T) -> Parameters where (T == Private_Key || T == Public_Key) {
return internal_to_params(k.params)
}
// key_size returns the key size of a Private_Key or Public_Key in bytes.
@(require_results)
key_size :: proc(k: ^$T) -> int where (T == Private_Key || T == Public_Key) {
when T == Private_Key {
return PRIVATE_KEY_SEED_SIZE
} else {
return PUBLIC_KEY_SIZES[internal_to_params(k.params)]
}
}
// signature_size returns the key size of a signature in bytes.
@(require_results)
signature_size :: proc(k: ^$T) -> int where (T == Private_Key || T == Public_Key) {
return SIGNATURE_SIZES[internal_to_params(k.params)]
}
@(private="file",require_results)
internal_to_params :: proc "contextless" (params: ^_mldsa.Params) -> Parameters {
switch params {
case &_mldsa.Params_44:
return .ML_DSA_44
case &_mldsa.Params_65:
return .ML_DSA_65
case &_mldsa.Params_87:
return .ML_DSA_87
case:
return .Invalid
}
}

View File

@@ -0,0 +1,7 @@
/*
Module-Lattice-Based Digital Signature Algorithm.
See:
- [[ https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.204.pdf ]]
*/
package mldsa

View File

@@ -43,6 +43,7 @@ package all
@(require) import "core:crypto/legacy/keccak"
@(require) import "core:crypto/legacy/md5"
@(require) import "core:crypto/legacy/sha1"
@(require) import "core:crypto/mldsa"
@(require) import "core:crypto/mlkem"
@(require) import cnoise "core:crypto/noise"
@(require) import "core:crypto/pbkdf2"

View File

@@ -48,6 +48,7 @@ package all
@(require) import "core:crypto/legacy/keccak"
@(require) import "core:crypto/legacy/md5"
@(require) import "core:crypto/legacy/sha1"
@(require) import "core:crypto/mldsa"
@(require) import "core:crypto/mlkem"
@(require) import cnoise "core:crypto/noise"
@(require) import "core:crypto/pbkdf2"

View File

@@ -18,8 +18,8 @@ import "core:crypto/hash"
ECDH_ITERS :: 10000
@(private = "file")
DSA_ITERS :: 10000
@(private = "file")
MSG : string : "Got a job for you, 621."
@(private)
SIG_MSG : string : "Got a job for you, 621."
@(test)
benchmark_crypto_ecc :: proc(t: ^testing.T) {
@@ -126,8 +126,8 @@ bench_dsa :: proc() {
@(private = "file")
bench_ed25519 :: proc() -> (sk, sig, verif: time.Duration) {
priv_str := "cafebabecafebabecafebabecafebabecafebabecafebabecafebabecafebabe"
priv_bytes, _ := hex.decode(transmute([]byte)(priv_str), context.temp_allocator)
SEED : string : "cafebabecafebabecafebabecafebabecafebabecafebabecafebabecafebabe"
priv_bytes, _ := hex.decode(transmute([]byte)(SEED), context.temp_allocator)
priv_key: ed25519.Private_Key
start := time.tick_now()
for _ in 0 ..< DSA_ITERS {
@@ -136,13 +136,11 @@ bench_ed25519 :: proc() -> (sk, sig, verif: time.Duration) {
}
sk = time.tick_since(start) / DSA_ITERS
pub_bytes := priv_key._pub_key._b[:] // "I know what I am doing"
pub_key: ed25519.Public_Key
ok := ed25519.public_key_set_bytes(&pub_key, pub_bytes[:])
assert(ok, "public key should deserialize")
ed25519.public_key_set_priv(&pub_key, &priv_key)
sig_bytes: [ed25519.SIGNATURE_SIZE]byte
msg_bytes := transmute([]byte)(MSG)
msg_bytes := transmute([]byte)(SIG_MSG)
start = time.tick_now()
for _ in 0 ..< DSA_ITERS {
ed25519.sign(&priv_key, msg_bytes, sig_bytes[:])
@@ -151,7 +149,7 @@ bench_ed25519 :: proc() -> (sk, sig, verif: time.Duration) {
start = time.tick_now()
for _ in 0 ..< DSA_ITERS {
ok = ed25519.verify(&pub_key, msg_bytes, sig_bytes[:])
ok := ed25519.verify(&pub_key, msg_bytes, sig_bytes[:])
assert(ok, "signature should validate")
}
verif = time.tick_since(start) / DSA_ITERS
@@ -175,7 +173,7 @@ bench_ecdsa :: proc(curve: ecdsa.Curve, hash: hash.Algorithm) -> (sk, sig, verif
ecdsa.public_key_set_priv(&pub_key, &priv_key)
sig_bytes := make([]byte, ecdsa.RAW_SIGNATURE_SIZES[curve], context.temp_allocator)
msg_bytes := transmute([]byte)(MSG)
msg_bytes := transmute([]byte)(SIG_MSG)
start = time.tick_now()
for _ in 0 ..< DSA_ITERS {
ok := ecdsa.sign_raw(&priv_key, hash, msg_bytes, sig_bytes, true)

View File

@@ -6,15 +6,19 @@ import "core:text/table"
import "core:time"
import "core:crypto"
import "core:crypto/mldsa"
import "core:crypto/mlkem"
@(private = "file")
MLKEM_ITERS :: 50000
@(private = "file")
MLDSA_ITERS :: 10000
@(test)
benchmark_crypto_mlkem :: proc(t: ^testing.T) {
if !crypto.HAS_RAND_BYTES {
log.warnf("ML-KEM benchmarks skipped, no system entropy source")
return
}
tbl: table.Table
@@ -75,6 +79,86 @@ benchmark_crypto_mlkem :: proc(t: ^testing.T) {
log_table(&tbl)
}
@(test)
bench_mldsa :: proc(t: ^testing.T) {
if !crypto.HAS_RAND_BYTES {
log.warnf("ML-DSA benchmarks skipped, no system entropy source")
return
}
tbl: table.Table
table.init(&tbl)
defer table.destroy(&tbl)
table.caption(&tbl, "ML-DSA")
table.aligned_header_of_values(&tbl, .Right, "Parameters", "Op", "Time")
append_tbl := proc(tbl: ^table.Table, algo_name, op: string, t: time.Duration) {
table.aligned_row_of_values(
tbl,
.Right,
algo_name,
op,
table.format(tbl, "%8M", t),
)
}
do_bench := proc(params: mldsa.Parameters) -> (sk, sig, verif: time.Duration) {
// The time taken is highly seed dependent due to rejection
// sampling using SHAKE, so we hit up the system entropy source
// and crank up the iteration count.
priv_key: mldsa.Private_Key
start := time.tick_now()
for _ in 0 ..< MLDSA_ITERS*2 {
ok := mldsa.private_key_generate(&priv_key, params)
assert(ok, "private key should generate")
}
sk = time.tick_since(start) / (MLDSA_ITERS*2)
pub_key: mldsa.Public_Key
mldsa.public_key_set_priv(&pub_key, &priv_key)
msg_bytes := transmute([]byte)(SIG_MSG)
sig_bytes := make([]byte, mldsa.SIGNATURE_SIZES[params])
defer delete(sig_bytes)
// FIPS defaults to hedged mode with non-deterministic signatures.
start = time.tick_now()
for _ in 0 ..< MLDSA_ITERS {
ok := mldsa.sign(&priv_key, nil, msg_bytes, sig_bytes)
assert(ok, "signature should succeed")
}
sig = time.tick_since(start) / MLDSA_ITERS
start = time.tick_now()
for _ in 0 ..< MLDSA_ITERS {
ok := mldsa.verify(&pub_key, nil, msg_bytes, sig_bytes)
assert(ok, "signature should validate")
}
verif = time.tick_since(start) / MLDSA_ITERS
return
}
for params in mldsa.Parameters {
if params == .Invalid {
continue
}
param_name := MLDSA_PARAMS_NAMES[params]
sig, sk, verif := do_bench(params)
append_tbl(&tbl, param_name, "private_key_generate", sk)
append_tbl(&tbl, param_name, "sign", sig)
append_tbl(&tbl, param_name, "verify", verif)
if params != .ML_DSA_87 {
table.row(&tbl)
}
}
log_table(&tbl)
}
@(private="file")
MLKEM_PARAMS_NAMES := [mlkem.Parameters]string {
.Invalid = "invalid",
@@ -82,3 +166,11 @@ MLKEM_PARAMS_NAMES := [mlkem.Parameters]string {
.ML_KEM_768 = "ML-KEM-768",
.ML_KEM_1024 = "ML-KEM-1024",
}
@(private="file")
MLDSA_PARAMS_NAMES := [mldsa.Parameters]string {
.Invalid = "invalid",
.ML_DSA_44 = "ML-DSA-44",
.ML_DSA_65 = "ML-DSA-65",
.ML_DSA_87 = "ML-DSA-87",
}

View File

@@ -1,10 +1,12 @@
package test_core_crypto
import "core:bytes"
import "core:fmt"
import "core:log"
import "core:testing"
import "core:crypto"
import "core:crypto/mldsa"
import "core:crypto/mlkem"
@(test)
@@ -70,3 +72,88 @@ test_mlkem :: proc(t: ^testing.T) {
)
}
}
@(test)
test_mldsa :: proc(t: ^testing.T) {
TEST_MSG : string : "ML-DSA test message"
msg_bytes := transmute([]byte)(TEST_MSG)
// Test vectors are huge, and are covered by the wycheproof corpus,
// so do some casual tests.
for params in mldsa.Parameters {
if params == .Invalid {
continue
}
seed: [mldsa.PRIVATE_KEY_SEED_SIZE]byte
fmt.bprintf(seed[:], "odin test - %v", params)
priv_key: mldsa.Private_Key
if !testing.expectf(
t,
mldsa.private_key_set_bytes(&priv_key, params, seed[:]),
"%v: private_key_set_bytes",
params,
) {
continue
}
sig_det_bytes := make([]byte, mldsa.SIGNATURE_SIZES[params])
defer delete(sig_det_bytes)
if !testing.expectf(
t,
mldsa.sign(&priv_key, nil, msg_bytes, sig_det_bytes, true),
"%v: sign (deterministic)",
params,
) {
continue
}
pub_key: mldsa.Public_Key
mldsa.public_key_set_priv(&pub_key, &priv_key)
if !testing.expectf(
t,
mldsa.verify(&pub_key, nil, msg_bytes, sig_det_bytes),
"%v: verify (deterministic)",
params,
) {
continue
}
if !crypto.HAS_RAND_BYTES {
continue
}
sig_hedged_bytes := make([]byte, mldsa.SIGNATURE_SIZES[params])
defer delete(sig_hedged_bytes)
if !testing.expectf(
t,
mldsa.sign(&priv_key, nil, msg_bytes, sig_hedged_bytes),
"%v: sign (hedged)",
params,
) {
continue
}
if !testing.expectf(
t,
mldsa.verify(&pub_key, nil, msg_bytes, sig_hedged_bytes),
"%v: verify (hedged)",
params,
) {
continue
}
// False positive rate of 1/(2^256), assuming a functional
// entropy source.
testing.expectf(
t,
!bytes.equal(sig_det_bytes, sig_hedged_bytes),
"%v: deterministic sig should not equal hedged",
params,
)
}
}

View File

@@ -4,11 +4,15 @@ import "core:encoding/hex"
import "core:log"
import "core:mem"
import "core:os"
import "core:slice"
import "core:testing"
import "core:crypto/_mlkem"
import "core:crypto/mlkem"
import "core:crypto/_mldsa"
import "core:crypto/mldsa"
import "../common"
@(test)
@@ -81,7 +85,7 @@ test_mlkem :: proc(t: ^testing.T) {
test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool {
params_str := test_vectors.test_groups[0].parameter_set
params := parameter_set_to_params(params_str)
params := mlkem_parameter_set_to_params(params_str)
if params == .Invalid {
return false
}
@@ -93,19 +97,32 @@ test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"%s/KeyGen/%d/%d: %s: %+v",
params_str,
tg_id,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("%s/KeyGen/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags)
}
seed := common.hexbytes_decode(test_vector.seed)
dk: mlkem.Decapsulation_Key
if !testing.expectf(
t,
mlkem.decapsulation_key_set_bytes(&dk, params, seed),
"%s/KeyGen/%d/%d: failed to set decapsulation key from seed",
"%s/KeyGen/%d/%d: failed to set decapsulation key from seed: %s",
params_str,
tg_id,
test_vector.tc_id,
test_vector.seed,
) {
num_failed *= 1
num_failed += 1
continue
}
@@ -184,7 +201,7 @@ test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr
test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool {
params_str := test_vectors.test_groups[0].parameter_set
params := parameter_set_to_params(params_str)
params := mlkem_parameter_set_to_params(params_str)
if params == .Invalid {
return false
}
@@ -196,6 +213,19 @@ test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"%s/Encaps/%d/%d: %s: %+v",
params_str,
tg_id,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("%s/Encaps/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags)
}
ek: mlkem.Encapsulation_Key
ok := mlkem.encapsulation_key_set_bytes(
&ek,
@@ -284,7 +314,7 @@ test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr
test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool {
params_str := test_vectors.test_groups[0].parameter_set
params := parameter_set_to_params(params_str)
params := mlkem_parameter_set_to_params(params_str)
if params == .Invalid {
return false
}
@@ -296,6 +326,19 @@ test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"%s/Decaps/%d/%d: %s: %+v",
params_str,
tg_id,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("%s/Decaps/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags)
}
// We do not have an API for decaps with raw seed.
seed := common.hexbytes_decode(test_vector.seed)
switch len(seed) {
@@ -386,8 +429,8 @@ test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Gr
return num_failed == 0
}
@(require_results, private="file")
parameter_set_to_params :: proc(s: string) -> mlkem.Parameters {
@(require_results,private="file")
mlkem_parameter_set_to_params :: proc(s: string) -> mlkem.Parameters {
switch s {
case "ML-KEM-512":
return .ML_KEM_512
@@ -399,3 +442,321 @@ parameter_set_to_params :: proc(s: string) -> mlkem.Parameters {
return .Invalid
}
}
@(test)
test_mldsa :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
log.debug("mldsa: starting")
files_sign := []string {
"mldsa_44_sign_seed_test.json",
"mldsa_65_sign_seed_test.json",
"mldsa_87_sign_seed_test.json",
}
for f in files_sign {
mem.free_all()
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Mldsa_Test_Group)
load_ok := load(&test_vectors, fn)
if !testing.expectf(t, load_ok, "Unable to load {}", f) {
continue
}
testing.expectf(t, test_mldsa_sign(t, &test_vectors), "ML-DSA Sign failed")
}
files_verify := []string {
"mldsa_44_verify_test.json",
"mldsa_65_verify_test.json",
"mldsa_87_verify_test.json",
}
for f in files_verify {
mem.free_all()
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Mldsa_Test_Group)
load_ok := load(&test_vectors, fn)
if !testing.expectf(t, load_ok, "Unable to load {}", f) {
continue
}
testing.expectf(t, test_mldsa_verify(t, &test_vectors), "ML-DSA Verify failed")
}
}
test_mldsa_sign :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Mldsa_Test_Group)) -> bool {
FLAG_INTERNAL :: "Internal"
dummy_rnd: [_mldsa.RNDBYTES]byte
params_str := test_vectors.algorithm
params := mldsa_parameter_set_to_params(params_str)
if params == .Invalid {
return false
}
log.debugf("%s: Sign starting", params_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group, tg_id in test_vectors.test_groups {
seed := common.hexbytes_decode(test_group.private_seed)
priv_key: mldsa.Private_Key
tg_len := len(test_group.tests)
if !testing.expectf(
t,
mldsa.private_key_set_bytes(&priv_key, params, seed),
"%s/Sign/%d: failed to set private key from seed: %s",
params_str,
tg_id,
test_group.private_seed,
) {
num_ran += tg_len
num_failed += tg_len
continue
}
pub_bytes := make([]byte, mldsa.PUBLIC_KEY_SIZES[params])
mldsa.private_key_public_bytes(&priv_key, pub_bytes)
ok := common.hexbytes_compare(test_group.public_key, pub_bytes)
if !ok {
x := transmute(string)(hex.encode(pub_bytes[:]))
log.errorf(
"%s/Sign/%d: public key: expected: %s actual: %s",
params_str,
tg_id,
test_group.public_key,
x,
)
num_ran += tg_len
num_failed += tg_len
continue
}
pub_key: mldsa.Public_Key
if !testing.expectf(
t,
mldsa.public_key_set_bytes(&pub_key, params, pub_bytes),
"%s/Sign/%d: failed to set public key",
params_str,
tg_id,
) {
num_ran += tg_len
num_failed += tg_len
continue
}
sig := make([]byte, mldsa.SIGNATURE_SIZES[params])
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"%s/Sign/%d/%d: %s: %+v",
params_str,
tg_id,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("%s/Sign/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags)
}
ctx := common.hexbytes_decode(test_vector.ctx)
msg := common.hexbytes_decode(test_vector.msg)
is_external_mu := slice.contains(test_vector.flags, FLAG_INTERNAL)
switch is_external_mu {
case false:
ok = mldsa.sign(
&priv_key,
ctx,
msg,
sig,
true,
)
case true:
ok = _mldsa.dsa_sign_internal(
sig,
msg,
ctx,
dummy_rnd[:],
&priv_key,
common.hexbytes_decode(test_vector.mu),
)
}
if !result_check(test_vector.result, ok) {
log.errorf(
"%s/Sign/%d/%d: unexpected sign result: %v",
params_str,
tg_id,
test_vector.tc_id,
ok,
)
num_failed += 1
continue
}
if result_is_invalid(test_vector.result) {
num_passed += 1
continue
}
ok = common.hexbytes_compare(test_vector.sig, sig)
if !ok {
x := transmute(string)(hex.encode(sig))
log.errorf(
"%s/Sign/%d/%d: sign: expected: %s actual: %s",
params_str,
tg_id,
test_vector.tc_id,
test_vector.sig,
x,
)
num_failed += 1
continue
}
// Might as well verify as well if we have the ctx/msg.
if !is_external_mu {
if !testing.expectf(
t,
mldsa.verify(&pub_key, ctx, msg, sig),
"%s/Sign/%d/%d: verify failed",
params_str,
tg_id,
test_vector.tc_id,
) {
num_failed += 1
continue
}
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"%s/Sign: ran %d, passed %d, failed %d, skipped %d",
params_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}
test_mldsa_verify:: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Mldsa_Test_Group)) -> bool {
params_str := test_vectors.algorithm
params := mldsa_parameter_set_to_params(params_str)
if params == .Invalid {
return false
}
log.debugf("%s: Verify starting", params_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group, tg_id in test_vectors.test_groups {
tg_len := len(test_group.tests)
pub_key_bytes := common.hexbytes_decode(test_group.public_key)
pub_key: mldsa.Public_Key
expected := len(pub_key_bytes) == mldsa.PUBLIC_KEY_SIZES[params]
ok := mldsa.public_key_set_bytes(&pub_key, params, pub_key_bytes)
if !testing.expectf(
t,
ok == expected,
"%s/Verify/%d: failed to set public key",
params_str,
tg_id,
) {
num_ran += tg_len
num_failed += tg_len
continue
}
if expected == false {
num_ran += tg_len
num_passed += tg_len
continue
}
for &test_vector in test_group.tests {
if comment := test_vector.comment; comment != "" {
log.debugf(
"%s/Verify/%d/%d: %s: %+v",
params_str,
tg_id,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("%s/Verify/%d/%d: %+v", params_str, tg_id, test_vector.tc_id, test_vector.flags)
}
num_ran += 1
ctx := common.hexbytes_decode(test_vector.ctx)
msg := common.hexbytes_decode(test_vector.msg)
sig := common.hexbytes_decode(test_vector.sig)
ok = mldsa.verify(&pub_key, ctx, msg, sig)
if !result_check(test_vector.result, ok) {
log.errorf(
"%s/Verify/%d/%d: unexpected verify result: %v",
params_str,
tg_id,
test_vector.tc_id,
ok,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"%s/Verify: ran %d, passed %d, failed %d, skipped %d",
params_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}
@(require_results,private="file")
mldsa_parameter_set_to_params :: proc(s: string) -> mldsa.Parameters {
switch s {
case "ML-DSA-44":
return .ML_DSA_44
case "ML-DSA-65":
return .ML_DSA_65
case "ML-DSA-87":
return .ML_DSA_87
case:
return .Invalid
}
}

View File

@@ -29,10 +29,6 @@ result_is_invalid :: proc(r: Result) -> bool {
return r == "invalid"
}
// The type namings are not following Odin convention, to better match
// the schema, though the fields do.
load :: proc(tvs: ^$T/Test_Vectors, fn: string) -> bool {
raw_json, err := os.read_entire_file_from_path(fn, context.allocator)
if err != os.ERROR_NONE {
@@ -223,3 +219,24 @@ Kem_Test_Vector :: struct {
k: common.Hex_Bytes `json:"K"`,
result: Result `json:"result"`,
}
Mldsa_Test_Group :: struct {
type: string `json:"type"`,
private_seed: common.Hex_Bytes `json:"privateSeed"`,
private_key_pkcs8: common.Hex_Bytes `json:"privateKeyPkcs8"`,
public_key: common.Hex_Bytes `json:"publicKey"`,
public_key_der: common.Hex_Bytes `json:"publicKeyDer"`,
source: Test_Group_Source `json:"source"`,
tests: []Mldsa_Test_Vector `json:"tests"`,
}
Mldsa_Test_Vector :: struct {
tc_id: int `json:"tcId"`,
comment: string `json:"comment"`,
msg: common.Hex_Bytes `json:"msg"`,
ctx: common.Hex_Bytes `json:"ctx"`,
mu: common.Hex_Bytes `json:"mu"`,
sig: common.Hex_Bytes `json:"sig"`,
result: Result `json:"result"`,
flags: []string `json:"flags"`,
}