Merge pull request #6635 from Yawning/feature/mlkem

core/crypto/mlkem: Support ML-KEM (FIPS 203)
This commit is contained in:
Jeroen van Rijn
2026-05-06 11:59:08 +02:00
committed by GitHub
30 changed files with 3541 additions and 1346 deletions

View File

@@ -0,0 +1,52 @@
#+private
package _mlkem
import "core:encoding/endian"
unchecked_get_u24le :: #force_inline proc "contextless" (b: []byte) -> u32 #no_bounds_check {
r := u32(b[0])
r |= u32(b[1]) << 8
r |= u32(b[2]) << 16
return r
}
cbd3 :: proc "contextless" (r: ^Poly, buf: ^[3*N/4]byte) #no_bounds_check {
for i in 0..<N/4 {
t := unchecked_get_u24le(buf[3*i:])
d := t & 0x00249249
d += (t>>1) & 0x00249249
d += (t>>2) & 0x00249249
for j in uint(0)..<4 {
a := i16((d >> (6*j+0)) & 0x7)
b := i16((d >> (6*j+3)) & 0x7)
r.coeffs[4*i+int(j)] = a - b
}
}
}
cbd2 :: proc "contextless" (r: ^Poly, buf: ^[2*N/4]byte) #no_bounds_check {
for i in 0..<N/8 {
t := endian.unchecked_get_u32le(buf[4*i:])
d := t & 0x55555555
d += (t>>1) & 0x55555555
for j in uint(0)..<8 {
a := i16((d >> (4*j+0)) & 0x3)
b := i16((d >> (4*j+2)) & 0x3)
r.coeffs[8*i+int(j)] = a - b
}
}
}
poly_cbd_eta1_512 :: proc "contextless" (r: ^Poly, buf: ^[ETA1_512*N/4]byte) {
cbd3(r, buf)
}
poly_cbd_eta1 :: proc "contextless" (r: ^Poly, buf: ^[ETA1*N/4]byte) {
cbd2(r, buf)
}
poly_cbd_eta2 :: proc "contextless" (r: ^Poly, buf: ^[ETA2*N/4]byte) {
cbd2(r, buf)
}

View File

@@ -0,0 +1,53 @@
package _mlkem
K_512 :: 2
K_768 :: 3
K_1024 :: 4
K_MAX :: K_1024
N :: 256
Q :: 3329
ETA1_512 :: 3
ETA1 :: 2
ETA2 :: 2
POLYBYTES :: 384
SYMBYTES :: 32
POLYCOMPRESSEDBYTES_512 :: 128
POLYCOMPRESSEDBYTES_768 :: 128
POLYCOMPRESSEDBYTES_1024 :: 160
POLYVECBYTES_512 :: K_512 * POLYBYTES
POLYVECBYTES_768 :: K_768 * POLYBYTES
POLYVECBYTES_1024 :: K_1024 * POLYBYTES
POLYVECCOMPRESSEDBYTES_512 :: K_512 * 320
POLYVECCOMPRESSEDBYTES_768 :: K_768 * 320
POLYVECCOMPRESSEDBYTES_1024 :: K_1024 * 352
INDCPA_MSGBYTES :: SYMBYTES
INDCPA_PUBLICKEYBYTES_512 :: POLYVECBYTES_512 + SYMBYTES
INDCPA_SECRETKEYBYTES_512 :: POLYVECBYTES_512
INDCPA_PUBLICKEYBYTES_768 :: POLYVECBYTES_768 + SYMBYTES
INDCPA_SECRETKEYBYTES_768 :: POLYVECBYTES_768
INDCPA_PUBLICKEYBYTES_1024 :: POLYVECBYTES_1024 + SYMBYTES
INDCPA_SECRETKEYBYTES_1024 :: POLYVECBYTES_1024
INDCPA_PUBLICKEYBYTES_MAX :: INDCPA_PUBLICKEYBYTES_1024
INDCPA_BYTES_512 :: POLYVECCOMPRESSEDBYTES_512 + POLYCOMPRESSEDBYTES_512
INDCPA_BYTES_768 :: POLYVECCOMPRESSEDBYTES_768 + POLYCOMPRESSEDBYTES_768
INDCPA_BYTES_1024 :: POLYVECCOMPRESSEDBYTES_1024 + POLYCOMPRESSEDBYTES_1024
ENCAPSKEYBYTES_512 :: INDCPA_PUBLICKEYBYTES_512
ENCAPSKEYBYTES_768 :: INDCPA_PUBLICKEYBYTES_768
ENCAPSKEYBYTES_1024 :: INDCPA_PUBLICKEYBYTES_1024
DECAPSKEYBYTES_512 :: INDCPA_SECRETKEYBYTES_512 + INDCPA_PUBLICKEYBYTES_512 + 2 * SYMBYTES
DECAPSKEYBYTES_768 :: INDCPA_SECRETKEYBYTES_768 + INDCPA_PUBLICKEYBYTES_768 + 2 * SYMBYTES
DECAPSKEYBYTES_1024 :: INDCPA_SECRETKEYBYTES_1024 + INDCPA_PUBLICKEYBYTES_1024 + 2 * SYMBYTES
CIPHERTEXTBYTES_512 :: INDCPA_BYTES_512
CIPHERTEXTBYTES_768 :: INDCPA_BYTES_768
CIPHERTEXTBYTES_1024 :: INDCPA_BYTES_1024
CIPHERTEXTBYTES_MAX :: INDCPA_BYTES_1024

View File

@@ -0,0 +1,349 @@
#+private
package _mlkem
import "core:crypto"
import "core:crypto/shake"
@(require_results)
pack_pk :: proc "contextless" (r: []byte, pk: ^Polyvec, seed: []byte, k: int) -> bool {
pk_len := polyvec_byte_size(k)
switch {
case pk_len == 0:
return false
case len(seed) != SYMBYTES || len(r) != pk_len + SYMBYTES:
return false
}
polyvec_tobytes(r[:pk_len], pk, k)
copy(r[pk_len:], seed)
return true
}
@(require_results)
unpack_pk :: proc "contextless" (pk: ^Polyvec, seed, packedpk: []byte) -> bool {
pk_len := len(packedpk) - SYMBYTES
k: int
switch {
case pk_len == POLYVECBYTES_512:
k = K_512
case pk_len == POLYVECBYTES_768:
k = K_768
case pk_len == POLYVECBYTES_1024:
k = K_1024
case len(packedpk) - pk_len != SYMBYTES:
return false
case len(seed) != SYMBYTES:
return false
}
if k == 0 {
return false
}
ok := polyvec_frombytes(pk, packedpk[:pk_len], k)
copy(seed, packedpk[pk_len:])
return ok
}
@(require_results)
pack_sk :: proc "contextless" (r: []byte, sk: ^Polyvec, k: int) -> bool {
r_len := len(r)
if r_len == 0 || r_len != polyvec_byte_size(k) {
return false
}
polyvec_tobytes(r, sk, k)
return true
}
@(require_results)
unpack_sk :: proc "contextless" (sk: ^Polyvec, packedsk: []byte) -> bool {
k: int
switch len(packedsk) {
case POLYVECBYTES_512:
k = K_512
case POLYVECBYTES_768:
k = K_768
case POLYVECBYTES_1024:
k = K_1024
case:
return false
}
if k == 0 {
return false
}
return polyvec_frombytes(sk, packedsk, k)
}
@(require_results)
pack_ciphertext :: proc "contextless" (r: []byte, b: ^Polyvec, v: ^Poly, k: int) -> bool {
b_len := polyvec_compressed_byte_size(k)
if len(r) != b_len + poly_compressed_bytes(k) {
return false
}
polyvec_compress(r[:b_len], b, k)
poly_compress(r[b_len:], v)
return true
}
@(require_results)
unpack_ciphertext :: proc "contextless" (b: ^Polyvec, v: ^Poly, c: []byte) -> int {
b_len: int
k: int
switch len(c) {
case INDCPA_BYTES_512:
b_len = POLYVECCOMPRESSEDBYTES_512
k = K_512
case INDCPA_BYTES_768:
b_len = POLYVECCOMPRESSEDBYTES_768
k = K_768
case INDCPA_BYTES_1024:
b_len = POLYVECCOMPRESSEDBYTES_1024
k = K_1024
case:
return 0
}
polyvec_decompress(b, c[:b_len], k)
poly_decompress(v, c[b_len:])
return k
}
@(require_results)
rej_uniform :: proc "contextless" (r: []i16, buf: []byte) -> int {
r_len, b_len := len(r), len(buf)
ctr, pos: int
for ctr < r_len && pos + 3 <= b_len {
val0 := (u16(buf[pos+0] >> 0) | (u16(buf[pos+1]) << 8)) & 0xFFF
val1 := (u16(buf[pos+1] >> 4) | (u16(buf[pos+2]) << 4)) & 0xFFF
pos += 3
if val0 < Q {
r[ctr] = i16(val0)
ctr += 1
}
if(ctr < r_len && val1 < Q) {
r[ctr] = i16(val1)
ctr += 1
}
}
return ctr
}
gen_matrix :: proc(a: []Polyvec, seed: []byte, transposed: bool, k: int) {
GEN_MATRIX_NBLOCKS :: ((12*N/8*(1 << 12)/Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES)
buf: [GEN_MATRIX_NBLOCKS*XOF_BLOCKBYTES]byte = ---
ctx: shake.Context = ---
ctr: int
defer shake.reset(&ctx)
defer crypto.zero_explicit(&buf, size_of(buf))
for i in 0..<k {
for j in 0..<k {
switch transposed {
case true:
xof_absorb(&ctx, seed, byte(i), byte(j))
case false:
xof_absorb(&ctx, seed, byte(j), byte(i))
}
shake.read(&ctx, buf[:])
ctr = rej_uniform(a[i].vec[j].coeffs[:], buf[:])
b := buf[:XOF_BLOCKBYTES]
for ctr < N {
shake.read(&ctx, b)
ctr += rej_uniform(a[i].vec[j].coeffs[ctr:], b)
}
}
}
}
K_PKE_Decryption_Key :: struct {
pv: Polyvec,
k: int,
}
K_PKE_Encryption_Key :: struct {
pv: Polyvec,
p: [SYMBYTES]byte,
k: int,
}
k_pke_encryption_key_set :: proc(dst, src: ^K_PKE_Encryption_Key) {
k_pke_key_clear(dst)
for i in 0..<src.k {
copy(dst.pv.vec[i].coeffs[:], src.pv.vec[i].coeffs[:])
}
copy(dst.p[:], src.p[:])
dst.k = src.k
}
k_pke_key_clear :: proc(k: $T) where T == ^K_PKE_Encryption_Key || T == ^K_PKE_Decryption_Key {
crypto.zero_explicit(k, size_of(k^))
}
k_pke_keygen :: proc(
ek: ^K_PKE_Encryption_Key,
dk: ^K_PKE_Decryption_Key,
d: []byte,
k: int,
) {
assert(len(d) == SYMBYTES, "crypto/mlkem: invalid K-PKE d")
ensure(k == K_512 || k == K_768 || k == K_1024, "crypto/mlkem: invalid k")
buf: [2*SYMBYTES]byte = ---
defer crypto.zero_explicit(&buf, size_of(buf))
a_: [K_MAX]Polyvec = ---
e: Polyvec = ---
defer crypto.zero_explicit(&a_, size_of(Polyvec) * k)
defer polyvec_clear(&e)
a := a_[:k]
copy(buf[:], d)
buf[SYMBYTES] = byte(k)
hash_g(buf[:], buf[:SYMBYTES+1])
p, sigma := buf[:SYMBYTES], buf[SYMBYTES:]
gen_matrix(a, p, false, k)
n := byte(0)
for i in 0..<k {
if k != K_512 {
poly_getnoise_eta1(&dk.pv.vec[i], sigma, n)
} else {
poly_getnoise_eta1_512(&dk.pv.vec[i], sigma, n)
}
n += 1
}
for i in 0..<k {
if k != K_512 {
poly_getnoise_eta1(&e.vec[i], sigma, n)
} else {
poly_getnoise_eta1_512(&e.vec[i], sigma, n)
}
n += 1
}
polyvec_ntt(&dk.pv, k)
polyvec_ntt(&e, k)
for i in 0..<k {
polyvec_basemul_acc_montgomery(&ek.pv.vec[i], &a[i], &dk.pv, k)
poly_tomont(&ek.pv.vec[i])
}
polyvec_add(&ek.pv, &ek.pv, &e, k)
polyvec_reduce(&ek.pv, k)
copy(ek.p[:], p)
dk.k = k
ek.k = k
}
@(require_results)
k_pke_encrypt :: proc(
ciphertext: []byte,
ek: ^K_PKE_Encryption_Key,
m: []byte,
r: []byte,
) -> bool {
ensure(len(m) == INDCPA_MSGBYTES, "crypto/mlkem: invalid K-PKE m")
ensure(len(r) == SYMBYTES, "crypto/mlkem: invalid K-PKE r")
k := ek.k
at_: [K_MAX]Polyvec = ---
sp, ep, b: Polyvec = ---, ---, ---
kay, epp, v: Poly = ---, ---, ---
defer crypto.zero_explicit(&at_, size_of(Polyvec) * k)
defer polyvec_clear(&sp, &ep, &b)
defer poly_clear(&kay, &epp, &v)
poly_frommsg(&kay, m)
at := at_[:k]
gen_matrix(at, ek.p[:], true, k)
n := byte(0)
for i in 0..<k {
if k != K_512 {
poly_getnoise_eta1(&sp.vec[i], r, n)
} else {
poly_getnoise_eta1_512(&sp.vec[i], r, n)
}
n += 1
}
for i in 0..<k {
poly_getnoise_eta2(&ep.vec[i], r, n)
n += 1
}
poly_getnoise_eta2(&epp, r, n)
polyvec_ntt(&sp, k)
for i in 0..<k {
polyvec_basemul_acc_montgomery(&b.vec[i], &at[i], &sp, k)
}
polyvec_basemul_acc_montgomery(&v, &ek.pv, &sp, k)
polyvec_invntt_tomont(&b, k)
poly_invntt_tomont(&v)
polyvec_add(&b, &b, &ep, k)
poly_add(&v, &v, &epp)
poly_add(&v, &v, &kay)
polyvec_reduce(&b, k)
poly_reduce(&v)
return pack_ciphertext(ciphertext, &b, &v, k)
}
@(require_results)
k_pke_decrypt :: proc(
plaintext: []byte,
dk: ^K_PKE_Decryption_Key,
c: []byte,
) -> bool {
if len(plaintext) != INDCPA_MSGBYTES {
return false
}
k := dk.k
b: Polyvec = ---
v, mp: Poly = ---, ---
defer poly_clear(&v, &mp)
if unpack_ciphertext(&b, &v, c) != k {
return false
}
polyvec_ntt(&b, k)
polyvec_basemul_acc_montgomery(&mp, &dk.pv, &b, k)
poly_invntt_tomont(&mp)
poly_sub(&mp, &v, &mp)
poly_reduce(&mp)
poly_tomsg(plaintext, &mp)
return true
}

View File

@@ -0,0 +1,195 @@
package _mlkem
import "core:crypto"
import subtle "core:crypto/_subtle"
// This implementation is derived from the PQ-CRYSTALS reference
// implementation [[ https://github.com/pq-crystals/kyber ]],
// primarily for licensing reasons. Arguably mlkem-native is
// a more "up to date" codebase, but the changes to the
// ref code is minor and they slapped an attribution-required
// license on something that was originally CC-0/Apache 2.0.
// "Private Key"
Decapsulation_Key :: struct {
pke_dk: K_PKE_Decryption_Key,
ek: Encapsulation_Key,
seed: [SYMBYTES*2]byte, // (d, z)
}
// "Public Key"
Encapsulation_Key :: struct {
pke_ek: K_PKE_Encryption_Key,
raw_bytes: [INDCPA_PUBLICKEYBYTES_MAX]byte,
h: [SYMBYTES]byte,
}
decapsulation_key_expanded_bytes :: proc(
dk: ^Decapsulation_Key,
dst: []byte,
) {
sk := &dk.pke_dk
pv_len := polyvec_byte_size(sk.k)
ek_len := pv_len + SYMBYTES
ek_bytes := dk.ek.raw_bytes[:ek_len]
dst := dst
_ = pack_sk(dst[:pv_len], &sk.pv, sk.k)
dst = dst[pv_len:]
copy(dst, ek_bytes)
dst = dst[ek_len:]
hash_h(dst[:SYMBYTES], ek_bytes)
dst = dst[SYMBYTES:]
copy(dst, dk.seed[SYMBYTES:])
}
@(require_results)
encapsulation_key_set_bytes :: proc(
ek: ^Encapsulation_Key,
k: int,
b: []byte,
) -> bool {
k_len: int
switch k {
case K_512:
k_len = ENCAPSKEYBYTES_512
case K_768:
k_len = ENCAPSKEYBYTES_768
case K_1024:
k_len = ENCAPSKEYBYTES_1024
case:
return false
}
if len(b) != k_len {
return false
}
pke_ek := &ek.pke_ek
ok := unpack_pk(&pke_ek.pv, pke_ek.p[:], b)
pke_ek.k = k
copy(ek.raw_bytes[:k_len], b)
hash_h(ek.h[:], b)
// FIPS 203 unlike Kyber requires canonical encoding of
// encapsulation keys (Section 7,2), which is checked in
// unpack_pk.
if !ok {
crypto.zero_explicit(ek, size_of(Encapsulation_Key))
}
return ok
}
encapsulation_key_set_decaps :: proc(ek: ^Encapsulation_Key, dk: ^Decapsulation_Key) {
dk_ek := &dk.ek.pke_ek
ensure(dk_ek.k == K_512 || dk_ek.k == K_768 || dk_ek.k == K_1024, "crypto/mlkem: invalid decaps k")
k_pke_encryption_key_set(&ek.pke_ek, dk_ek)
copy(ek.raw_bytes[:], dk.ek.raw_bytes[:])
copy(ek.h[:], dk.ek.h[:])
}
// NIST's version of this also returns an encapsulation key, but our
// internal representation includes it as part of the decapsulation key
// in a more traditional "keypair" approach.
kem_keygen_internal :: proc(
dk: ^Decapsulation_Key,
seed: []byte, // (d, z)
k: int,
) {
ensure(len(seed) == 2 * SYMBYTES, "crypto/mlkem: invalid seed")
dk_ek := &dk.ek
d, z := seed[:SYMBYTES], seed[SYMBYTES:]
k_pke_keygen(&dk_ek.pke_ek, &dk.pke_dk, d, k)
ek_len := polyvec_byte_size(k) + SYMBYTES
ek_bytes := dk_ek.raw_bytes[:ek_len]
ensure(
pack_pk(ek_bytes, &dk_ek.pke_ek.pv, dk_ek.pke_ek.p[:], k),
"crypto/mlkem: failed to pack K-PKE ek",
)
hash_h(dk_ek.h[:], ek_bytes)
copy(dk.seed[:SYMBYTES], d)
copy(dk.seed[SYMBYTES:], z)
}
// The `_internal` "de-randomized" versions of ML-KEM.Encaps and
// ML-KEM.Decaps are only ever to be called by the actual non-interal
// implementation or test cases.
kem_encaps_internal :: proc(
shared_secret: []byte,
ciphertext: []byte,
ek: ^Encapsulation_Key,
randomness: []byte,
) {
ensure(len(shared_secret) == SYMBYTES, "crypto/mlkem: invalid K")
ensure(len(randomness) == SYMBYTES, "crypto/mlkem: invalid m")
ensure(
len(ciphertext) == ct_len_for_k(ek.pke_ek.k),
"crypto/mlkem: invalid ciphertext length",
)
buf: [2*SYMBYTES]byte = ---
defer crypto.zero_explicit(&buf, size_of(buf))
hash_g(buf[:], randomness, ek.h[:])
// Can't fail, ciphertext length is valid.
_ = k_pke_encrypt(ciphertext, &ek.pke_ek, randomness, buf[SYMBYTES:])
copy(shared_secret, buf[:SYMBYTES])
}
kem_decaps_internal :: proc(
shared_secret: []byte,
dk: ^Decapsulation_Key,
ciphertext: []byte,
) {
ct_len := ct_len_for_k(dk.pke_dk.k)
ensure(
len(ciphertext) == ct_len,
"crypto/mlkem: invalid ciphertext length",
)
m_: [SYMBYTES]byte
defer crypto.zero_explicit(&m_, size_of(m_))
// Can't fail, ciphertext length is valid.
_ = k_pke_decrypt(m_[:], &dk.pke_dk, ciphertext)
buf: [2*SYMBYTES]byte = ---
defer crypto.zero_explicit(&buf, size_of(buf))
ek := &dk.ek
hash_g(buf[:], m_[:], ek.h[:])
rkprf(shared_secret, dk.seed[SYMBYTES:], ciphertext)
ct_buf: [CIPHERTEXTBYTES_MAX]byte = ---
defer crypto.zero_explicit(&ct_buf, size_of(ct_buf))
ct_ := ct_buf[:ct_len]
_ = k_pke_encrypt(ct_, &ek.pke_ek, m_[:], buf[SYMBYTES:])
ok := crypto.compare_constant_time(ciphertext, ct_)
subtle.cmov_bytes(shared_secret, buf[:SYMBYTES], ok)
}
@(private="file")
ct_len_for_k :: proc(k: int) -> int {
switch k {
case K_512:
return CIPHERTEXTBYTES_512
case K_768:
return CIPHERTEXTBYTES_768
case K_1024:
return CIPHERTEXTBYTES_1024
case:
panic("crypto/mlkem: invalid k for ciphertext length")
}
}

View File

@@ -0,0 +1,75 @@
#+private
package _mlkem
@(rodata)
ZETAS := [128]i16 {
-1044, -758, -359, -1517, 1493, 1422, 287, 202,
-171, 622, 1577, 182, 962, -1202, -1474, 1468,
573, -1325, 264, 383, -829, 1458, -1602, -130,
-681, 1017, 732, 608, -1542, 411, -205, -1571,
1223, 652, -552, 1015, -1293, 1491, -282, -1544,
516, -8, -320, -666, -1618, -1162, 126, 1469,
-853, -90, -271, 830, 107, -1421, -247, -951,
-398, 961, -1508, -725, 448, -1065, 677, -1275,
-1103, 430, 555, 843, -1251, 871, 1550, 105,
422, 587, 177, -235, -291, -460, 1574, 1653,
-246, 778, 1159, -147, -777, 1483, -602, 1119,
-1590, 644, -872, 349, 418, 329, -156, -75,
817, 1097, 603, 610, 1322, -1285, -1465, 384,
-1215, -136, 1218, -1335, -874, 220, -1187, -1659,
-1185, -1530, -1278, 794, -1510, -854, -870, 478,
-108, -308, 996, 991, 958, -1460, 1522, 1628,
}
@(require_results)
fqmul :: #force_inline proc "contextless" (a, b: i16) -> i16 {
return montgomery_reduce(i32(a) * i32(b))
}
ntt :: proc "contextless" (r: ^[N]i16) #no_bounds_check {
j, k := 0, 1
for l := 128; l >= 2; l >>= 1 {
for start := 0; start < N; start = j + l {
zeta := ZETAS[k]
k += 1
for j = start; j < start + l; j += 1 {
t := fqmul(zeta, r[j+l])
r[j+l] = r[j] - t
r[j] = r[j] + t
}
}
}
}
invntt :: proc "contextless" (r: ^[N]i16) #no_bounds_check {
F : i16 : 1441 // mont^2/128
j, k := 0, 127
for l := 2; l <= 128; l <<= 1 {
for start := 0; start < 256; start = j+l {
zeta := ZETAS[k]
k -= 1
for j = start; j < start + l; j += 1 {
t := r[j]
r[j] = barrett_reduce(t + r[j+l])
r[j+l] = r[j+l] - t
r[j+l] = fqmul(zeta, r[j+l])
}
}
}
for v, i in r {
r[i] = fqmul(v, F)
}
}
@(require_results)
base_case_multiply :: proc "contextless" (a_0, a_1, b_0, b_1, zeta: i16) -> (i16, i16) {
r_0 := fqmul(a_1, b_1)
r_0 = fqmul(r_0, zeta)
r_0 += fqmul(a_0, b_0)
r_1 := fqmul(a_0, b_1)
r_1 += fqmul(a_1, b_0)
return r_0, r_1
}

View File

@@ -0,0 +1,241 @@
#+private
package _mlkem
import "core:crypto"
import subtle "core:crypto/_subtle"
// Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial
// coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1]
Poly :: struct {
coeffs: [N]i16,
}
poly_compress :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check {
t: [8]byte = ---
defer crypto.zero_explicit(&t, size_of(t))
r := r
switch len(r) {
case POLYCOMPRESSEDBYTES_768: // Also covers _512
for i in 0..<N/8 {
for j in 0..<8 {
// map to positive standard representatives
u := a.coeffs[8*i+j]
u += (u >> 15) & Q
// t[j] = ((((uint16_t)u << 4) + Q/2)/Q) & 15
d0 := u32(u) << 4
d0 += 1665
d0 *= 80635
d0 >>= 28
t[j] = byte(d0) & 0xf
}
r[0] = t[0] | (t[1] << 4)
r[1] = t[2] | (t[3] << 4)
r[2] = t[4] | (t[5] << 4)
r[3] = t[6] | (t[7] << 4)
r = r[4:]
}
case POLYCOMPRESSEDBYTES_1024:
for i in 0..<N/8 {
for j in 0..<8 {
// map to positive standard representatives
u := a.coeffs[8*i+j]
u += (u >> 15) & Q
// t[j] = ((((uint16_t)u << 5) + Q/2)/Q) & 31
d0 := u32(u) << 5
d0 += 1664
d0 *= 40318
d0 >>= 27
t[j] = byte(d0) & 0x1f
}
r[0] = (t[0] >> 0) | (t[1] << 5)
r[1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7)
r[2] = (t[3] >> 1) | (t[4] << 4)
r[3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6)
r[4] = (t[6] >> 2) | (t[7] << 3)
r = r[5:]
}
case:
panic_contextless("crypto/mlkem: invalid POLYCOMPRESSEDBYTES")
}
}
poly_decompress :: proc "contextless" (r: ^Poly, a: []byte) {
a := a
switch len(a) {
case POLYCOMPRESSEDBYTES_768: // Also covers _512
for i in 0..<N/2 {
r.coeffs[2*i+0] = i16(((u16(a[0] & 15) * Q) + 8) >> 4)
r.coeffs[2*i+1] = i16(((u16(a[0] >> 4) * Q) + 8) >> 4)
a = a[1:]
}
case POLYCOMPRESSEDBYTES_1024:
t: [8]byte = ---
defer crypto.zero_explicit(&t, size_of(t))
for i in 0..<N/8 {
t[0] = (a[0] >> 0)
t[1] = (a[0] >> 5) | (a[1] << 3)
t[2] = (a[1] >> 2)
t[3] = (a[1] >> 7) | (a[2] << 1)
t[4] = (a[2] >> 4) | (a[3] << 4)
t[5] = (a[3] >> 1)
t[6] = (a[3] >> 6) | (a[4] << 2)
t[7] = (a[4] >> 3)
a = a[5:]
for j in 0..<8 {
r.coeffs[8*i+j] = i16((u32(t[j] & 31) * Q + 16) >> 5)
}
}
case:
panic_contextless("crypto/mlkem: invalid POLYCOMPRESSEDBYTES")
}
}
poly_tobytes :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check {
ensure_contextless(len(r) >= POLYBYTES)
for i in 0..<N/2 {
// map to positive standard representatives
t0 := u16(a.coeffs[2*i])
t0 += u16((i16(t0) >> 15) & Q)
t1 := u16(a.coeffs[2*i+1])
t1 += u16((i16(t1) >> 15) & Q)
r[3*i+0] = byte(t0 >> 0)
r[3*i+1] = byte(t0 >> 8) | byte(t1 << 4)
r[3*i+2] = byte(t1 >> 4)
}
}
@(require_results)
poly_frombytes :: proc "contextless" (r: ^Poly, a: []byte) -> bool #no_bounds_check {
ensure_contextless(len(a) >= POLYBYTES)
ok := true
for i in 0..<N/2 {
r.coeffs[2*i] = i16(((u16(a[3*i+0]) >> 0) | (u16(a[3*i+1]) << 8)) & 0xFFF)
r.coeffs[2*i+1] = i16(((u16(a[3*i+1]) >> 4) | (u16(a[3*i+2]) << 4)) & 0xFFF)
ok &= r.coeffs[2*i] < Q && r.coeffs[2*i+1] < Q
}
return ok
}
poly_frommsg :: proc "contextless" (r: ^Poly, msg: []byte) #no_bounds_check {
#assert(INDCPA_MSGBYTES == N/8)
ensure_contextless(len(msg) == INDCPA_MSGBYTES)
for i in 0..<N/8 {
for j in 0..<8 {
r.coeffs[8*i+j] = subtle.csel_i16(0, (Q+1)/2, int(msg[i] >> uint(j))&1)
}
}
}
poly_tomsg :: proc "contextless" (msg: []byte, a: ^Poly) #no_bounds_check {
ensure_contextless(len(msg) == INDCPA_MSGBYTES)
for i in 0..<N/8 {
msg[i] = 0
for j in uint(0)..<8 {
t := u32(a.coeffs[8*i+int(j)])
// t += ((int16_t)t >> 15) & Q
// t = (((t << 1) + Q/2)/Q) & 1
t <<= 1
t += 1665
t *= 80635
t >>= 28
t &= 1
msg[i] |= byte(t << j)
}
}
}
poly_getnoise_eta1_512 :: proc(r: ^Poly, seed: []byte, iv: byte) {
buf: [ETA1_512*N/4]byte = ---
defer crypto.zero_explicit(&buf, size_of(buf))
prf(buf[:], seed, iv)
poly_cbd_eta1_512(r, &buf)
}
poly_getnoise_eta1 :: proc(r: ^Poly, seed: []byte, iv: byte) {
buf: [ETA1*N/4]byte = ---
defer crypto.zero_explicit(&buf, size_of(buf))
prf(buf[:], seed, iv)
poly_cbd_eta1(r, &buf)
}
poly_getnoise_eta2 :: proc(r: ^Poly, seed: []byte, iv: byte) {
buf: [ETA2*N/4]byte = ---
defer crypto.zero_explicit(&buf, size_of(buf))
prf(buf[:], seed, iv)
poly_cbd_eta2(r, &buf)
}
poly_ntt :: proc "contextless" (r: ^Poly) {
ntt(&r.coeffs)
poly_reduce(r)
}
poly_invntt_tomont :: proc "contextless" (r: ^Poly) {
invntt(&r.coeffs)
}
poly_basemul_montgomery :: proc "contextless" (r, a, b: ^Poly) #no_bounds_check {
for i in 0..<N/4 {
j := 4 * i
r.coeffs[j], r.coeffs[j+1] = base_case_multiply(a.coeffs[j], a.coeffs[j+1], b.coeffs[j], b.coeffs[j+1], ZETAS[64+i])
r.coeffs[j+2], r.coeffs[j+3] = base_case_multiply(a.coeffs[j+2], a.coeffs[j+3], b.coeffs[j+2], b.coeffs[j+3], -ZETAS[64+i])
}
}
poly_tomont :: proc "contextless" (r: ^Poly) {
F : i16 : (1 << 32) % Q
for v, i in r.coeffs {
r.coeffs[i] = montgomery_reduce(i32(v)*i32(F))
}
}
poly_reduce :: proc "contextless" (r: ^Poly) {
for v, i in r.coeffs {
r.coeffs[i] = barrett_reduce(v)
}
}
poly_add :: proc "contextless" (r, a, b: ^Poly) {
for i in 0..<N {
r.coeffs[i] = a.coeffs[i] + b.coeffs[i]
}
}
poly_sub :: proc "contextless" (r, a, b: ^Poly) {
for i in 0..<N {
r.coeffs[i] = a.coeffs[i] - b.coeffs[i]
}
}
poly_clear :: proc "contextless" (a: ..^Poly) {
for j in 0..<len(a) {
p := a[j]
crypto.zero_explicit(p, size_of(Poly))
}
}
poly_compressed_bytes :: #force_inline proc "contextless" (k: int) -> int {
switch k {
case K_512:
return POLYCOMPRESSEDBYTES_512
case K_768:
return POLYCOMPRESSEDBYTES_768
case K_1024:
return POLYCOMPRESSEDBYTES_1024
case:
panic_contextless("crypto/mlkem: invalid k")
}
}

View File

@@ -0,0 +1,224 @@
#+private
package _mlkem
import "core:crypto"
Polyvec :: struct {
vec: [K_MAX]Poly,
}
polyvec_compress :: proc "contextless" (r: []byte, a: ^Polyvec, kay: int) #no_bounds_check {
d0: u64
r := r
switch len(r) {
case POLYVECCOMPRESSEDBYTES_512, POLYVECCOMPRESSEDBYTES_768:
ensure_contextless(kay == K_512 || kay == K_768)
t: [4]u16 = ---
defer crypto.zero_explicit(&t, size_of(t))
for i in 0..<kay {
for j in 0..<N/4 {
for k in 0..<4 {
t[k] = u16(a.vec[i].coeffs[4*j+k])
t[k] += u16((i16(t[k]) >> 15) & Q)
// t[k] = ((((uint32_t)t[k] << 10) + Q/2)/Q) & 0x3ff
d0 = u64(t[k])
d0 <<= 10
d0 += 1665
d0 *= 1290167
d0 >>= 32
t[k] = u16(d0 & 0x3ff)
}
r[0] = byte(t[0] >> 0)
r[1] = byte((t[0] >> 8) | (t[1] << 2))
r[2] = byte((t[1] >> 6) | (t[2] << 4))
r[3] = byte((t[2] >> 4) | (t[3] << 6))
r[4] = byte(t[3] >> 2)
r = r[5:]
}
}
case POLYVECCOMPRESSEDBYTES_1024:
ensure_contextless(kay == K_1024)
t: [8]u16 = ---
defer crypto.zero_explicit(&t, size_of(t))
for i in 0..<K_1024 {
for j in 0..<N/8 {
for k in 0..<8 {
t[k] = u16(a.vec[i].coeffs[8*j+k])
t[k] += u16((i16(t[k]) >> 15) & Q)
// t[k] = ((((uint32_t)t[k] << 11) + Q/2)/Q) & 0x7ff
d0 = u64(t[k])
d0 <<= 11
d0 += 1664
d0 *= 645084
d0 >>= 31
t[k] = u16(d0 & 0x7ff)
}
r[0] = byte(t[0] >> 0)
r[1] = byte((t[0] >> 8) | (t[1] << 3))
r[2] = byte((t[1] >> 5) | (t[2] << 6))
r[3] = byte(t[2] >> 2)
r[4] = byte((t[2] >> 10) | (t[3] << 1))
r[5] = byte((t[3] >> 7) | (t[4] << 4))
r[6] = byte((t[4] >> 4) | (t[5] << 7))
r[7] = byte(t[5] >> 1)
r[8] = byte((t[5] >> 9) | (t[6] << 2))
r[9] = byte((t[6] >> 6) | (t[7] << 5))
r[10] = byte(t[7] >> 3)
r = r[11:]
}
}
case:
panic_contextless("crypto/mlkem: invalid POLYVECCOMPRESSEDBYTES")
}
}
polyvec_decompress :: proc "contextless" (r: ^Polyvec, a: []byte, kay: int) #no_bounds_check {
a := a
switch len(a) {
case POLYVECCOMPRESSEDBYTES_512, POLYVECCOMPRESSEDBYTES_768:
ensure_contextless(kay == K_512 || kay == K_768)
t: [4]u16 = ---
defer crypto.zero_explicit(&t, size_of(t))
for i in 0..<kay {
for j in 0..<N/4 {
t[0] = u16(a[0] >> 0) | (u16(a[1]) << 8)
t[1] = u16(a[1] >> 2) | (u16(a[2]) << 6)
t[2] = u16(a[2] >> 4) | (u16(a[3]) << 4)
t[3] = u16(a[3] >> 6) | (u16(a[4]) << 2)
a = a[5:]
for k in 0..<4 {
r.vec[i].coeffs[4*j+k] = i16((u32(t[k] & 0x3FF) * Q + 512) >> 10)
}
}
}
case POLYVECCOMPRESSEDBYTES_1024:
t: [8]u16 = ---
defer crypto.zero_explicit(&t, size_of(t))
for i in 0..<K_1024 {
for j in 0..<N/8 {
t[0] = u16(a[0] >> 0) | (u16(a[1]) << 8)
t[1] = u16(a[1] >> 3) | (u16(a[2]) << 5)
t[2] = u16(a[2] >> 6) | (u16(a[3]) << 2) | (u16(a[4]) << 10)
t[3] = u16(a[4] >> 1) | (u16(a[5]) << 7)
t[4] = u16(a[5] >> 4) | (u16(a[6]) << 4)
t[5] = u16(a[6] >> 7) | (u16(a[7]) << 1) | (u16(a[8]) << 9)
t[6] = u16(a[8] >> 2) | (u16(a[9]) << 6)
t[7] = u16(a[9] >> 5) | (u16(a[10]) << 3)
a = a[11:]
for k in 0..<8 {
r.vec[i].coeffs[8*j+k] = i16((u32(t[k] & 0x7FF) * Q + 1024) >> 11)
}
}
}
case:
panic_contextless("crypto/mlkem: invalid POLYVECCOMPRESSEDBYTES")
}
}
polyvec_tobytes :: proc "contextless" (r: []byte, a: ^Polyvec, k: int) #no_bounds_check {
ensure_contextless(len(r) == k * POLYBYTES, "crypto/mlkem: invalid buffer")
r := r
for i in 0..<k {
poly_tobytes(r, &a.vec[i])
r = r[POLYBYTES:]
}
}
@(require_results)
polyvec_frombytes :: proc "contextless" (r: ^Polyvec, a: []byte, k: int) -> bool #no_bounds_check {
switch k {
case K_512, K_768, K_1024:
case:
panic_contextless("crypto/mlkem: invalid POLYVECBYTES")
}
ensure_contextless(len(a) == k * POLYBYTES, "crypto/mlkem: invalid buffer")
a := a
ok := true
for i in 0..<k {
ok &= poly_frombytes(&r.vec[i], a)
a = a[POLYBYTES:]
}
return ok
}
@(require_results)
polyvec_byte_size :: #force_inline proc "contextless" (k: int) -> int {
switch k {
case K_512, K_768, K_1024:
return k * POLYBYTES
case:
return 0
}
}
@(require_results)
polyvec_compressed_byte_size :: #force_inline proc "contextless" (k: int) -> int {
switch k {
case K_512:
return POLYVECCOMPRESSEDBYTES_512
case K_768:
return POLYVECCOMPRESSEDBYTES_768
case K_1024:
return POLYVECCOMPRESSEDBYTES_1024
case:
return 0
}
}
polyvec_ntt :: proc "contextless" (r: ^Polyvec, k: int) {
for i in 0..<k {
poly_ntt(&r.vec[i])
}
}
polyvec_invntt_tomont :: proc "contextless" (r: ^Polyvec, k: int) {
for i in 0..<k {
poly_invntt_tomont(&r.vec[i])
}
}
polyvec_basemul_acc_montgomery :: proc "contextless" (r: ^Poly, a, b: ^Polyvec, k: int) {
t: Poly = ---
defer crypto.zero_explicit(&t, size_of(t))
poly_basemul_montgomery(r, &a.vec[0], &b.vec[0])
for i in 1..<k {
poly_basemul_montgomery(&t, &a.vec[i], &b.vec[i])
poly_add(r, r, &t)
}
poly_reduce(r)
}
polyvec_reduce :: proc "contextless" (r: ^Polyvec, k: int) {
for i in 0..<k {
poly_reduce(&r.vec[i])
}
}
polyvec_add :: proc "contextless" (r, a, b: ^Polyvec, k: int) {
for i in 0..<k {
poly_add(&r.vec[i], &a.vec[i], &b.vec[i])
}
}
polyvec_clear :: proc "contextless" (rs: ..^Polyvec) {
for j in 0..<len(rs) {
r := rs[j]
crypto.zero_explicit(r, size_of(Polyvec))
}
}

View File

@@ -0,0 +1,19 @@
#+private
package _mlkem
@(require_results)
montgomery_reduce :: #force_inline proc "contextless" (a: i32) -> i16 {
QINV :: -3327 // q^-1 mod 2^16
t := i16(a) * QINV
return i16((a - i32(t) * Q) >> 16)
}
@(require_results)
barrett_reduce :: #force_inline proc "contextless" (a: i16) -> i16 {
V : i16 : ((1<<26) + Q / 2) / Q
t := i16((i32(V)*i32(a) + (1<<25)) >> 26)
t *= Q
return a - t
}

View File

@@ -0,0 +1,61 @@
#+private
package _mlkem
import "core:crypto"
import "core:crypto/_sha3"
import "core:crypto/sha3"
import "core:crypto/shake"
XOF_BLOCKBYTES :: _sha3.RATE_128
#assert(XOF_BLOCKBYTES % 3 == 0)
prf :: proc(out, key: []byte, iv: byte) {
ctx: shake.Context = ---
defer shake.reset(&ctx)
shake.init_256(&ctx)
shake.write(&ctx, key)
shake.write(&ctx, []byte{iv})
shake.read(&ctx, out)
}
rkprf :: proc(out, key, input: []byte) {
ctx: shake.Context = ---
defer shake.reset(&ctx)
shake.init_256(&ctx)
shake.write(&ctx, key)
shake.write(&ctx, input)
shake.read(&ctx, out)
}
xof_absorb :: proc(ctx: ^shake.Context, seed: []byte, x, y: byte) {
shake.init_128(ctx)
extseed: [SYMBYTES+2]byte = ---
defer crypto.zero_explicit(&extseed, size_of(extseed))
copy(extseed[:], seed)
extseed[SYMBYTES+0] = x
extseed[SYMBYTES+1] = y
shake.write(ctx, extseed[:])
}
hash_h :: proc(dst, src: []byte) {
ctx: sha3.Context = ---
sha3.init_256(&ctx)
sha3.update(&ctx, src)
sha3.final(&ctx, dst)
}
hash_g :: proc(dst: []byte, srcs: ..[]byte) {
ctx: sha3.Context = ---
sha3.init_512(&ctx)
for src in srcs {
sha3.update(&ctx, src)
}
sha3.final(&ctx, dst)
}

View File

@@ -3,6 +3,7 @@ Various useful bit operations in constant time.
*/
package _subtle
import "core:crypto/_fiat"
import "core:math/bits"
// byte_eq returns 1 if and only if (⟺) a == b, 0 otherwise.
@@ -40,3 +41,41 @@ u64_is_non_zero :: proc "contextless" (a: u64) -> u64 {
is_zero := u64_is_zero(a)
return (~is_zero) & 1
}
@(optimization_mode="none")
cmov_bytes :: proc "contextless" (dst, src: []byte, ctrl: int) {
s_len := len(src)
ensure_contextless(s_len == len(dst), "crypto: cmov length mismatch")
c := -(byte)(ctrl)
for i in 0..<s_len {
dst[i] ~= c & (dst[i] ~ src[i])
}
}
@(optimization_mode="none")
csel_i16 :: proc "contextless" (a, b: i16, ctrl: int) -> i16 {
c := -(u16)(ctrl)
return a ~ i16(c & u16(a ~ b))
}
@(optimization_mode="none")
csel_u16 :: proc "contextless" (a, b: u16, ctrl: int) -> u16 {
c := -(u16)(ctrl)
return a ~ (c & (a ~ b))
}
csel_u32 :: proc "contextless" (a, b: u32, ctrl: int) -> u32 {
return _fiat.cmovznz_u32(_fiat.u1(ctrl), a, b)
}
csel_u64 :: proc "contextless" (a, b: u64, ctrl: int) -> u64 {
return _fiat.cmovznz_u64(_fiat.u1(ctrl), a, b)
}
csel :: proc {
csel_i16,
csel_u16,
csel_u32,
csel_u64,
}

View File

@@ -85,18 +85,6 @@ zero_explicit :: proc "contextless" (data: rawptr, len: int) -> rawptr {
return data
}
/*
Set each byte of a memory range to a specific value.
This procedure copies value specified by the `value` parameter into each of the
`len` bytes of a memory range, located at address `data`.
This procedure returns the pointer to `data`.
*/
set :: proc "contextless" (data: rawptr, value: byte, len: int) -> rawptr {
return runtime.memset(data, i32(value), len)
}
// rand_bytes fills the dst buffer with cryptographic entropy taken from
// the system entropy source. This routine will block if the system entropy
// source is not ready yet. All system entropy source failures are treated

View File

@@ -106,6 +106,7 @@ Public_Key :: struct {
// private_key_generate uses the system entropy source to generate a new
// Private_Key. This will only fail if and only if (⟺) the system entropy source is
// missing or broken.
@(require_results)
private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool {
private_key_clear(priv_key)
@@ -143,6 +144,7 @@ private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool {
// private_key_set_bytes decodes a byte-encoded private key, and returns
// true if and only if (⟺) the operation was successful.
@(require_results)
private_key_set_bytes :: proc(priv_key: ^Private_Key, curve: Curve, b: []byte) -> bool {
private_key_clear(priv_key)
@@ -281,6 +283,7 @@ private_key_bytes :: proc(priv_key: ^Private_Key, dst: []byte) {
// private_key_equal returns true if and only if (⟺) the private keys are equal,
// in constant time.
@(require_results)
private_key_equal :: proc(p, q: ^Private_Key) -> bool {
if p._curve != q._curve {
return false
@@ -311,6 +314,7 @@ private_key_clear :: proc "contextless" (priv_key: ^Private_Key) {
// public_key_set_bytes decodes a byte-encoded public key, and returns
// true if and only if (⟺) the operation was successful.
@(require_results)
public_key_set_bytes :: proc(pub_key: ^Public_Key, curve: Curve, b: []byte) -> bool {
public_key_clear(pub_key)
@@ -411,6 +415,7 @@ public_key_bytes :: proc(pub_key: ^Public_Key, dst: []byte) {
// public_key_equal returns true if and only if (⟺) the public keys are equal,
// in constant time.
@(require_results)
public_key_equal :: proc(p, q: ^Public_Key) -> bool {
if p._curve != q._curve {
return false
@@ -479,11 +484,13 @@ ecdh :: proc(priv_key: ^Private_Key, pub_key: ^Public_Key, dst: []byte) -> bool
}
// curve returns the Curve used by a Private_Key or Public_Key instance.
@(require_results)
curve :: proc(k: ^$T) -> Curve where(T == Private_Key || T == Public_Key) {
return k._curve
}
// key_size returns the key size of a Private_Key or Public_Key in bytes.
@(require_results)
key_size :: proc(k: ^$T) -> int where(T == Private_Key || T == Public_Key) {
when T == Private_Key {
return PRIVATE_KEY_SIZES[k._curve]
@@ -494,6 +501,7 @@ key_size :: proc(k: ^$T) -> int where(T == Private_Key || T == Public_Key) {
// shared_secret_size returns the shared secret size of a key exchange
// in bytes.
@(require_results)
shared_secret_size :: proc(k: ^$T) -> int where(T == Private_Key || T == Public_Key) {
return SHARED_SECRET_SIZES[k._curve]
}

View File

@@ -81,6 +81,7 @@ Public_Key :: struct {
// private_key_generate uses the system entropy source to generate a new
// Private_Key. This will only fail if and only if (⟺) the system entropy source is
// missing or broken.
@(require_results)
private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool {
private_key_clear(priv_key)
@@ -112,6 +113,7 @@ private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool {
// private_key_set_bytes decodes a byte-encoded private key, and returns
// true if and only if (⟺) the operation was successful.
@(require_results)
private_key_set_bytes :: proc(priv_key: ^Private_Key, curve: Curve, b: []byte) -> bool {
private_key_clear(priv_key)
@@ -222,6 +224,7 @@ private_key_set :: proc(priv_key, src: ^Private_Key) {
// private_key_equal returns true if and only if (⟺) the private keys are equal,
// in constant time.
@(require_results)
private_key_equal :: proc(p, q: ^Private_Key) -> bool {
if p._curve != q._curve {
return false
@@ -246,6 +249,7 @@ private_key_clear :: proc "contextless" (priv_key: ^Private_Key) {
// public_key_set_bytes decodes a byte-encoded public key, and returns
// true if and only if (⟺) the operation was successful.
@(require_results)
public_key_set_bytes :: proc(pub_key: ^Public_Key, curve: Curve, b: []byte) -> bool {
public_key_clear(pub_key)
@@ -334,6 +338,7 @@ public_key_bytes :: proc(pub_key: ^Public_Key, dst: []byte) {
// public_key_equal returns true if and only if (⟺) the public keys are equal,
// in constant time.
@(require_results)
public_key_equal :: proc(p, q: ^Public_Key) -> bool {
if p._curve != q._curve {
return false

View File

@@ -18,6 +18,7 @@ package md5
zhibog, dotbmp: Initial implementation.
*/
import "base:intrinsics"
import "core:crypto"
import "core:encoding/endian"
import "core:math/bits"
@@ -100,7 +101,7 @@ final :: proc(ctx: ^Context, hash: []byte, finalize_clone: bool = false) {
i += 1
}
transform(ctx, ctx.data[:])
crypto.set(&ctx.data, 0, 56)
intrinsics.mem_zero(&ctx.data, 56)
}
ctx.bitlen += u64(ctx.datalen * 8)

View File

@@ -19,6 +19,7 @@ package sha1
zhibog, dotbmp: Initial implementation.
*/
import "base:intrinsics"
import "core:crypto"
import "core:encoding/endian"
import "core:math/bits"
@@ -107,7 +108,7 @@ final :: proc(ctx: ^Context, hash: []byte, finalize_clone: bool = false) {
i += 1
}
transform(ctx, ctx.data[:])
crypto.set(&ctx.data, 0, 56)
intrinsics.mem_zero(&ctx.data, 56)
}
ctx.bitlen += u64(ctx.datalen * 8)

268
core/crypto/mlkem/api.odin Normal file
View File

@@ -0,0 +1,268 @@
package mlkem
import "core:crypto"
import "core:crypto/_mlkem"
// Parameters are the supported ML-KEM parameter sets.
Parameters :: enum {
Invalid,
ML_KEM_512,
ML_KEM_768,
ML_KEM_1024,
}
// DECAPSULATION_KEY_SEED_SIZE is the size of a Decapsulation key in bytes.
DECAPSULATION_KEY_SEED_SIZE :: 64 // (d, z) in NIST terms.
// DECAPSULATION_KEY_EXPANDED_SIZES are the per-parameter sizes of the
// decapsulation key in bytes.
DECAPSULATION_KEY_EXPANDED_SIZES := [Parameters]int {
.Invalid = 0,
.ML_KEM_512 = _mlkem.DECAPSKEYBYTES_512, // 1632-bytes
.ML_KEM_768 = _mlkem.DECAPSKEYBYTES_768, // 2400-bytes
.ML_KEM_1024 = _mlkem.DECAPSKEYBYTES_1024, // 3168-bytes
}
// ENCAPSULATION_KEY_SIZES are the per-parameter sizes of the encapsulation
// key in bytes.
ENCAPSULATION_KEY_SIZES := [Parameters]int {
.Invalid = 0,
.ML_KEM_512 = _mlkem.ENCAPSKEYBYTES_512, // 800-bytes
.ML_KEM_768 = _mlkem.ENCAPSKEYBYTES_768, // 1184-bytes
.ML_KEM_1024 = _mlkem.ENCAPSKEYBYTES_1024, // 1568-bytes
}
// CIPHERTEXT_SIZES are the per-parameter set sizes of the ciphertext
// in bytes.
CIPHERTEXT_SIZES := [Parameters]int {
.Invalid = 0,
.ML_KEM_512 = _mlkem.CIPHERTEXTBYTES_512, // 768-bytes
.ML_KEM_768 = _mlkem.CIPHERTEXTBYTES_768, // 1088-bytes
.ML_KEM_1024 = _mlkem.CIPHERTEXTBYTES_1024, // 1568-bytes
}
// SHARED_SECRET_SIZE is the size of the final shared secret in bytes.
SHARED_SECRET_SIZE :: 32
// Decapsulation_Key is a ML-KEM decapsulation (aka "private") key.
// This implementation opts to include the encapsulation (aka "public")
// key as well for cases where the decapsulation key is reused (eg: HPKE
// with X-Wing).
Decapsulation_Key :: _mlkem.Decapsulation_Key
// Encapsulation_Key is a ML-KEM encapsulation (aka "public") key.
Encapsulation_Key :: _mlkem.Encapsulation_Key
// decapsulation_key_generate uses the system entropy source to generate
// a decapsulation key. This will only fail if and only if (⟺) the system
// entropy source is missing or broken.
@(require_results)
decapsulation_key_generate :: proc(dk: ^Decapsulation_Key, params: Parameters) -> bool {
decapsulation_key_clear(dk)
if !crypto.HAS_RAND_BYTES {
return false
}
k := params_to_k(params)
if k == 0 {
panic("crypto/mlkem: invalid parameter set")
}
seed: [DECAPSULATION_KEY_SEED_SIZE]byte = ---
defer crypto.zero_explicit(&seed, size_of(seed))
crypto.rand_bytes(seed[:])
_mlkem.kem_keygen_internal(dk, seed[:], k)
return true
}
// decapsulation_key_set_bytes decodes a byte-encoded decapsulation key
// in (d, z) "seed" format, and returns true if and only if (⟺) the
// operation was successful.
@(require_results)
decapsulation_key_set_bytes :: proc(dk: ^Decapsulation_Key, params: Parameters, seed: []byte) -> bool {
k := params_to_k(params)
if k == 0 {
return false
}
if len(seed) != DECAPSULATION_KEY_SEED_SIZE {
return false
}
_mlkem.kem_keygen_internal(dk, seed, k)
return true
}
// decapsulation_key_bytes sets dst to byte-encoding of dk in the (d, z)
// "seed" format.
decapsulation_key_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) {
ensure(dk.pke_dk.k != 0, "crypto/mlkem: uninitialized Decapsulation_Key")
ensure(len(dst) == DECAPSULATION_KEY_SEED_SIZE, "crypto/mlkem: invalid destination size")
copy(dst, dk.seed[:])
}
// decapsulation_key_expanded_bytes sets dst to the byte-encoding of dk.
// in the expanded FIPS 203 format. This primarily exists for export
// purposes.
decapsulation_key_expanded_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) {
dk_len: int
switch dk.pke_dk.k {
case _mlkem.K_512:
dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_512]
case _mlkem.K_768:
dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_768]
case _mlkem.K_1024:
dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_1024]
case:
panic("crypto/mlkem: uninitialized Decapsulation_Key")
}
ensure(len(dst) == dk_len, "crypto/mlkem: invalid destination size")
_mlkem.decapsulation_key_expanded_bytes(dk, dst)
}
// decapsulation_key_encaps_bytes sets dst to the byte-encoding of the
// encasulation key corresponding to dk.
decapsulation_key_encaps_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) {
encapsulation_key_bytes(&dk.ek, dst)
}
// decapsulation_key_clear clears dk to the uninitialized state.
decapsulation_key_clear :: proc(dk: ^Decapsulation_Key) {
crypto.zero_explicit(dk, size_of(Decapsulation_Key))
}
// encapsulation_key_set_bytes decodes a byte-encoded encapsulation key,
// and returns true if and only if (⟺) the operation was successful.
@(require_results)
encapsulation_key_set_bytes :: proc(ek: ^Encapsulation_Key, params: Parameters, b: []byte) -> bool {
k := params_to_k(params)
if k == 0 {
return false
}
if len(b) != ENCAPSULATION_KEY_SIZES[params] {
return false
}
return _mlkem.encapsulation_key_set_bytes(ek, k, b)
}
// encapsulation_key_set_decaps sets ek to the encapsulation key corresponding
// to dk.
encapsulation_key_set_decaps :: proc(ek: ^Encapsulation_Key, dk: ^Decapsulation_Key) {
ensure(dk.pke_dk.k != 0, "crypto/mlkem: uninitialized Decapsulation_Key")
_mlkem.encapsulation_key_set_decaps(ek, dk)
}
// encapsulation_key_encaps_bytes sets dst to the byte-encoding of ek.
encapsulation_key_bytes :: proc(ek: ^Encapsulation_Key, dst: []byte) {
ensure(ek.pke_ek.k != 0, "crypto/mlkem: uninitialized Encapsulation_Key")
k_len: int
switch ek.pke_ek.k {
case _mlkem.K_512:
k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_512]
case _mlkem.K_768:
k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_768]
case _mlkem.K_1024:
k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_1024]
case:
panic("crypto/mlkem: invalid destination size")
}
copy(dst, ek.raw_bytes[:k_len])
}
// encapsulation_key_clear clears ek to the uninitialized state.
encapsulation_key_clear :: proc(ek: ^Encapsulation_Key) {
crypto.zero_explicit(ek, size_of(Encapsulation_Key))
}
// encaps_raw_ek_bytes uses the byte encoded encapsulation key to generate
// a shared secret and an associated ciphertext. This routine will fail
// if the system entropy source is unavailable, or of the encapsulation key
// is invalid.
@(require_results)
encaps_ek_raw_bytes :: proc(params: Parameters, raw_ek, shared_secret, ciphertext: []byte) -> bool {
ek: Encapsulation_Key = ---
if !encapsulation_key_set_bytes(&ek, params, raw_ek) {
return false
}
defer encapsulation_key_clear(&ek)
return encaps_ek(&ek, shared_secret, ciphertext)
}
// encaps_ek uses the encapsulation key to generate a shared secret and an
// associated ciphertext. This routine will fail if the system entropy source
// is unavailable.
@(require_results)
encaps_ek :: proc(ek: ^Encapsulation_Key, shared_secret, ciphertext: []byte) -> bool {
ensure(len(shared_secret) == SHARED_SECRET_SIZE, "crypto/mlkem: invalid shared_seret size")
if !crypto.HAS_RAND_BYTES {
return false
}
m: [_mlkem.SYMBYTES]byte = ---
defer crypto.zero_explicit(&m, size_of(m))
crypto.rand_bytes(m[:])
_mlkem.kem_encaps_internal(shared_secret, ciphertext, ek, m[:])
return true
}
encaps :: proc {
encaps_ek,
encaps_ek_raw_bytes,
}
// decaps uses the decapsulation key to generate a shared secret from a
// ciphertext. Due to ML-KEM's implicit rejection mechanism, this function
// will only return false if and only if (⟺) the lengths of the inputs
// are invalid or the decapsulation key is uninitialized.
//
// This routine returning true does not guarantee that the shared secret
// matches that generated by the peer.
@(require_results)
decaps :: proc(dk: ^Decapsulation_Key, ciphertext, shared_secret: []byte) -> bool {
ensure(len(shared_secret) == SHARED_SECRET_SIZE, "crypto/mlkem: invalid shared_seret size")
ct_len: int
switch dk.pke_dk.k {
case _mlkem.K_512:
ct_len = CIPHERTEXT_SIZES[.ML_KEM_512]
case _mlkem.K_768:
ct_len = CIPHERTEXT_SIZES[.ML_KEM_768]
case _mlkem.K_1024:
ct_len = CIPHERTEXT_SIZES[.ML_KEM_1024]
case:
return false
}
if len(ciphertext) != ct_len {
return false
}
_mlkem.kem_decaps_internal(shared_secret, dk, ciphertext)
return true
}
@(private="file")
params_to_k :: #force_inline proc "contextless" (params: Parameters) -> int {
#partial switch params {
case .ML_KEM_512:
return _mlkem.K_512
case .ML_KEM_768:
return _mlkem.K_768
case .ML_KEM_1024:
return _mlkem.K_1024
}
return 0
}

View File

@@ -0,0 +1,7 @@
/*
ML-KEM Module-Lattice-Based Key-Encapsulation Mechanism.
See:
- [[ https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.pdf ]]
*/
package mlkem

View File

@@ -58,7 +58,9 @@ generate_keypair :: proc(protocol: ^Protocol, private_key: ^ecdh.Private_Key) {
case: panic("crypto/noise: unsupported DH curve in protocol")
}
ecdh.private_key_generate(private_key, protocol.dh)
if !ecdh.private_key_generate(private_key, protocol.dh) {
panic("crypto/noise: entropy source unavailable")
}
}
// Performs a Diffie-Hellman calculation between the private key in key_pair
@@ -837,7 +839,9 @@ handshakestate_read_message :: proc(self: ^Handshake_State, message, dst: []byte
panic("crypto/noise: re was not empty when processing token 'e' during ReadMessage")
}
ecdh.public_key_set_bytes(&self.re, protocol.dh, re)
if !ecdh.public_key_set_bytes(&self.re, protocol.dh, re) {
return nil, .Invalid_Handshake_Message
}
symmetricstate_mix_hash(&self.symmetric_state, re)
if self.message_pattern.is_psk {
symmetricstate_mix_key(&self.symmetric_state, re)
@@ -864,7 +868,10 @@ handshakestate_read_message :: proc(self: ^Handshake_State, message, dst: []byte
panic("crypto/noise: rs was not empty when processing token 's' during ReadMessage")
}
ecdh.public_key_set_bytes(&self.rs, protocol.dh, rs)
if !ecdh.public_key_set_bytes(&self.rs, protocol.dh, rs) {
self.status = .Handshake_Failed
return nil, .Invalid_Handshake_Message
}
msg = msg[rs_len:]
case .ee:

View File

@@ -43,6 +43,7 @@ package all
@(require) import "core:crypto/legacy/keccak"
@(require) import "core:crypto/legacy/md5"
@(require) import "core:crypto/legacy/sha1"
@(require) import "core:crypto/mlkem"
@(require) import cnoise "core:crypto/noise"
@(require) import "core:crypto/pbkdf2"
@(require) import "core:crypto/poly1305"

View File

@@ -48,6 +48,7 @@ package all
@(require) import "core:crypto/legacy/keccak"
@(require) import "core:crypto/legacy/md5"
@(require) import "core:crypto/legacy/sha1"
@(require) import "core:crypto/mlkem"
@(require) import cnoise "core:crypto/noise"
@(require) import "core:crypto/pbkdf2"
@(require) import "core:crypto/poly1305"

View File

@@ -2,6 +2,7 @@ package benchmark_core_crypto
import "base:runtime"
import "core:encoding/hex"
import "core:mem"
import "core:log"
import "core:testing"
import "core:text/table"
@@ -161,7 +162,7 @@ bench_ed25519 :: proc() -> (sk, sig, verif: time.Duration) {
@(private="file")
bench_ecdsa :: proc(curve: ecdsa.Curve, hash: hash.Algorithm) -> (sk, sig, verif: time.Duration) {
priv_bytes := make([]byte, ecdsa.PRIVATE_KEY_SIZES[curve], context.temp_allocator)
crypto.set(raw_data(priv_bytes), 0x69, len(priv_bytes))
mem.set(raw_data(priv_bytes), 0x69, len(priv_bytes))
priv_key: ecdsa.Private_Key
start := time.tick_now()
for _ in 0 ..< DSA_ITERS {

View File

@@ -0,0 +1,84 @@
package benchmark_core_crypto
import "core:log"
import "core:testing"
import "core:text/table"
import "core:time"
import "core:crypto"
import "core:crypto/mlkem"
@(private = "file")
MLKEM_ITERS :: 50000
@(test)
benchmark_crypto_mlkem :: proc(t: ^testing.T) {
if !crypto.HAS_RAND_BYTES {
log.warnf("ML-KEM benchmarks skipped, no system entropy source")
}
tbl: table.Table
table.init(&tbl)
defer table.destroy(&tbl)
table.caption(&tbl, "ML-KEM")
table.aligned_header_of_values(&tbl, .Right, "Parameters", "Keygen", "Encaps", "Decaps")
append_tbl := proc(tbl: ^table.Table, algo_name: string, keygen, encaps, decaps: time.Duration) {
table.aligned_row_of_values(
tbl,
.Right,
algo_name,
table.format(tbl, "%8M", keygen),
table.format(tbl, "%8M", encaps),
table.format(tbl, "%8M", decaps),
)
}
for params in mlkem.Parameters {
if params == .Invalid {
continue
}
param_name := MLKEM_PARAMS_NAMES[params]
decaps_key: mlkem.Decapsulation_Key
start := time.tick_now()
for _ in 0 ..< MLKEM_ITERS {
_ = mlkem.decapsulation_key_generate(&decaps_key, params)
}
keygen := time.tick_since(start) / MLKEM_ITERS
encaps_key := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params])
defer delete(encaps_key)
ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params])
defer delete(ciphertext)
mlkem.decapsulation_key_encaps_bytes(&decaps_key, encaps_key)
bob_shared: [mlkem.SHARED_SECRET_SIZE]byte
start = time.tick_now()
for _ in 0 ..< MLKEM_ITERS {
_ = mlkem.encaps(params, encaps_key, bob_shared[:], ciphertext)
}
encaps := time.tick_since(start) / MLKEM_ITERS
alice_shared: [mlkem.SHARED_SECRET_SIZE]byte
start = time.tick_now()
for _ in 0 ..< MLKEM_ITERS {
_ = mlkem.decaps(&decaps_key, ciphertext, alice_shared[:])
}
decaps := time.tick_since(start) / MLKEM_ITERS
append_tbl(&tbl, param_name, keygen, encaps, decaps)
}
log_table(&tbl)
}
@(private="file")
MLKEM_PARAMS_NAMES := [mlkem.Parameters]string {
.Invalid = "invalid",
.ML_KEM_512 = "ML-KEM-512",
.ML_KEM_768 = "ML-KEM-768",
.ML_KEM_1024 = "ML-KEM-1024",
}

View File

@@ -0,0 +1,72 @@
package test_core_crypto
import "core:bytes"
import "core:log"
import "core:testing"
import "core:crypto"
import "core:crypto/mlkem"
@(test)
test_mlkem :: proc(t: ^testing.T) {
if !crypto.HAS_RAND_BYTES {
log.info("rand_bytes not supported - skipping")
return
}
// Test vectors are huge, and are covered by the wycheproof corpus,
// so just test a full key exchange with all supported parameter
// sets.
for params in mlkem.Parameters {
if params == .Invalid {
continue
}
// Alice
decaps_key: mlkem.Decapsulation_Key
if !testing.expectf(
t,
mlkem.decapsulation_key_generate(&decaps_key, params),
"%v: decapsulation_key_generate",
params,
) {
continue
}
defer mlkem.decapsulation_key_clear(&decaps_key)
ek_bytes := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params])
defer delete(ek_bytes)
mlkem.decapsulation_key_encaps_bytes(&decaps_key, ek_bytes)
// Bob
bob_shared_secret: [mlkem.SHARED_SECRET_SIZE]byte
ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params])
defer delete(ciphertext)
if !testing.expectf(
t,
mlkem.encaps(params, ek_bytes, bob_shared_secret[:], ciphertext),
"%v: encaps",
params,
) {
continue
}
// Alice
alice_shared_secret: [mlkem.SHARED_SECRET_SIZE]byte
if !testing.expectf(
t,
mlkem.decaps(&decaps_key, ciphertext, alice_shared_secret[:]),
"%v: decaps",
params,
) {
continue
}
testing.expectf(
t,
bytes.equal(alice_shared_secret[:], bob_shared_secret[:]),
"%v: shared secret mismatch",
params,
)
}
}

View File

@@ -0,0 +1,514 @@
package test_wycheproof
import "core:encoding/hex"
import "core:log"
import "core:mem"
import "core:os"
import "core:slice"
import "core:testing"
import chacha_simd128 "core:crypto/_chacha20/simd128"
import chacha_simd256 "core:crypto/_chacha20/simd256"
import "core:crypto/aegis"
import "core:crypto/aes"
import "core:crypto/chacha20"
import "core:crypto/chacha20poly1305"
import "../common"
supported_aegis_impls :: proc() -> [dynamic]aes.Implementation {
impls := make([dynamic]aes.Implementation, 0, 2, context.temp_allocator)
append(&impls, aes.Implementation.Portable)
if aegis.is_hardware_accelerated() {
append(&impls, aes.Implementation.Hardware)
}
return impls
}
@(test)
test_aead_aegis :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
files := []string {
"aegis128L_test.json",
"aegis256_test.json",
}
log.debug("aead/aegis: starting")
for f in files {
mem.free_all() // Probably don't need this, but be safe.
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Aead_Test_Group)
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
continue
}
for impl in supported_aegis_impls() {
testing.expectf(t, test_aead_aegis_impl(&test_vectors, impl), "impl {} failed", impl)
}
}
}
test_aead_aegis_impl :: proc(
test_vectors: ^Test_Vectors(Aead_Test_Group),
impl: aes.Implementation,
) -> bool {
log.debug("aead/aegis/%v: starting", impl)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"aead/aegis/%v/%d: %s: %+v",
impl,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("aead/aegis/%v/%d: %+v",
impl,
test_vector.tc_id,
test_vector.flags,
)
}
key := common.hexbytes_decode(test_vector.key)
iv := common.hexbytes_decode(test_vector.iv)
aad := common.hexbytes_decode(test_vector.aad)
msg := common.hexbytes_decode(test_vector.msg)
ct := common.hexbytes_decode(test_vector.ct)
tag := common.hexbytes_decode(test_vector.tag)
if len(iv) == 0 {
log.infof(
"aead/aegis/%v/%d: skipped, invalid IVs panic",
impl,
test_vector.tc_id,
)
num_skipped += 1
continue
}
ctx: aegis.Context
aegis.init(&ctx, key, impl)
if result_is_valid(test_vector.result) {
ct_ := make([]byte, len(ct))
tag_ := make([]byte, len(tag))
aegis.seal(&ctx, ct_, tag_, iv, aad, msg)
ok := common.hexbytes_compare(test_vector.ct, ct_)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(ct_))
log.errorf(
"aead/aegis/%v/%d: ciphertext: expected %s actual %s",
impl,
test_vector.tc_id,
test_vector.ct,
x,
)
num_failed += 1
continue
}
ok = common.hexbytes_compare(test_vector.tag, tag_)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(tag_))
log.errorf(
"aead/aegis/%v/%d: tag: expected %s actual %s",
impl,
test_vector.tc_id,
test_vector.tag,
x,
)
num_failed += 1
continue
}
}
msg_ := make([]byte, len(msg))
ok := aegis.open(&ctx, msg_, iv, aad, ct, tag)
if !result_check(test_vector.result, ok) {
log.errorf("aead/aegis/%v/%d: decrypt failed", impl, test_vector.tc_id)
num_failed += 1
continue
}
if ok && !common.hexbytes_compare(test_vector.msg, msg_) {
x := transmute(string)(hex.encode(msg_))
log.errorf(
"aead/aegis/%v/%d: decrypt msg: expected %s actual %s",
impl,
test_vector.tc_id,
test_vector.msg,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"aead/aegis: ran %d, passed %d, failed %d, skipped %d",
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}
supported_aes_impls :: proc() -> [dynamic]aes.Implementation {
impls := make([dynamic]aes.Implementation, 0, 2)
append(&impls, aes.Implementation.Portable)
if aes.is_hardware_accelerated() {
append(&impls, aes.Implementation.Hardware)
}
return impls
}
@(test)
test_aead_aes_gcm :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
fn, _ := os.join_path([]string{BASE_PATH, "aes_gcm_test.json"}, context.allocator)
log.debug("aead/aes-gcm: starting")
test_vectors: Test_Vectors(Aead_Test_Group)
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", fn) {
return
}
for impl in supported_aes_impls() {
testing.expectf(t, test_aead_aes_gcm_impl(&test_vectors, impl), "impl {} failed", impl)
}
}
test_aead_aes_gcm_impl :: proc(
test_vectors: ^Test_Vectors(Aead_Test_Group),
impl: aes.Implementation,
) -> bool {
log.debug("aead/aes-gcm/%v: starting", impl)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"aead/aes-gcm/%v/%d: %s: %+v",
impl,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("aead/aes-gcm/%v/%d: %+v",
impl,
test_vector.tc_id,
test_vector.flags,
)
}
key := common.hexbytes_decode(test_vector.key)
iv := common.hexbytes_decode(test_vector.iv)
aad := common.hexbytes_decode(test_vector.aad)
msg := common.hexbytes_decode(test_vector.msg)
ct := common.hexbytes_decode(test_vector.ct)
tag := common.hexbytes_decode(test_vector.tag)
if len(iv) == 0 {
log.infof(
"aead/aes-gcm/%v/%d: skipped, invalid IVs panic",
impl,
test_vector.tc_id,
)
num_skipped += 1
continue
}
ctx: aes.Context_GCM
aes.init_gcm(&ctx, key, impl)
if result_is_valid(test_vector.result) {
ct_ := make([]byte, len(ct))
tag_ := make([]byte, len(tag))
aes.seal_gcm(&ctx, ct_, tag_, iv, aad, msg)
ok := common.hexbytes_compare(test_vector.ct, ct_)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(ct_))
log.errorf(
"aead/aes-gcm/%v/%d: ciphertext: expected %s actual %s",
impl,
test_vector.tc_id,
test_vector.ct,
x,
)
num_failed += 1
continue
}
ok = common.hexbytes_compare(test_vector.tag, tag_)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(tag_))
log.errorf(
"aead/aes-gcm/%v/%d: tag: expected %s actual %s",
impl,
test_vector.tc_id,
test_vector.tag,
x,
)
num_failed += 1
continue
}
}
msg_ := make([]byte, len(msg))
ok := aes.open_gcm(&ctx, msg_, iv, aad, ct, tag)
if !result_check(test_vector.result, ok) {
log.errorf("aead/aes-gcm/%v/%d: decrypt failed", impl, test_vector.tc_id)
num_failed += 1
continue
}
if ok && !common.hexbytes_compare(test_vector.msg, msg_) {
x := transmute(string)(hex.encode(msg_))
log.errorf(
"aead/aes-gcm/%v/%d: decrypt msg: expected %s actual %s",
impl,
test_vector.tc_id,
test_vector.msg,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"aead/aes-gcm: ran %d, passed %d, failed %d, skipped %d",
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}
supported_chacha_impls :: proc() -> [dynamic]chacha20.Implementation {
impls := make([dynamic]chacha20.Implementation, 0, 3)
append(&impls, chacha20.Implementation.Portable)
if chacha_simd128.is_performant() {
append(&impls, chacha20.Implementation.Simd128)
}
if chacha_simd256.is_performant() {
append(&impls, chacha20.Implementation.Simd256)
}
return impls
}
@(test)
test_aead_chacha20_poly1305 :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
files := []string {
"chacha20_poly1305_test.json",
"xchacha20_poly1305_test.json",
}
log.debug("aead/(x)chacha20poly1305: starting")
for f, i in files {
mem.free_all() // Probably don't need this, but be safe.
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Aead_Test_Group)
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
continue
}
for impl in supported_chacha_impls() {
testing.expectf(t, test_aead_chacha20_poly1305_impl(&test_vectors, i == 1, impl), "impl {} failed", impl)
}
}
}
test_aead_chacha20_poly1305_impl :: proc(
test_vectors: ^Test_Vectors(Aead_Test_Group),
is_xchacha: bool,
impl: chacha20.Implementation,
) -> bool {
FLAG_INVALID_NONCE_SIZE :: "InvalidNonceSize"
alg_str := is_xchacha ? "xchacha20poly1305" : "chacha20poly1305"
num_ran, num_passed, num_failed, num_skipped: int
for &test_group in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"aead/%s/%v/%d: %s: %+v",
alg_str,
impl,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("aead/%s/%v/%d: %+v",
alg_str,
impl,
test_vector.tc_id,
test_vector.flags,
)
}
key := common.hexbytes_decode(test_vector.key)
iv := common.hexbytes_decode(test_vector.iv)
aad := common.hexbytes_decode(test_vector.aad)
msg := common.hexbytes_decode(test_vector.msg)
ct := common.hexbytes_decode(test_vector.ct)
tag := common.hexbytes_decode(test_vector.tag)
if slice.contains(test_vector.flags, FLAG_INVALID_NONCE_SIZE) {
log.infof(
"aead/%s/%v/%d: skipped, invalid nonces panic",
alg_str,
impl,
test_vector.tc_id,
)
num_skipped += 1
continue
}
ctx: chacha20poly1305.Context
switch is_xchacha {
case true:
chacha20poly1305.init_xchacha(&ctx, key, impl)
case false:
chacha20poly1305.init(&ctx, key, impl)
}
if result_is_valid(test_vector.result) {
ct_ := make([]byte, len(ct))
tag_ := make([]byte, len(tag))
chacha20poly1305.seal(&ctx, ct_, tag_, iv, aad, msg)
ok := common.hexbytes_compare(test_vector.ct, ct_)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(ct_))
log.errorf(
"aead/%s/%v/%d: ciphertext: expected %s actual %s",
alg_str,
impl,
test_vector.tc_id,
test_vector.ct,
x,
)
num_failed += 1
continue
}
ok = common.hexbytes_compare(test_vector.tag, tag_)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(tag_))
log.errorf(
"aead/%s/%v/%d: tag: expected %s actual %s",
alg_str,
impl,
test_vector.tc_id,
test_vector.tag,
x,
)
num_failed += 1
continue
}
}
msg_ := make([]byte, len(msg))
ok := chacha20poly1305.open(&ctx, msg_, iv, aad, ct, tag)
if !result_check(test_vector.result, ok) {
log.errorf("aead/%s/%v/%d: decrypt failed",
alg_str,
impl,
test_vector.tc_id,
)
num_failed += 1
continue
}
if ok && !common.hexbytes_compare(test_vector.msg, msg_) {
x := transmute(string)(hex.encode(msg_))
log.errorf(
"aead/%s/%v/%d: decrypt msg: expected %s actual %s",
alg_str,
impl,
test_vector.tc_id,
test_vector.msg,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"aead/%s/%v: ran %d, passed %d, failed %d, skipped %d",
alg_str,
impl,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}

View File

@@ -0,0 +1,427 @@
package test_wycheproof
import "core:encoding/hex"
import "core:log"
import "core:mem"
import "core:os"
import "core:slice"
import "core:strings"
import "core:testing"
import "core:crypto/hash"
import "core:crypto/ecdh"
import "core:crypto/ecdsa"
import "core:crypto/ed25519"
import "../common"
@(test)
test_eddsa_ed25519 :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
fn, _ := os.join_path([]string{BASE_PATH, "ed25519_test.json"}, context.allocator)
log.debug("eddsa/ed25519: starting")
test_vectors: Test_Vectors(Eddsa_Test_Group)
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", fn) {
return
}
num_ran, num_passed, num_failed, num_skipped: int
for &test_group, i in test_vectors.test_groups {
mem.free_all() // Probably don't need this, but be safe.
pk_bytes := common.hexbytes_decode(test_group.public_key.pk)
pk: ed25519.Public_Key
pk_ok := ed25519.public_key_set_bytes(&pk, pk_bytes)
if !testing.expectf(t, pk_ok, "eddsa/ed25519/%d: invalid public key: %s", i, test_group.public_key.pk) {
num_failed += len(test_group.tests)
continue
}
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"eddsa/ed25519/%d: %s: %+v",
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("eddsa/ed25519/%d: %+v", test_vector.tc_id, test_vector.flags)
}
msg := common.hexbytes_decode(test_vector.msg)
sig := common.hexbytes_decode(test_vector.sig)
verify_ok := ed25519.verify(&pk, msg, sig)
if !testing.expectf(
t,
result_check(test_vector.result, verify_ok),
"eddsa/ed25519/%d: verify failed: expected %s actual %v",
test_vector.tc_id,
test_vector.result,
verify_ok,
) {
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"eddsa/ed25519: ran %d, passed %d, failed %d, skipped %d",
num_ran,
num_passed,
num_failed,
num_skipped,
)
}
@(test)
test_ecdsa :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
log.debug("ecdsa: starting")
files := []string {
"ecdsa_secp256r1_sha256_test.json",
"ecdsa_secp256r1_sha512_test.json",
"ecdsa_secp384r1_sha384_test.json",
}
for f in files {
mem.free_all() // Probably don't need this, but be safe.
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Ecdsa_Test_Group)
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
continue
}
testing.expectf(t, test_ecdsa_impl(t, &test_vectors), "ecdsa failed")
}
}
test_ecdsa_impl :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Ecdsa_Test_Group)) -> bool {
curve_str := test_vectors.test_groups[0].public_key.curve
hash_str := test_vectors.test_groups[0].sha
curve_alg: ecdsa.Curve
switch curve_str {
case "secp256r1":
curve_alg = .SECP256R1
case "secp384r1":
curve_alg = .SECP384R1
case:
log.errorf("ecdsa: unsupported curve: %s", curve_str)
}
hash_alg: hash.Algorithm
switch hash_str {
case "SHA-256":
hash_alg = .SHA256
case "SHA-384":
hash_alg = .SHA384
case "SHA-512":
hash_alg = .SHA512
case:
log.errorf("ecdsa: unsupported hash: %s", hash_str)
}
log.debugf("ecdsa/%s/%s: starting", curve_str, hash_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group, i in test_vectors.test_groups {
pk_bytes := common.hexbytes_decode(test_group.public_key.uncompressed)
pk: ecdsa.Public_Key
pk_ok := ecdsa.public_key_set_bytes(&pk, curve_alg, pk_bytes)
if !testing.expectf(t, pk_ok, "ecdsa/%s/%s/%d: invalid public key: %s", curve_str, hash_str, i, test_group.public_key.uncompressed) {
num_failed += len(test_group.tests)
continue
}
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"ecda/%s/%s/%d: %s: %+v",
curve_str,
hash_str,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("ecdsa/%s/%s/%d: %+v", curve_str, hash_str, test_vector.tc_id, test_vector.flags)
}
msg := common.hexbytes_decode(test_vector.msg)
sig := common.hexbytes_decode(test_vector.sig)
verify_ok := ecdsa.verify_asn1(&pk, hash_alg, msg, sig)
if !testing.expectf(
t,
result_check(test_vector.result, verify_ok),
"ecdsa/%s/%s/%d: verify failed: expected %s actual %v",
curve_str,
hash_str,
test_vector.tc_id,
test_vector.result,
verify_ok,
) {
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"ecdsa/%s/%s: ran %d, passed %d, failed %d, skipped %d",
curve_str,
hash_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}
@(test)
test_ecdh :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
PREFIX_TEST_ECDH :: "ecdh_"
SUFFIX_TEST_ECPOINT :: "_ecpoint"
files := []string {
"ecdh_secp256r1_ecpoint_test.json",
"ecdh_secp384r1_ecpoint_test.json",
"x25519_test.json",
"x448_test.json",
}
log.debug("ecdh: starting")
for f in files {
mem.free_all() // Probably don't need this, but be safe.
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Ecdh_Test_Group)
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
continue
}
alg_str := strings.trim_suffix(f, SUFFIX_TEST_JSON)
alg_str = strings.trim_suffix(alg_str, SUFFIX_TEST_ECPOINT)
alg_str = strings.trim_prefix(alg_str, PREFIX_TEST_ECDH)
testing.expectf(t, test_ecdh_impl(&test_vectors, alg_str), "alg {} failed", alg_str)
}
}
test_ecdh_impl :: proc(
test_vectors: ^Test_Vectors(Ecdh_Test_Group),
alg_str: string,
) -> bool {
ALG_P256 :: "secp256r1"
ALG_P384 :: "secp384r1"
ALG_X25519 :: "x25519"
ALG_X448 :: "x448"
// XDH exceptions
FLAG_PUBLIC_KEY_TOO_LONG :: "PublicKeyTooLong"
FLAG_ZERO_SHARED_SECRET :: "ZeroSharedSecret"
// ECDH exceptions
FLAG_COMPRESSED_POINT :: "CompressedPoint"
FLAG_INVALID_CURVE :: "InvalidCurveAttack"
FLAG_INVALID_ENCODING :: "InvalidEncoding"
log.debugf("ecdh/%s: starting", alg_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf("ecdh/%s/%d: %s: %+v", alg_str, test_vector.tc_id, comment, test_vector.flags)
} else {
log.debugf("ecdh/%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags)
}
raw_pub := common.hexbytes_decode(test_vector.public)
raw_priv := common.hexbytes_decode(test_vector.private)
curve: ecdh.Curve
priv_key: ecdh.Private_Key
pub_key: ecdh.Public_Key
is_nist, is_xdh: bool
switch alg_str {
case ALG_P256:
curve = .SECP256R1
// Ugh, ASN.1 :(
l := len(raw_priv)
if l == 33 {
if raw_priv[0] == 0 {
raw_priv = raw_priv[1:]
}
} else if l < 32 {
// left-pad.odin
tmp := make([]byte, 32)
copy(tmp[32-l:], raw_priv)
raw_priv = tmp
}
is_nist = true
case ALG_P384:
curve = .SECP384R1
// Ugh, ASN.1 :(
l := len(raw_priv)
if l == 49 {
if raw_priv[0] == 0 {
raw_priv = raw_priv[1:]
}
} else if l < 48 {
// left-pad.odin
tmp := make([]byte, 48)
copy(tmp[48-l:], raw_priv)
raw_priv = tmp
}
is_nist = true
case ALG_X25519:
curve = .X25519
is_xdh = true
case ALG_X448:
curve = .X448
is_xdh = true
case:
log.errorf("ecdh: unsupported algorithm: %s", alg_str)
return false
}
if ok := ecdh.private_key_set_bytes(&priv_key, curve, raw_priv); !ok {
log.errorf(
"ecdh/%s/%d: failed to deserialize private_key: %s %d %x",
alg_str,
test_vector.tc_id,
test_vector.private,
len(raw_priv),
raw_priv,
)
num_failed += 1
continue
}
if ok := ecdh.public_key_set_bytes(&pub_key, curve, raw_pub); !ok {
if is_nist {
if slice.contains(test_vector.flags, FLAG_COMPRESSED_POINT) {
num_passed += 1
continue
}
if slice.contains(test_vector.flags, FLAG_INVALID_CURVE) {
num_passed += 1
continue
}
if slice.contains(test_vector.flags, FLAG_INVALID_ENCODING) {
num_passed += 1
continue
}
}
if slice.contains(test_vector.flags, FLAG_PUBLIC_KEY_TOO_LONG) {
num_passed += 1
continue
}
log.errorf(
"ecdh/%s/%d: failed to deserialize public_key: %s",
alg_str,
test_vector.tc_id,
test_vector.public,
)
num_failed += 1
continue
}
shared := make([]byte, ecdh.SHARED_SECRET_SIZES[curve])
ok := ecdh.ecdh(&priv_key, &pub_key, shared)
if !ok {
if is_xdh && slice.contains(test_vector.flags, FLAG_ZERO_SHARED_SECRET) {
num_passed += 1
continue
}
// unused: x := transmute(string)(hex.encode(shared))
log.errorf(
"ecdh/%s/%d: ecdh failed",
alg_str,
test_vector.tc_id,
)
num_failed += 1
continue
}
ok = common.hexbytes_compare(test_vector.shared, shared)
// "acceptable" results are fine from here because we have
// checked for the all-zero shared secret XDH case already.
if !result_check(test_vector.result, ok, false) {
x := transmute(string)(hex.encode(shared))
log.errorf(
"ecdh/%s/%d: shared: expected %s actual %s",
alg_str,
test_vector.tc_id,
test_vector.shared,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"ecdh/%s: ran %d, passed %d, failed %d, skipped %d",
alg_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}

View File

@@ -0,0 +1,238 @@
package test_wycheproof
import "core:encoding/hex"
import "core:log"
import "core:mem"
import "core:os"
import "core:slice"
import "core:strings"
import "core:testing"
import "core:crypto/hkdf"
import "core:crypto/pbkdf2"
import "../common"
@(test)
test_hkdf :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
log.debug("hkdf: starting")
files := []string {
"hkdf_sha1_test.json",
"hkdf_sha256_test.json",
"hkdf_sha384_test.json",
"hkdf_sha512_test.json",
}
for f in files {
mem.free_all() // Probably don't need this, but be safe.
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Hkdf_Test_Group)
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
continue
}
testing.expectf(t, test_hkdf_impl(&test_vectors), "hkdf failed")
}
}
test_hkdf_impl :: proc(test_vectors: ^Test_Vectors(Hkdf_Test_Group)) -> bool {
PREFIX_HKDF :: "HKDF-"
FLAG_SIZE_TOO_LARGE :: "SizeTooLarge"
alg_str := strings.trim_prefix(test_vectors.algorithm, PREFIX_HKDF)
alg, ok := hash_name_to_algorithm(alg_str)
if !ok {
return false
}
alg_str = strings.to_lower(alg_str)
log.debugf("hkdf/%s: starting", alg_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"hkdf/%s/%d: %s: %+v",
alg_str,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("hkdf/%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags)
}
ikm := common.hexbytes_decode(test_vector.ikm)
salt := common.hexbytes_decode(test_vector.salt)
info := common.hexbytes_decode(test_vector.info)
if slice.contains(test_vector.flags, FLAG_SIZE_TOO_LARGE) {
log.infof(
"hkdf/%s/%d: skipped, oversized outputs panic",
alg_str,
test_vector.tc_id,
)
num_skipped += 1
continue
}
okm_ := make([]byte, test_vector.size)
hkdf.extract_and_expand(alg, salt, ikm, info, okm_)
ok = common.hexbytes_compare(test_vector.okm, okm_)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(okm_))
log.errorf(
"hkdf/%s/%d: shared: expected %s actual %s",
alg_str,
test_vector.tc_id,
test_vector.okm,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"hkdf/%s: ran %d, passed %d, failed %d, skipped %d",
alg_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}
@(test)
test_pbkdf2 :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
log.debug("pbkdf2: starting")
files := []string {
"pbkdf2_hmacsha1_test.json",
"pbkdf2_hmacsha224_test.json",
"pbkdf2_hmacsha256_test.json",
"pbkdf2_hmacsha384_test.json",
"pbkdf2_hmacsha512_test.json",
}
for f in files {
mem.free_all() // Probably don't need this, but be safe.
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Pbkdf_Test_Group)
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
continue
}
testing.expectf(t, test_pbkdf2_impl(&test_vectors), "pbkdf2 failed")
}
}
test_pbkdf2_impl :: proc(
test_vectors: ^Test_Vectors(Pbkdf_Test_Group),
) -> bool {
PREFIX_PBKDF_HMAC :: "PBKDF2-HMAC"
FLAG_LARGE_ITERATION_COUNT :: "LargeIterationCount"
alg_str := strings.trim_prefix(test_vectors.algorithm, PREFIX_PBKDF_HMAC)
alg, ok := hash_name_to_algorithm(alg_str)
if !ok {
return false
}
alg_str = strings.to_lower(alg_str)
log.debugf("pbkdf2/hmac-%s: starting", alg_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"pbkdf2/hmac-%s/%d: %s: %+v",
alg_str,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("pbkdf2/hmac-%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags)
}
if slice.contains(test_vector.flags, FLAG_LARGE_ITERATION_COUNT) {
log.infof(
"pbkdf2/hmac-%s/%d: skipped, takes fucking forever",
alg_str,
test_vector.tc_id,
)
num_skipped += 1
continue
}
password := common.hexbytes_decode(test_vector.password)
salt := common.hexbytes_decode(test_vector.salt)
dk_ := make([]byte, test_vector.dk_len)
pbkdf2.derive(alg, password, salt, test_vector.iteration_count, dk_)
ok = common.hexbytes_compare(test_vector.dk, dk_)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(dk_))
log.errorf(
"pbkdf2/hmac-%s/%d: shared: expected %s actual %s",
alg_str,
test_vector.tc_id,
test_vector.dk,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"pbkdf2/%s: ran %d, passed %d, failed %d, skipped %d",
alg_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}

View File

@@ -0,0 +1,156 @@
package test_wycheproof
import "core:encoding/hex"
import "core:log"
import "core:mem"
import "core:os"
import "core:testing"
import "core:crypto/hmac"
import "core:crypto/kmac"
import "core:crypto/siphash"
import "../common"
@(test)
test_mac :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
log.debug("mac: starting")
files := []string {
"hmac_sha1_test.json",
"hmac_sha224_test.json",
"hmac_sha256_test.json",
"hmac_sha3_224_test.json",
"hmac_sha3_256_test.json",
"hmac_sha3_384_test.json",
"hmac_sha3_512_test.json",
"hmac_sha384_test.json",
// "hmac_sha512_224_test.json",
"hmac_sha512_256_test.json",
"hmac_sha512_test.json",
"hmac_sm3_test.json",
"kmac128_no_customization_test.json",
"kmac256_no_customization_test.json",
"siphash_1_3_test.json",
"siphash_2_4_test.json",
"siphash_4_8_test.json",
}
for f in files {
mem.free_all() // Probably don't need this, but be safe.
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Mac_Test_Group)
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
continue
}
testing.expectf(t, test_mac_impl(&test_vectors), "hkdf failed")
}
}
test_mac_impl :: proc(test_vectors: ^Test_Vectors(Mac_Test_Group)) -> bool {
PREFIX_HMAC :: "HMAC"
PREFIX_KMAC :: "KMAC"
mac_alg, hmac_alg, alg_str, ok := mac_algorithm(test_vectors.algorithm)
if !ok {
log.errorf("mac: unsupported algorith: %s", test_vectors.algorithm)
return false
}
log.debugf("%s: starting", alg_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
if comment := test_vector.comment; comment != "" {
log.debugf(
"%s/%d: %s: %+v",
alg_str,
test_vector.tc_id,
comment,
test_vector.flags,
)
} else {
log.debugf("%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags)
}
key := common.hexbytes_decode(test_vector.key)
msg := common.hexbytes_decode(test_vector.msg)
tag_ := make([]byte, len(test_vector.tag) / 2)
#partial switch mac_alg {
case .HMAC:
ctx: hmac.Context
hmac.init(&ctx, hmac_alg, key)
hmac.update(&ctx, msg)
if l := hmac.tag_size(&ctx); l == len(tag_) {
hmac.final(&ctx, tag_)
} else {
// Our hmac package does not support truncation.
tmp := make([]byte, l)
hmac.final(&ctx, tmp)
copy(tag_, tmp)
}
case .KMAC128, .KMAC256:
ctx: kmac.Context
#partial switch mac_alg {
case .KMAC128:
kmac.init_128(&ctx, key, nil)
case .KMAC256:
kmac.init_256(&ctx, key, nil)
}
kmac.update(&ctx, msg)
kmac.final(&ctx, tag_)
case .SIPHASH_1_3:
siphash.sum_1_3(msg, key, tag_)
case .SIPHASH_2_4:
siphash.sum_2_4(msg, key, tag_)
case .SIPHASH_4_8:
siphash.sum_4_8(msg, key, tag_)
}
ok = common.hexbytes_compare(test_vector.tag, tag_)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(tag_))
log.errorf(
"%s/%d: tag: expected %s actual %s",
alg_str,
test_vector.tc_id,
test_vector.tag,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"%s: ran %d, passed %d, failed %d, skipped %d",
alg_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,401 @@
package test_wycheproof
import "core:encoding/hex"
import "core:log"
import "core:mem"
import "core:os"
import "core:testing"
import "core:crypto/_mlkem"
import "core:crypto/mlkem"
import "../common"
@(test)
test_mlkem :: proc(t: ^testing.T) {
arena: mem.Arena
arena_backing := make([]byte, ARENA_SIZE)
defer delete(arena_backing)
mem.arena_init(&arena, arena_backing)
context.allocator = mem.arena_allocator(&arena)
log.debug("mlkem: starting")
files_keygen := []string {
"mlkem_512_keygen_seed_test.json",
"mlkem_768_keygen_seed_test.json",
"mlkem_1024_keygen_seed_test.json",
}
for f in files_keygen {
mem.free_all()
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Kem_Test_Group)
load_ok := load(&test_vectors, fn)
if !testing.expectf(t, load_ok, "Unable to load {}", f) {
continue
}
testing.expectf(t, test_mlkem_keygen(t, &test_vectors), "ML-KEM KeyGen failed")
}
files_encaps := []string {
"mlkem_512_encaps_test.json",
"mlkem_768_encaps_test.json",
"mlkem_1024_encaps_test.json",
}
for f in files_encaps {
mem.free_all()
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Kem_Test_Group)
load_ok := load(&test_vectors, fn)
if !testing.expectf(t, load_ok, "Unable to load {}", f) {
continue
}
testing.expectf(t, test_mlkem_encaps(t, &test_vectors), "ML-KEM Encaps failed")
}
files_decaps := []string {
"mlkem_512_test.json",
"mlkem_768_test.json",
"mlkem_1024_test.json",
}
for f in files_decaps {
mem.free_all()
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
test_vectors: Test_Vectors(Kem_Test_Group)
load_ok := load(&test_vectors, fn)
if !testing.expectf(t, load_ok, "Unable to load {}", f) {
continue
}
testing.expectf(t, test_mlkem_decaps(t, &test_vectors), "ML-KEM Decaps failed")
}
}
test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool {
params_str := test_vectors.test_groups[0].parameter_set
params := parameter_set_to_params(params_str)
if params == .Invalid {
return false
}
log.debugf("%s: KeyGen starting", params_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group, tg_id in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
seed := common.hexbytes_decode(test_vector.seed)
dk: mlkem.Decapsulation_Key
if !testing.expectf(
t,
mlkem.decapsulation_key_set_bytes(&dk, params, seed),
"%s/KeyGen/%d/%d: failed to set decapsulation key from seed",
params_str,
tg_id,
test_vector.tc_id,
test_vector.seed,
) {
num_failed *= 1
continue
}
ek_bytes := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params])
mlkem.decapsulation_key_encaps_bytes(&dk, ek_bytes)
ok := common.hexbytes_compare(test_vector.ek, ek_bytes)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(ek_bytes))
log.errorf(
"%s/KeyGen/%d/%d: ek: expected %s actual %s",
params_str,
tg_id,
test_vector.tc_id,
test_vector.ek,
x,
)
num_failed += 1
continue
}
dk_bytes := make([]byte, mlkem.DECAPSULATION_KEY_EXPANDED_SIZES[params])
mlkem.decapsulation_key_expanded_bytes(&dk, dk_bytes)
ok = common.hexbytes_compare(test_vector.dk, dk_bytes)
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(dk_bytes))
log.errorf(
"%s/KeyGen/%d/%d: dk: expected %s actual %s",
tg_id,
params_str,
test_vector.tc_id,
test_vector.dk,
x,
)
num_failed += 1
continue
}
seed_bytes: [mlkem.DECAPSULATION_KEY_SEED_SIZE]byte
mlkem.decapsulation_key_bytes(&dk, seed_bytes[:])
ok = common.hexbytes_compare(test_vector.seed, seed_bytes[:])
if !result_check(test_vector.result, ok) {
x := transmute(string)(hex.encode(seed_bytes[:]))
log.errorf(
"%s/KeyGen/%d/%d: seed: expected %s actual %s",
tg_id,
params_str,
test_vector.tc_id,
test_vector.seed,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"%s/KeyGen: ran %d, passed %d, failed %d, skipped %d",
params_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}
test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool {
params_str := test_vectors.test_groups[0].parameter_set
params := parameter_set_to_params(params_str)
if params == .Invalid {
return false
}
log.debugf("%s: Encaps starting", params_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group, tg_id in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
ek: mlkem.Encapsulation_Key
ok := mlkem.encapsulation_key_set_bytes(
&ek,
params,
common.hexbytes_decode(test_vector.ek),
)
// The current corpus can only fail if the encapsulation key
// is malformed in some way.
if !result_check(test_vector.result, ok) {
log.errorf(
"%s/Encaps/%d/%d: unexpected set encapsulation key from bytes: %s (%v != %v)",
params_str,
tg_id,
test_vector.tc_id,
test_vector.ek,
test_vector.result,
ok,
)
num_failed += 1
continue
}
if !ok {
num_passed += 1
continue
}
shared_secret: [mlkem.SHARED_SECRET_SIZE]byte
ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params])
_mlkem.kem_encaps_internal(
shared_secret[:],
ciphertext,
&ek,
common.hexbytes_decode(test_vector.m),
)
ok = common.hexbytes_compare(test_vector.c, ciphertext)
if !ok {
x := transmute(string)(hex.encode(ciphertext))
log.errorf(
"%s/Encaps/%d/%d: ciphertext: expected: %s actual: %s",
params_str,
tg_id,
test_vector.tc_id,
test_vector.c,
x,
)
num_failed += 1
continue
}
ok = common.hexbytes_compare(test_vector.k, shared_secret[:])
if !ok {
x := transmute(string)(hex.encode(shared_secret[:]))
log.errorf(
"%s/Encaps/%d/%d: shared_secret: expected: %s actual: %s",
params_str,
tg_id,
test_vector.tc_id,
test_vector.k,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"%s/Encaps: ran %d, passed %d, failed %d, skipped %d",
params_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}
test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool {
params_str := test_vectors.test_groups[0].parameter_set
params := parameter_set_to_params(params_str)
if params == .Invalid {
return false
}
log.debugf("%s: Decaps starting", params_str)
num_ran, num_passed, num_failed, num_skipped: int
for &test_group, tg_id in test_vectors.test_groups {
for &test_vector in test_group.tests {
num_ran += 1
// We do not have an API for decaps with raw seed.
seed := common.hexbytes_decode(test_vector.seed)
switch len(seed) {
case mlkem.DECAPSULATION_KEY_SEED_SIZE:
case:
if testing.expectf(
t,
result_is_invalid(test_vector.result),
"%s/Decaps/%d/%d: test vector expects success with invalid seed",
params_str,
tg_id,
test_vector.tc_id,
) {
num_passed += 1
} else {
num_failed += 1
}
continue
}
dk: mlkem.Decapsulation_Key
if !testing.expectf(
t,
mlkem.decapsulation_key_set_bytes(&dk, params, seed),
"%s/Decaps/%d/%d: failed to set decapsulation key from seed",
params_str,
tg_id,
test_vector.tc_id,
test_vector.seed,
) {
num_failed *= 1
continue
}
shared_secret: [mlkem.SHARED_SECRET_SIZE]byte
ok := mlkem.decaps(
&dk,
common.hexbytes_decode(test_vector.c),
shared_secret[:],
)
if !result_check(test_vector.result, ok) {
log.errorf(
"%s/Decaps/%d/%d: unexpected decapsulation failure",
params_str,
tg_id,
test_vector.tc_id,
)
num_failed += 1
continue
}
if !ok {
num_passed += 1
continue
}
ok = common.hexbytes_compare(test_vector.k, shared_secret[:])
if !ok {
x := transmute(string)(hex.encode(shared_secret[:]))
log.errorf(
"%s/Decaps/%d/%d: shared_secret: expected: %s actual: %s",
params_str,
tg_id,
test_vector.tc_id,
test_vector.k,
x,
)
num_failed += 1
continue
}
num_passed += 1
}
}
assert(num_ran == test_vectors.number_of_tests)
assert(num_passed + num_failed + num_skipped == num_ran)
log.infof(
"%s/Decaps: ran %d, passed %d, failed %d, skipped %d",
params_str,
num_ran,
num_passed,
num_failed,
num_skipped,
)
return num_failed == 0
}
@(require_results, private="file")
parameter_set_to_params :: proc(s: string) -> mlkem.Parameters {
switch s {
case "ML-KEM-512":
return .ML_KEM_512
case "ML-KEM-768":
return .ML_KEM_768
case "ML-KEM-1024":
return .ML_KEM_1024
case:
return .Invalid
}
}

View File

@@ -64,6 +64,11 @@ Test_Vectors_Note :: struct {
links: []string `json:"links"`,
}
Test_Group_Source :: struct {
name: string `json:"name"`,
version: string `json:"version"`,
}
Aead_Test_Group :: struct {
iv_size: int `json:"ivSize"`,
key_size: int `json:"keySize"`,
@@ -198,3 +203,23 @@ Pbkdf_Test_Vector :: struct {
result: Result `json:"result"`,
flags: []string `json:"flags"`,
}
Kem_Test_Group :: struct {
type: string `json:"type"`,
source: Test_Group_Source `json:"source"`,
parameter_set: string `json:"parameterSet"`,
tests: []Kem_Test_Vector `json:"tests"`,
}
Kem_Test_Vector :: struct {
tc_id: int `json:"tcId"`,
flags: []string `json:"flags"`,
comment: string `json:"comment"`,
seed: common.Hex_Bytes `json:"seed"`,
m: common.Hex_Bytes `json:"m"`,
ek: common.Hex_Bytes `json:"ek"`,
dk: common.Hex_Bytes `json:"dk"`,
c: common.Hex_Bytes `json:"c"`,
k: common.Hex_Bytes `json:"K"`,
result: Result `json:"result"`,
}