mirror of
https://github.com/odin-lang/Odin.git
synced 2026-05-25 13:18:14 +00:00
Merge pull request #6635 from Yawning/feature/mlkem
core/crypto/mlkem: Support ML-KEM (FIPS 203)
This commit is contained in:
52
core/crypto/_mlkem/cbd.odin
Normal file
52
core/crypto/_mlkem/cbd.odin
Normal file
@@ -0,0 +1,52 @@
|
||||
#+private
|
||||
package _mlkem
|
||||
|
||||
import "core:encoding/endian"
|
||||
|
||||
unchecked_get_u24le :: #force_inline proc "contextless" (b: []byte) -> u32 #no_bounds_check {
|
||||
r := u32(b[0])
|
||||
r |= u32(b[1]) << 8
|
||||
r |= u32(b[2]) << 16
|
||||
return r
|
||||
}
|
||||
|
||||
cbd3 :: proc "contextless" (r: ^Poly, buf: ^[3*N/4]byte) #no_bounds_check {
|
||||
for i in 0..<N/4 {
|
||||
t := unchecked_get_u24le(buf[3*i:])
|
||||
d := t & 0x00249249
|
||||
d += (t>>1) & 0x00249249
|
||||
d += (t>>2) & 0x00249249
|
||||
|
||||
for j in uint(0)..<4 {
|
||||
a := i16((d >> (6*j+0)) & 0x7)
|
||||
b := i16((d >> (6*j+3)) & 0x7)
|
||||
r.coeffs[4*i+int(j)] = a - b
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cbd2 :: proc "contextless" (r: ^Poly, buf: ^[2*N/4]byte) #no_bounds_check {
|
||||
for i in 0..<N/8 {
|
||||
t := endian.unchecked_get_u32le(buf[4*i:])
|
||||
d := t & 0x55555555
|
||||
d += (t>>1) & 0x55555555
|
||||
|
||||
for j in uint(0)..<8 {
|
||||
a := i16((d >> (4*j+0)) & 0x3)
|
||||
b := i16((d >> (4*j+2)) & 0x3)
|
||||
r.coeffs[8*i+int(j)] = a - b
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
poly_cbd_eta1_512 :: proc "contextless" (r: ^Poly, buf: ^[ETA1_512*N/4]byte) {
|
||||
cbd3(r, buf)
|
||||
}
|
||||
|
||||
poly_cbd_eta1 :: proc "contextless" (r: ^Poly, buf: ^[ETA1*N/4]byte) {
|
||||
cbd2(r, buf)
|
||||
}
|
||||
|
||||
poly_cbd_eta2 :: proc "contextless" (r: ^Poly, buf: ^[ETA2*N/4]byte) {
|
||||
cbd2(r, buf)
|
||||
}
|
||||
53
core/crypto/_mlkem/constants.odin
Normal file
53
core/crypto/_mlkem/constants.odin
Normal file
@@ -0,0 +1,53 @@
|
||||
package _mlkem
|
||||
|
||||
K_512 :: 2
|
||||
K_768 :: 3
|
||||
K_1024 :: 4
|
||||
K_MAX :: K_1024
|
||||
|
||||
N :: 256
|
||||
Q :: 3329
|
||||
|
||||
ETA1_512 :: 3
|
||||
ETA1 :: 2
|
||||
ETA2 :: 2
|
||||
|
||||
POLYBYTES :: 384
|
||||
SYMBYTES :: 32
|
||||
|
||||
POLYCOMPRESSEDBYTES_512 :: 128
|
||||
POLYCOMPRESSEDBYTES_768 :: 128
|
||||
POLYCOMPRESSEDBYTES_1024 :: 160
|
||||
|
||||
POLYVECBYTES_512 :: K_512 * POLYBYTES
|
||||
POLYVECBYTES_768 :: K_768 * POLYBYTES
|
||||
POLYVECBYTES_1024 :: K_1024 * POLYBYTES
|
||||
|
||||
POLYVECCOMPRESSEDBYTES_512 :: K_512 * 320
|
||||
POLYVECCOMPRESSEDBYTES_768 :: K_768 * 320
|
||||
POLYVECCOMPRESSEDBYTES_1024 :: K_1024 * 352
|
||||
|
||||
INDCPA_MSGBYTES :: SYMBYTES
|
||||
INDCPA_PUBLICKEYBYTES_512 :: POLYVECBYTES_512 + SYMBYTES
|
||||
INDCPA_SECRETKEYBYTES_512 :: POLYVECBYTES_512
|
||||
INDCPA_PUBLICKEYBYTES_768 :: POLYVECBYTES_768 + SYMBYTES
|
||||
INDCPA_SECRETKEYBYTES_768 :: POLYVECBYTES_768
|
||||
INDCPA_PUBLICKEYBYTES_1024 :: POLYVECBYTES_1024 + SYMBYTES
|
||||
INDCPA_SECRETKEYBYTES_1024 :: POLYVECBYTES_1024
|
||||
INDCPA_PUBLICKEYBYTES_MAX :: INDCPA_PUBLICKEYBYTES_1024
|
||||
INDCPA_BYTES_512 :: POLYVECCOMPRESSEDBYTES_512 + POLYCOMPRESSEDBYTES_512
|
||||
INDCPA_BYTES_768 :: POLYVECCOMPRESSEDBYTES_768 + POLYCOMPRESSEDBYTES_768
|
||||
INDCPA_BYTES_1024 :: POLYVECCOMPRESSEDBYTES_1024 + POLYCOMPRESSEDBYTES_1024
|
||||
|
||||
ENCAPSKEYBYTES_512 :: INDCPA_PUBLICKEYBYTES_512
|
||||
ENCAPSKEYBYTES_768 :: INDCPA_PUBLICKEYBYTES_768
|
||||
ENCAPSKEYBYTES_1024 :: INDCPA_PUBLICKEYBYTES_1024
|
||||
|
||||
DECAPSKEYBYTES_512 :: INDCPA_SECRETKEYBYTES_512 + INDCPA_PUBLICKEYBYTES_512 + 2 * SYMBYTES
|
||||
DECAPSKEYBYTES_768 :: INDCPA_SECRETKEYBYTES_768 + INDCPA_PUBLICKEYBYTES_768 + 2 * SYMBYTES
|
||||
DECAPSKEYBYTES_1024 :: INDCPA_SECRETKEYBYTES_1024 + INDCPA_PUBLICKEYBYTES_1024 + 2 * SYMBYTES
|
||||
|
||||
CIPHERTEXTBYTES_512 :: INDCPA_BYTES_512
|
||||
CIPHERTEXTBYTES_768 :: INDCPA_BYTES_768
|
||||
CIPHERTEXTBYTES_1024 :: INDCPA_BYTES_1024
|
||||
CIPHERTEXTBYTES_MAX :: INDCPA_BYTES_1024
|
||||
349
core/crypto/_mlkem/k_pke.odin
Normal file
349
core/crypto/_mlkem/k_pke.odin
Normal file
@@ -0,0 +1,349 @@
|
||||
#+private
|
||||
package _mlkem
|
||||
|
||||
import "core:crypto"
|
||||
import "core:crypto/shake"
|
||||
|
||||
@(require_results)
|
||||
pack_pk :: proc "contextless" (r: []byte, pk: ^Polyvec, seed: []byte, k: int) -> bool {
|
||||
pk_len := polyvec_byte_size(k)
|
||||
switch {
|
||||
case pk_len == 0:
|
||||
return false
|
||||
case len(seed) != SYMBYTES || len(r) != pk_len + SYMBYTES:
|
||||
return false
|
||||
}
|
||||
|
||||
polyvec_tobytes(r[:pk_len], pk, k)
|
||||
copy(r[pk_len:], seed)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
unpack_pk :: proc "contextless" (pk: ^Polyvec, seed, packedpk: []byte) -> bool {
|
||||
pk_len := len(packedpk) - SYMBYTES
|
||||
k: int
|
||||
switch {
|
||||
case pk_len == POLYVECBYTES_512:
|
||||
k = K_512
|
||||
case pk_len == POLYVECBYTES_768:
|
||||
k = K_768
|
||||
case pk_len == POLYVECBYTES_1024:
|
||||
k = K_1024
|
||||
case len(packedpk) - pk_len != SYMBYTES:
|
||||
return false
|
||||
case len(seed) != SYMBYTES:
|
||||
return false
|
||||
}
|
||||
if k == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
ok := polyvec_frombytes(pk, packedpk[:pk_len], k)
|
||||
copy(seed, packedpk[pk_len:])
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
pack_sk :: proc "contextless" (r: []byte, sk: ^Polyvec, k: int) -> bool {
|
||||
r_len := len(r)
|
||||
if r_len == 0 || r_len != polyvec_byte_size(k) {
|
||||
return false
|
||||
}
|
||||
|
||||
polyvec_tobytes(r, sk, k)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
unpack_sk :: proc "contextless" (sk: ^Polyvec, packedsk: []byte) -> bool {
|
||||
k: int
|
||||
switch len(packedsk) {
|
||||
case POLYVECBYTES_512:
|
||||
k = K_512
|
||||
case POLYVECBYTES_768:
|
||||
k = K_768
|
||||
case POLYVECBYTES_1024:
|
||||
k = K_1024
|
||||
case:
|
||||
return false
|
||||
}
|
||||
if k == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
return polyvec_frombytes(sk, packedsk, k)
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
pack_ciphertext :: proc "contextless" (r: []byte, b: ^Polyvec, v: ^Poly, k: int) -> bool {
|
||||
b_len := polyvec_compressed_byte_size(k)
|
||||
if len(r) != b_len + poly_compressed_bytes(k) {
|
||||
return false
|
||||
}
|
||||
|
||||
polyvec_compress(r[:b_len], b, k)
|
||||
poly_compress(r[b_len:], v)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
unpack_ciphertext :: proc "contextless" (b: ^Polyvec, v: ^Poly, c: []byte) -> int {
|
||||
b_len: int
|
||||
k: int
|
||||
switch len(c) {
|
||||
case INDCPA_BYTES_512:
|
||||
b_len = POLYVECCOMPRESSEDBYTES_512
|
||||
k = K_512
|
||||
case INDCPA_BYTES_768:
|
||||
b_len = POLYVECCOMPRESSEDBYTES_768
|
||||
k = K_768
|
||||
case INDCPA_BYTES_1024:
|
||||
b_len = POLYVECCOMPRESSEDBYTES_1024
|
||||
k = K_1024
|
||||
case:
|
||||
return 0
|
||||
}
|
||||
|
||||
polyvec_decompress(b, c[:b_len], k)
|
||||
poly_decompress(v, c[b_len:])
|
||||
|
||||
return k
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
rej_uniform :: proc "contextless" (r: []i16, buf: []byte) -> int {
|
||||
r_len, b_len := len(r), len(buf)
|
||||
|
||||
ctr, pos: int
|
||||
for ctr < r_len && pos + 3 <= b_len {
|
||||
val0 := (u16(buf[pos+0] >> 0) | (u16(buf[pos+1]) << 8)) & 0xFFF
|
||||
val1 := (u16(buf[pos+1] >> 4) | (u16(buf[pos+2]) << 4)) & 0xFFF
|
||||
pos += 3
|
||||
|
||||
if val0 < Q {
|
||||
r[ctr] = i16(val0)
|
||||
ctr += 1
|
||||
}
|
||||
if(ctr < r_len && val1 < Q) {
|
||||
r[ctr] = i16(val1)
|
||||
ctr += 1
|
||||
}
|
||||
}
|
||||
|
||||
return ctr
|
||||
}
|
||||
|
||||
gen_matrix :: proc(a: []Polyvec, seed: []byte, transposed: bool, k: int) {
|
||||
GEN_MATRIX_NBLOCKS :: ((12*N/8*(1 << 12)/Q + XOF_BLOCKBYTES)/XOF_BLOCKBYTES)
|
||||
|
||||
buf: [GEN_MATRIX_NBLOCKS*XOF_BLOCKBYTES]byte = ---
|
||||
ctx: shake.Context = ---
|
||||
ctr: int
|
||||
|
||||
defer shake.reset(&ctx)
|
||||
defer crypto.zero_explicit(&buf, size_of(buf))
|
||||
|
||||
for i in 0..<k {
|
||||
for j in 0..<k {
|
||||
switch transposed {
|
||||
case true:
|
||||
xof_absorb(&ctx, seed, byte(i), byte(j))
|
||||
case false:
|
||||
xof_absorb(&ctx, seed, byte(j), byte(i))
|
||||
}
|
||||
|
||||
shake.read(&ctx, buf[:])
|
||||
ctr = rej_uniform(a[i].vec[j].coeffs[:], buf[:])
|
||||
|
||||
b := buf[:XOF_BLOCKBYTES]
|
||||
for ctr < N {
|
||||
shake.read(&ctx, b)
|
||||
ctr += rej_uniform(a[i].vec[j].coeffs[ctr:], b)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
K_PKE_Decryption_Key :: struct {
|
||||
pv: Polyvec,
|
||||
k: int,
|
||||
}
|
||||
|
||||
K_PKE_Encryption_Key :: struct {
|
||||
pv: Polyvec,
|
||||
p: [SYMBYTES]byte,
|
||||
k: int,
|
||||
}
|
||||
|
||||
k_pke_encryption_key_set :: proc(dst, src: ^K_PKE_Encryption_Key) {
|
||||
k_pke_key_clear(dst)
|
||||
|
||||
for i in 0..<src.k {
|
||||
copy(dst.pv.vec[i].coeffs[:], src.pv.vec[i].coeffs[:])
|
||||
}
|
||||
copy(dst.p[:], src.p[:])
|
||||
dst.k = src.k
|
||||
}
|
||||
|
||||
k_pke_key_clear :: proc(k: $T) where T == ^K_PKE_Encryption_Key || T == ^K_PKE_Decryption_Key {
|
||||
crypto.zero_explicit(k, size_of(k^))
|
||||
}
|
||||
|
||||
k_pke_keygen :: proc(
|
||||
ek: ^K_PKE_Encryption_Key,
|
||||
dk: ^K_PKE_Decryption_Key,
|
||||
d: []byte,
|
||||
k: int,
|
||||
) {
|
||||
assert(len(d) == SYMBYTES, "crypto/mlkem: invalid K-PKE d")
|
||||
ensure(k == K_512 || k == K_768 || k == K_1024, "crypto/mlkem: invalid k")
|
||||
|
||||
buf: [2*SYMBYTES]byte = ---
|
||||
defer crypto.zero_explicit(&buf, size_of(buf))
|
||||
|
||||
a_: [K_MAX]Polyvec = ---
|
||||
e: Polyvec = ---
|
||||
defer crypto.zero_explicit(&a_, size_of(Polyvec) * k)
|
||||
defer polyvec_clear(&e)
|
||||
|
||||
a := a_[:k]
|
||||
|
||||
copy(buf[:], d)
|
||||
buf[SYMBYTES] = byte(k)
|
||||
hash_g(buf[:], buf[:SYMBYTES+1])
|
||||
|
||||
p, sigma := buf[:SYMBYTES], buf[SYMBYTES:]
|
||||
|
||||
gen_matrix(a, p, false, k)
|
||||
|
||||
n := byte(0)
|
||||
for i in 0..<k {
|
||||
if k != K_512 {
|
||||
poly_getnoise_eta1(&dk.pv.vec[i], sigma, n)
|
||||
} else {
|
||||
poly_getnoise_eta1_512(&dk.pv.vec[i], sigma, n)
|
||||
}
|
||||
n += 1
|
||||
}
|
||||
for i in 0..<k {
|
||||
if k != K_512 {
|
||||
poly_getnoise_eta1(&e.vec[i], sigma, n)
|
||||
} else {
|
||||
poly_getnoise_eta1_512(&e.vec[i], sigma, n)
|
||||
}
|
||||
n += 1
|
||||
}
|
||||
|
||||
polyvec_ntt(&dk.pv, k)
|
||||
polyvec_ntt(&e, k)
|
||||
|
||||
for i in 0..<k {
|
||||
polyvec_basemul_acc_montgomery(&ek.pv.vec[i], &a[i], &dk.pv, k)
|
||||
poly_tomont(&ek.pv.vec[i])
|
||||
}
|
||||
|
||||
polyvec_add(&ek.pv, &ek.pv, &e, k)
|
||||
polyvec_reduce(&ek.pv, k)
|
||||
|
||||
copy(ek.p[:], p)
|
||||
|
||||
dk.k = k
|
||||
ek.k = k
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
k_pke_encrypt :: proc(
|
||||
ciphertext: []byte,
|
||||
ek: ^K_PKE_Encryption_Key,
|
||||
m: []byte,
|
||||
r: []byte,
|
||||
) -> bool {
|
||||
ensure(len(m) == INDCPA_MSGBYTES, "crypto/mlkem: invalid K-PKE m")
|
||||
ensure(len(r) == SYMBYTES, "crypto/mlkem: invalid K-PKE r")
|
||||
|
||||
k := ek.k
|
||||
|
||||
at_: [K_MAX]Polyvec = ---
|
||||
sp, ep, b: Polyvec = ---, ---, ---
|
||||
kay, epp, v: Poly = ---, ---, ---
|
||||
defer crypto.zero_explicit(&at_, size_of(Polyvec) * k)
|
||||
defer polyvec_clear(&sp, &ep, &b)
|
||||
defer poly_clear(&kay, &epp, &v)
|
||||
|
||||
poly_frommsg(&kay, m)
|
||||
|
||||
at := at_[:k]
|
||||
|
||||
gen_matrix(at, ek.p[:], true, k)
|
||||
|
||||
n := byte(0)
|
||||
for i in 0..<k {
|
||||
if k != K_512 {
|
||||
poly_getnoise_eta1(&sp.vec[i], r, n)
|
||||
} else {
|
||||
poly_getnoise_eta1_512(&sp.vec[i], r, n)
|
||||
}
|
||||
n += 1
|
||||
}
|
||||
for i in 0..<k {
|
||||
poly_getnoise_eta2(&ep.vec[i], r, n)
|
||||
n += 1
|
||||
}
|
||||
poly_getnoise_eta2(&epp, r, n)
|
||||
|
||||
polyvec_ntt(&sp, k)
|
||||
|
||||
for i in 0..<k {
|
||||
polyvec_basemul_acc_montgomery(&b.vec[i], &at[i], &sp, k)
|
||||
}
|
||||
|
||||
polyvec_basemul_acc_montgomery(&v, &ek.pv, &sp, k)
|
||||
|
||||
polyvec_invntt_tomont(&b, k)
|
||||
poly_invntt_tomont(&v)
|
||||
|
||||
polyvec_add(&b, &b, &ep, k)
|
||||
poly_add(&v, &v, &epp)
|
||||
poly_add(&v, &v, &kay)
|
||||
polyvec_reduce(&b, k)
|
||||
poly_reduce(&v)
|
||||
|
||||
return pack_ciphertext(ciphertext, &b, &v, k)
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
k_pke_decrypt :: proc(
|
||||
plaintext: []byte,
|
||||
dk: ^K_PKE_Decryption_Key,
|
||||
c: []byte,
|
||||
) -> bool {
|
||||
if len(plaintext) != INDCPA_MSGBYTES {
|
||||
return false
|
||||
}
|
||||
|
||||
k := dk.k
|
||||
|
||||
b: Polyvec = ---
|
||||
v, mp: Poly = ---, ---
|
||||
defer poly_clear(&v, &mp)
|
||||
|
||||
if unpack_ciphertext(&b, &v, c) != k {
|
||||
return false
|
||||
}
|
||||
|
||||
polyvec_ntt(&b, k)
|
||||
polyvec_basemul_acc_montgomery(&mp, &dk.pv, &b, k)
|
||||
poly_invntt_tomont(&mp)
|
||||
|
||||
poly_sub(&mp, &v, &mp)
|
||||
poly_reduce(&mp)
|
||||
|
||||
poly_tomsg(plaintext, &mp)
|
||||
|
||||
return true
|
||||
}
|
||||
195
core/crypto/_mlkem/kem_internal.odin
Normal file
195
core/crypto/_mlkem/kem_internal.odin
Normal file
@@ -0,0 +1,195 @@
|
||||
package _mlkem
|
||||
|
||||
import "core:crypto"
|
||||
import subtle "core:crypto/_subtle"
|
||||
|
||||
// This implementation is derived from the PQ-CRYSTALS reference
|
||||
// implementation [[ https://github.com/pq-crystals/kyber ]],
|
||||
// primarily for licensing reasons. Arguably mlkem-native is
|
||||
// a more "up to date" codebase, but the changes to the
|
||||
// ref code is minor and they slapped an attribution-required
|
||||
// license on something that was originally CC-0/Apache 2.0.
|
||||
|
||||
// "Private Key"
|
||||
Decapsulation_Key :: struct {
|
||||
pke_dk: K_PKE_Decryption_Key,
|
||||
ek: Encapsulation_Key,
|
||||
seed: [SYMBYTES*2]byte, // (d, z)
|
||||
}
|
||||
|
||||
// "Public Key"
|
||||
Encapsulation_Key :: struct {
|
||||
pke_ek: K_PKE_Encryption_Key,
|
||||
raw_bytes: [INDCPA_PUBLICKEYBYTES_MAX]byte,
|
||||
h: [SYMBYTES]byte,
|
||||
}
|
||||
|
||||
decapsulation_key_expanded_bytes :: proc(
|
||||
dk: ^Decapsulation_Key,
|
||||
dst: []byte,
|
||||
) {
|
||||
sk := &dk.pke_dk
|
||||
pv_len := polyvec_byte_size(sk.k)
|
||||
ek_len := pv_len + SYMBYTES
|
||||
|
||||
ek_bytes := dk.ek.raw_bytes[:ek_len]
|
||||
|
||||
dst := dst
|
||||
_ = pack_sk(dst[:pv_len], &sk.pv, sk.k)
|
||||
|
||||
dst = dst[pv_len:]
|
||||
copy(dst, ek_bytes)
|
||||
dst = dst[ek_len:]
|
||||
hash_h(dst[:SYMBYTES], ek_bytes)
|
||||
dst = dst[SYMBYTES:]
|
||||
copy(dst, dk.seed[SYMBYTES:])
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
encapsulation_key_set_bytes :: proc(
|
||||
ek: ^Encapsulation_Key,
|
||||
k: int,
|
||||
b: []byte,
|
||||
) -> bool {
|
||||
k_len: int
|
||||
switch k {
|
||||
case K_512:
|
||||
k_len = ENCAPSKEYBYTES_512
|
||||
case K_768:
|
||||
k_len = ENCAPSKEYBYTES_768
|
||||
case K_1024:
|
||||
k_len = ENCAPSKEYBYTES_1024
|
||||
case:
|
||||
return false
|
||||
}
|
||||
if len(b) != k_len {
|
||||
return false
|
||||
}
|
||||
|
||||
pke_ek := &ek.pke_ek
|
||||
ok := unpack_pk(&pke_ek.pv, pke_ek.p[:], b)
|
||||
pke_ek.k = k
|
||||
copy(ek.raw_bytes[:k_len], b)
|
||||
hash_h(ek.h[:], b)
|
||||
|
||||
// FIPS 203 unlike Kyber requires canonical encoding of
|
||||
// encapsulation keys (Section 7,2), which is checked in
|
||||
// unpack_pk.
|
||||
|
||||
if !ok {
|
||||
crypto.zero_explicit(ek, size_of(Encapsulation_Key))
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
encapsulation_key_set_decaps :: proc(ek: ^Encapsulation_Key, dk: ^Decapsulation_Key) {
|
||||
dk_ek := &dk.ek.pke_ek
|
||||
ensure(dk_ek.k == K_512 || dk_ek.k == K_768 || dk_ek.k == K_1024, "crypto/mlkem: invalid decaps k")
|
||||
|
||||
k_pke_encryption_key_set(&ek.pke_ek, dk_ek)
|
||||
copy(ek.raw_bytes[:], dk.ek.raw_bytes[:])
|
||||
copy(ek.h[:], dk.ek.h[:])
|
||||
}
|
||||
|
||||
// NIST's version of this also returns an encapsulation key, but our
|
||||
// internal representation includes it as part of the decapsulation key
|
||||
// in a more traditional "keypair" approach.
|
||||
kem_keygen_internal :: proc(
|
||||
dk: ^Decapsulation_Key,
|
||||
seed: []byte, // (d, z)
|
||||
k: int,
|
||||
) {
|
||||
ensure(len(seed) == 2 * SYMBYTES, "crypto/mlkem: invalid seed")
|
||||
|
||||
dk_ek := &dk.ek
|
||||
d, z := seed[:SYMBYTES], seed[SYMBYTES:]
|
||||
|
||||
k_pke_keygen(&dk_ek.pke_ek, &dk.pke_dk, d, k)
|
||||
|
||||
ek_len := polyvec_byte_size(k) + SYMBYTES
|
||||
ek_bytes := dk_ek.raw_bytes[:ek_len]
|
||||
ensure(
|
||||
pack_pk(ek_bytes, &dk_ek.pke_ek.pv, dk_ek.pke_ek.p[:], k),
|
||||
"crypto/mlkem: failed to pack K-PKE ek",
|
||||
)
|
||||
hash_h(dk_ek.h[:], ek_bytes)
|
||||
copy(dk.seed[:SYMBYTES], d)
|
||||
copy(dk.seed[SYMBYTES:], z)
|
||||
}
|
||||
|
||||
// The `_internal` "de-randomized" versions of ML-KEM.Encaps and
|
||||
// ML-KEM.Decaps are only ever to be called by the actual non-interal
|
||||
// implementation or test cases.
|
||||
|
||||
kem_encaps_internal :: proc(
|
||||
shared_secret: []byte,
|
||||
ciphertext: []byte,
|
||||
ek: ^Encapsulation_Key,
|
||||
randomness: []byte,
|
||||
) {
|
||||
ensure(len(shared_secret) == SYMBYTES, "crypto/mlkem: invalid K")
|
||||
ensure(len(randomness) == SYMBYTES, "crypto/mlkem: invalid m")
|
||||
ensure(
|
||||
len(ciphertext) == ct_len_for_k(ek.pke_ek.k),
|
||||
"crypto/mlkem: invalid ciphertext length",
|
||||
)
|
||||
|
||||
buf: [2*SYMBYTES]byte = ---
|
||||
defer crypto.zero_explicit(&buf, size_of(buf))
|
||||
|
||||
hash_g(buf[:], randomness, ek.h[:])
|
||||
|
||||
// Can't fail, ciphertext length is valid.
|
||||
_ = k_pke_encrypt(ciphertext, &ek.pke_ek, randomness, buf[SYMBYTES:])
|
||||
|
||||
copy(shared_secret, buf[:SYMBYTES])
|
||||
}
|
||||
|
||||
kem_decaps_internal :: proc(
|
||||
shared_secret: []byte,
|
||||
dk: ^Decapsulation_Key,
|
||||
ciphertext: []byte,
|
||||
) {
|
||||
ct_len := ct_len_for_k(dk.pke_dk.k)
|
||||
ensure(
|
||||
len(ciphertext) == ct_len,
|
||||
"crypto/mlkem: invalid ciphertext length",
|
||||
)
|
||||
|
||||
m_: [SYMBYTES]byte
|
||||
defer crypto.zero_explicit(&m_, size_of(m_))
|
||||
|
||||
// Can't fail, ciphertext length is valid.
|
||||
_ = k_pke_decrypt(m_[:], &dk.pke_dk, ciphertext)
|
||||
|
||||
buf: [2*SYMBYTES]byte = ---
|
||||
defer crypto.zero_explicit(&buf, size_of(buf))
|
||||
|
||||
ek := &dk.ek
|
||||
hash_g(buf[:], m_[:], ek.h[:])
|
||||
|
||||
rkprf(shared_secret, dk.seed[SYMBYTES:], ciphertext)
|
||||
|
||||
ct_buf: [CIPHERTEXTBYTES_MAX]byte = ---
|
||||
defer crypto.zero_explicit(&ct_buf, size_of(ct_buf))
|
||||
ct_ := ct_buf[:ct_len]
|
||||
_ = k_pke_encrypt(ct_, &ek.pke_ek, m_[:], buf[SYMBYTES:])
|
||||
|
||||
ok := crypto.compare_constant_time(ciphertext, ct_)
|
||||
subtle.cmov_bytes(shared_secret, buf[:SYMBYTES], ok)
|
||||
}
|
||||
|
||||
@(private="file")
|
||||
ct_len_for_k :: proc(k: int) -> int {
|
||||
switch k {
|
||||
case K_512:
|
||||
return CIPHERTEXTBYTES_512
|
||||
case K_768:
|
||||
return CIPHERTEXTBYTES_768
|
||||
case K_1024:
|
||||
return CIPHERTEXTBYTES_1024
|
||||
case:
|
||||
panic("crypto/mlkem: invalid k for ciphertext length")
|
||||
}
|
||||
}
|
||||
75
core/crypto/_mlkem/ntt.odin
Normal file
75
core/crypto/_mlkem/ntt.odin
Normal file
@@ -0,0 +1,75 @@
|
||||
#+private
|
||||
package _mlkem
|
||||
|
||||
@(rodata)
|
||||
ZETAS := [128]i16 {
|
||||
-1044, -758, -359, -1517, 1493, 1422, 287, 202,
|
||||
-171, 622, 1577, 182, 962, -1202, -1474, 1468,
|
||||
573, -1325, 264, 383, -829, 1458, -1602, -130,
|
||||
-681, 1017, 732, 608, -1542, 411, -205, -1571,
|
||||
1223, 652, -552, 1015, -1293, 1491, -282, -1544,
|
||||
516, -8, -320, -666, -1618, -1162, 126, 1469,
|
||||
-853, -90, -271, 830, 107, -1421, -247, -951,
|
||||
-398, 961, -1508, -725, 448, -1065, 677, -1275,
|
||||
-1103, 430, 555, 843, -1251, 871, 1550, 105,
|
||||
422, 587, 177, -235, -291, -460, 1574, 1653,
|
||||
-246, 778, 1159, -147, -777, 1483, -602, 1119,
|
||||
-1590, 644, -872, 349, 418, 329, -156, -75,
|
||||
817, 1097, 603, 610, 1322, -1285, -1465, 384,
|
||||
-1215, -136, 1218, -1335, -874, 220, -1187, -1659,
|
||||
-1185, -1530, -1278, 794, -1510, -854, -870, 478,
|
||||
-108, -308, 996, 991, 958, -1460, 1522, 1628,
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
fqmul :: #force_inline proc "contextless" (a, b: i16) -> i16 {
|
||||
return montgomery_reduce(i32(a) * i32(b))
|
||||
}
|
||||
|
||||
ntt :: proc "contextless" (r: ^[N]i16) #no_bounds_check {
|
||||
j, k := 0, 1
|
||||
for l := 128; l >= 2; l >>= 1 {
|
||||
for start := 0; start < N; start = j + l {
|
||||
zeta := ZETAS[k]
|
||||
k += 1
|
||||
for j = start; j < start + l; j += 1 {
|
||||
t := fqmul(zeta, r[j+l])
|
||||
r[j+l] = r[j] - t
|
||||
r[j] = r[j] + t
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
invntt :: proc "contextless" (r: ^[N]i16) #no_bounds_check {
|
||||
F : i16 : 1441 // mont^2/128
|
||||
|
||||
j, k := 0, 127
|
||||
for l := 2; l <= 128; l <<= 1 {
|
||||
for start := 0; start < 256; start = j+l {
|
||||
zeta := ZETAS[k]
|
||||
k -= 1
|
||||
for j = start; j < start + l; j += 1 {
|
||||
t := r[j]
|
||||
r[j] = barrett_reduce(t + r[j+l])
|
||||
r[j+l] = r[j+l] - t
|
||||
r[j+l] = fqmul(zeta, r[j+l])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for v, i in r {
|
||||
r[i] = fqmul(v, F)
|
||||
}
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
base_case_multiply :: proc "contextless" (a_0, a_1, b_0, b_1, zeta: i16) -> (i16, i16) {
|
||||
r_0 := fqmul(a_1, b_1)
|
||||
r_0 = fqmul(r_0, zeta)
|
||||
r_0 += fqmul(a_0, b_0)
|
||||
r_1 := fqmul(a_0, b_1)
|
||||
r_1 += fqmul(a_1, b_0)
|
||||
|
||||
return r_0, r_1
|
||||
}
|
||||
241
core/crypto/_mlkem/poly.odin
Normal file
241
core/crypto/_mlkem/poly.odin
Normal file
@@ -0,0 +1,241 @@
|
||||
#+private
|
||||
package _mlkem
|
||||
|
||||
import "core:crypto"
|
||||
import subtle "core:crypto/_subtle"
|
||||
|
||||
// Elements of R_q = Z_q[X]/(X^n + 1). Represents polynomial
|
||||
// coeffs[0] + X*coeffs[1] + X^2*coeffs[2] + ... + X^{n-1}*coeffs[n-1]
|
||||
Poly :: struct {
|
||||
coeffs: [N]i16,
|
||||
}
|
||||
|
||||
poly_compress :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check {
|
||||
t: [8]byte = ---
|
||||
defer crypto.zero_explicit(&t, size_of(t))
|
||||
|
||||
r := r
|
||||
switch len(r) {
|
||||
case POLYCOMPRESSEDBYTES_768: // Also covers _512
|
||||
for i in 0..<N/8 {
|
||||
for j in 0..<8 {
|
||||
// map to positive standard representatives
|
||||
u := a.coeffs[8*i+j]
|
||||
u += (u >> 15) & Q
|
||||
// t[j] = ((((uint16_t)u << 4) + Q/2)/Q) & 15
|
||||
d0 := u32(u) << 4
|
||||
d0 += 1665
|
||||
d0 *= 80635
|
||||
d0 >>= 28
|
||||
t[j] = byte(d0) & 0xf
|
||||
}
|
||||
|
||||
r[0] = t[0] | (t[1] << 4)
|
||||
r[1] = t[2] | (t[3] << 4)
|
||||
r[2] = t[4] | (t[5] << 4)
|
||||
r[3] = t[6] | (t[7] << 4)
|
||||
r = r[4:]
|
||||
}
|
||||
case POLYCOMPRESSEDBYTES_1024:
|
||||
for i in 0..<N/8 {
|
||||
for j in 0..<8 {
|
||||
// map to positive standard representatives
|
||||
u := a.coeffs[8*i+j]
|
||||
u += (u >> 15) & Q
|
||||
// t[j] = ((((uint16_t)u << 5) + Q/2)/Q) & 31
|
||||
d0 := u32(u) << 5
|
||||
d0 += 1664
|
||||
d0 *= 40318
|
||||
d0 >>= 27
|
||||
t[j] = byte(d0) & 0x1f
|
||||
}
|
||||
|
||||
r[0] = (t[0] >> 0) | (t[1] << 5)
|
||||
r[1] = (t[1] >> 3) | (t[2] << 2) | (t[3] << 7)
|
||||
r[2] = (t[3] >> 1) | (t[4] << 4)
|
||||
r[3] = (t[4] >> 4) | (t[5] << 1) | (t[6] << 6)
|
||||
r[4] = (t[6] >> 2) | (t[7] << 3)
|
||||
r = r[5:]
|
||||
}
|
||||
case:
|
||||
panic_contextless("crypto/mlkem: invalid POLYCOMPRESSEDBYTES")
|
||||
}
|
||||
}
|
||||
|
||||
poly_decompress :: proc "contextless" (r: ^Poly, a: []byte) {
|
||||
a := a
|
||||
switch len(a) {
|
||||
case POLYCOMPRESSEDBYTES_768: // Also covers _512
|
||||
for i in 0..<N/2 {
|
||||
r.coeffs[2*i+0] = i16(((u16(a[0] & 15) * Q) + 8) >> 4)
|
||||
r.coeffs[2*i+1] = i16(((u16(a[0] >> 4) * Q) + 8) >> 4)
|
||||
a = a[1:]
|
||||
}
|
||||
case POLYCOMPRESSEDBYTES_1024:
|
||||
t: [8]byte = ---
|
||||
defer crypto.zero_explicit(&t, size_of(t))
|
||||
|
||||
for i in 0..<N/8 {
|
||||
t[0] = (a[0] >> 0)
|
||||
t[1] = (a[0] >> 5) | (a[1] << 3)
|
||||
t[2] = (a[1] >> 2)
|
||||
t[3] = (a[1] >> 7) | (a[2] << 1)
|
||||
t[4] = (a[2] >> 4) | (a[3] << 4)
|
||||
t[5] = (a[3] >> 1)
|
||||
t[6] = (a[3] >> 6) | (a[4] << 2)
|
||||
t[7] = (a[4] >> 3)
|
||||
a = a[5:]
|
||||
|
||||
for j in 0..<8 {
|
||||
r.coeffs[8*i+j] = i16((u32(t[j] & 31) * Q + 16) >> 5)
|
||||
}
|
||||
}
|
||||
case:
|
||||
panic_contextless("crypto/mlkem: invalid POLYCOMPRESSEDBYTES")
|
||||
}
|
||||
}
|
||||
|
||||
poly_tobytes :: proc "contextless" (r: []byte, a: ^Poly) #no_bounds_check {
|
||||
ensure_contextless(len(r) >= POLYBYTES)
|
||||
|
||||
for i in 0..<N/2 {
|
||||
// map to positive standard representatives
|
||||
t0 := u16(a.coeffs[2*i])
|
||||
t0 += u16((i16(t0) >> 15) & Q)
|
||||
t1 := u16(a.coeffs[2*i+1])
|
||||
t1 += u16((i16(t1) >> 15) & Q)
|
||||
r[3*i+0] = byte(t0 >> 0)
|
||||
r[3*i+1] = byte(t0 >> 8) | byte(t1 << 4)
|
||||
r[3*i+2] = byte(t1 >> 4)
|
||||
}
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
poly_frombytes :: proc "contextless" (r: ^Poly, a: []byte) -> bool #no_bounds_check {
|
||||
ensure_contextless(len(a) >= POLYBYTES)
|
||||
|
||||
ok := true
|
||||
for i in 0..<N/2 {
|
||||
r.coeffs[2*i] = i16(((u16(a[3*i+0]) >> 0) | (u16(a[3*i+1]) << 8)) & 0xFFF)
|
||||
r.coeffs[2*i+1] = i16(((u16(a[3*i+1]) >> 4) | (u16(a[3*i+2]) << 4)) & 0xFFF)
|
||||
ok &= r.coeffs[2*i] < Q && r.coeffs[2*i+1] < Q
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
poly_frommsg :: proc "contextless" (r: ^Poly, msg: []byte) #no_bounds_check {
|
||||
#assert(INDCPA_MSGBYTES == N/8)
|
||||
ensure_contextless(len(msg) == INDCPA_MSGBYTES)
|
||||
|
||||
for i in 0..<N/8 {
|
||||
for j in 0..<8 {
|
||||
r.coeffs[8*i+j] = subtle.csel_i16(0, (Q+1)/2, int(msg[i] >> uint(j))&1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
poly_tomsg :: proc "contextless" (msg: []byte, a: ^Poly) #no_bounds_check {
|
||||
ensure_contextless(len(msg) == INDCPA_MSGBYTES)
|
||||
|
||||
for i in 0..<N/8 {
|
||||
msg[i] = 0
|
||||
for j in uint(0)..<8 {
|
||||
t := u32(a.coeffs[8*i+int(j)])
|
||||
// t += ((int16_t)t >> 15) & Q
|
||||
// t = (((t << 1) + Q/2)/Q) & 1
|
||||
t <<= 1
|
||||
t += 1665
|
||||
t *= 80635
|
||||
t >>= 28
|
||||
t &= 1
|
||||
msg[i] |= byte(t << j)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
poly_getnoise_eta1_512 :: proc(r: ^Poly, seed: []byte, iv: byte) {
|
||||
buf: [ETA1_512*N/4]byte = ---
|
||||
defer crypto.zero_explicit(&buf, size_of(buf))
|
||||
|
||||
prf(buf[:], seed, iv)
|
||||
poly_cbd_eta1_512(r, &buf)
|
||||
}
|
||||
|
||||
poly_getnoise_eta1 :: proc(r: ^Poly, seed: []byte, iv: byte) {
|
||||
buf: [ETA1*N/4]byte = ---
|
||||
defer crypto.zero_explicit(&buf, size_of(buf))
|
||||
|
||||
prf(buf[:], seed, iv)
|
||||
poly_cbd_eta1(r, &buf)
|
||||
}
|
||||
|
||||
poly_getnoise_eta2 :: proc(r: ^Poly, seed: []byte, iv: byte) {
|
||||
buf: [ETA2*N/4]byte = ---
|
||||
defer crypto.zero_explicit(&buf, size_of(buf))
|
||||
|
||||
prf(buf[:], seed, iv)
|
||||
poly_cbd_eta2(r, &buf)
|
||||
}
|
||||
|
||||
poly_ntt :: proc "contextless" (r: ^Poly) {
|
||||
ntt(&r.coeffs)
|
||||
poly_reduce(r)
|
||||
}
|
||||
|
||||
poly_invntt_tomont :: proc "contextless" (r: ^Poly) {
|
||||
invntt(&r.coeffs)
|
||||
}
|
||||
|
||||
poly_basemul_montgomery :: proc "contextless" (r, a, b: ^Poly) #no_bounds_check {
|
||||
for i in 0..<N/4 {
|
||||
j := 4 * i
|
||||
r.coeffs[j], r.coeffs[j+1] = base_case_multiply(a.coeffs[j], a.coeffs[j+1], b.coeffs[j], b.coeffs[j+1], ZETAS[64+i])
|
||||
r.coeffs[j+2], r.coeffs[j+3] = base_case_multiply(a.coeffs[j+2], a.coeffs[j+3], b.coeffs[j+2], b.coeffs[j+3], -ZETAS[64+i])
|
||||
}
|
||||
}
|
||||
|
||||
poly_tomont :: proc "contextless" (r: ^Poly) {
|
||||
F : i16 : (1 << 32) % Q
|
||||
for v, i in r.coeffs {
|
||||
r.coeffs[i] = montgomery_reduce(i32(v)*i32(F))
|
||||
}
|
||||
}
|
||||
|
||||
poly_reduce :: proc "contextless" (r: ^Poly) {
|
||||
for v, i in r.coeffs {
|
||||
r.coeffs[i] = barrett_reduce(v)
|
||||
}
|
||||
}
|
||||
|
||||
poly_add :: proc "contextless" (r, a, b: ^Poly) {
|
||||
for i in 0..<N {
|
||||
r.coeffs[i] = a.coeffs[i] + b.coeffs[i]
|
||||
}
|
||||
}
|
||||
|
||||
poly_sub :: proc "contextless" (r, a, b: ^Poly) {
|
||||
for i in 0..<N {
|
||||
r.coeffs[i] = a.coeffs[i] - b.coeffs[i]
|
||||
}
|
||||
}
|
||||
|
||||
poly_clear :: proc "contextless" (a: ..^Poly) {
|
||||
for j in 0..<len(a) {
|
||||
p := a[j]
|
||||
crypto.zero_explicit(p, size_of(Poly))
|
||||
}
|
||||
}
|
||||
|
||||
poly_compressed_bytes :: #force_inline proc "contextless" (k: int) -> int {
|
||||
switch k {
|
||||
case K_512:
|
||||
return POLYCOMPRESSEDBYTES_512
|
||||
case K_768:
|
||||
return POLYCOMPRESSEDBYTES_768
|
||||
case K_1024:
|
||||
return POLYCOMPRESSEDBYTES_1024
|
||||
case:
|
||||
panic_contextless("crypto/mlkem: invalid k")
|
||||
}
|
||||
}
|
||||
224
core/crypto/_mlkem/polyvec.odin
Normal file
224
core/crypto/_mlkem/polyvec.odin
Normal file
@@ -0,0 +1,224 @@
|
||||
#+private
|
||||
package _mlkem
|
||||
|
||||
import "core:crypto"
|
||||
|
||||
Polyvec :: struct {
|
||||
vec: [K_MAX]Poly,
|
||||
}
|
||||
|
||||
polyvec_compress :: proc "contextless" (r: []byte, a: ^Polyvec, kay: int) #no_bounds_check {
|
||||
d0: u64
|
||||
|
||||
r := r
|
||||
switch len(r) {
|
||||
case POLYVECCOMPRESSEDBYTES_512, POLYVECCOMPRESSEDBYTES_768:
|
||||
ensure_contextless(kay == K_512 || kay == K_768)
|
||||
|
||||
t: [4]u16 = ---
|
||||
defer crypto.zero_explicit(&t, size_of(t))
|
||||
|
||||
for i in 0..<kay {
|
||||
for j in 0..<N/4 {
|
||||
for k in 0..<4 {
|
||||
t[k] = u16(a.vec[i].coeffs[4*j+k])
|
||||
t[k] += u16((i16(t[k]) >> 15) & Q)
|
||||
// t[k] = ((((uint32_t)t[k] << 10) + Q/2)/Q) & 0x3ff
|
||||
d0 = u64(t[k])
|
||||
d0 <<= 10
|
||||
d0 += 1665
|
||||
d0 *= 1290167
|
||||
d0 >>= 32
|
||||
t[k] = u16(d0 & 0x3ff)
|
||||
}
|
||||
|
||||
r[0] = byte(t[0] >> 0)
|
||||
r[1] = byte((t[0] >> 8) | (t[1] << 2))
|
||||
r[2] = byte((t[1] >> 6) | (t[2] << 4))
|
||||
r[3] = byte((t[2] >> 4) | (t[3] << 6))
|
||||
r[4] = byte(t[3] >> 2)
|
||||
r = r[5:]
|
||||
}
|
||||
}
|
||||
case POLYVECCOMPRESSEDBYTES_1024:
|
||||
ensure_contextless(kay == K_1024)
|
||||
|
||||
t: [8]u16 = ---
|
||||
defer crypto.zero_explicit(&t, size_of(t))
|
||||
|
||||
for i in 0..<K_1024 {
|
||||
for j in 0..<N/8 {
|
||||
for k in 0..<8 {
|
||||
t[k] = u16(a.vec[i].coeffs[8*j+k])
|
||||
t[k] += u16((i16(t[k]) >> 15) & Q)
|
||||
// t[k] = ((((uint32_t)t[k] << 11) + Q/2)/Q) & 0x7ff
|
||||
d0 = u64(t[k])
|
||||
d0 <<= 11
|
||||
d0 += 1664
|
||||
d0 *= 645084
|
||||
d0 >>= 31
|
||||
t[k] = u16(d0 & 0x7ff)
|
||||
}
|
||||
|
||||
r[0] = byte(t[0] >> 0)
|
||||
r[1] = byte((t[0] >> 8) | (t[1] << 3))
|
||||
r[2] = byte((t[1] >> 5) | (t[2] << 6))
|
||||
r[3] = byte(t[2] >> 2)
|
||||
r[4] = byte((t[2] >> 10) | (t[3] << 1))
|
||||
r[5] = byte((t[3] >> 7) | (t[4] << 4))
|
||||
r[6] = byte((t[4] >> 4) | (t[5] << 7))
|
||||
r[7] = byte(t[5] >> 1)
|
||||
r[8] = byte((t[5] >> 9) | (t[6] << 2))
|
||||
r[9] = byte((t[6] >> 6) | (t[7] << 5))
|
||||
r[10] = byte(t[7] >> 3)
|
||||
r = r[11:]
|
||||
}
|
||||
}
|
||||
case:
|
||||
panic_contextless("crypto/mlkem: invalid POLYVECCOMPRESSEDBYTES")
|
||||
}
|
||||
}
|
||||
|
||||
polyvec_decompress :: proc "contextless" (r: ^Polyvec, a: []byte, kay: int) #no_bounds_check {
|
||||
a := a
|
||||
switch len(a) {
|
||||
case POLYVECCOMPRESSEDBYTES_512, POLYVECCOMPRESSEDBYTES_768:
|
||||
ensure_contextless(kay == K_512 || kay == K_768)
|
||||
|
||||
t: [4]u16 = ---
|
||||
defer crypto.zero_explicit(&t, size_of(t))
|
||||
|
||||
for i in 0..<kay {
|
||||
for j in 0..<N/4 {
|
||||
t[0] = u16(a[0] >> 0) | (u16(a[1]) << 8)
|
||||
t[1] = u16(a[1] >> 2) | (u16(a[2]) << 6)
|
||||
t[2] = u16(a[2] >> 4) | (u16(a[3]) << 4)
|
||||
t[3] = u16(a[3] >> 6) | (u16(a[4]) << 2)
|
||||
a = a[5:]
|
||||
|
||||
for k in 0..<4 {
|
||||
r.vec[i].coeffs[4*j+k] = i16((u32(t[k] & 0x3FF) * Q + 512) >> 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
case POLYVECCOMPRESSEDBYTES_1024:
|
||||
t: [8]u16 = ---
|
||||
defer crypto.zero_explicit(&t, size_of(t))
|
||||
|
||||
for i in 0..<K_1024 {
|
||||
for j in 0..<N/8 {
|
||||
t[0] = u16(a[0] >> 0) | (u16(a[1]) << 8)
|
||||
t[1] = u16(a[1] >> 3) | (u16(a[2]) << 5)
|
||||
t[2] = u16(a[2] >> 6) | (u16(a[3]) << 2) | (u16(a[4]) << 10)
|
||||
t[3] = u16(a[4] >> 1) | (u16(a[5]) << 7)
|
||||
t[4] = u16(a[5] >> 4) | (u16(a[6]) << 4)
|
||||
t[5] = u16(a[6] >> 7) | (u16(a[7]) << 1) | (u16(a[8]) << 9)
|
||||
t[6] = u16(a[8] >> 2) | (u16(a[9]) << 6)
|
||||
t[7] = u16(a[9] >> 5) | (u16(a[10]) << 3)
|
||||
a = a[11:]
|
||||
|
||||
for k in 0..<8 {
|
||||
r.vec[i].coeffs[8*j+k] = i16((u32(t[k] & 0x7FF) * Q + 1024) >> 11)
|
||||
}
|
||||
}
|
||||
}
|
||||
case:
|
||||
panic_contextless("crypto/mlkem: invalid POLYVECCOMPRESSEDBYTES")
|
||||
}
|
||||
}
|
||||
|
||||
polyvec_tobytes :: proc "contextless" (r: []byte, a: ^Polyvec, k: int) #no_bounds_check {
|
||||
ensure_contextless(len(r) == k * POLYBYTES, "crypto/mlkem: invalid buffer")
|
||||
|
||||
r := r
|
||||
for i in 0..<k {
|
||||
poly_tobytes(r, &a.vec[i])
|
||||
r = r[POLYBYTES:]
|
||||
}
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
polyvec_frombytes :: proc "contextless" (r: ^Polyvec, a: []byte, k: int) -> bool #no_bounds_check {
|
||||
switch k {
|
||||
case K_512, K_768, K_1024:
|
||||
case:
|
||||
panic_contextless("crypto/mlkem: invalid POLYVECBYTES")
|
||||
}
|
||||
ensure_contextless(len(a) == k * POLYBYTES, "crypto/mlkem: invalid buffer")
|
||||
|
||||
a := a
|
||||
ok := true
|
||||
for i in 0..<k {
|
||||
ok &= poly_frombytes(&r.vec[i], a)
|
||||
a = a[POLYBYTES:]
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
polyvec_byte_size :: #force_inline proc "contextless" (k: int) -> int {
|
||||
switch k {
|
||||
case K_512, K_768, K_1024:
|
||||
return k * POLYBYTES
|
||||
case:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
polyvec_compressed_byte_size :: #force_inline proc "contextless" (k: int) -> int {
|
||||
switch k {
|
||||
case K_512:
|
||||
return POLYVECCOMPRESSEDBYTES_512
|
||||
case K_768:
|
||||
return POLYVECCOMPRESSEDBYTES_768
|
||||
case K_1024:
|
||||
return POLYVECCOMPRESSEDBYTES_1024
|
||||
case:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
polyvec_ntt :: proc "contextless" (r: ^Polyvec, k: int) {
|
||||
for i in 0..<k {
|
||||
poly_ntt(&r.vec[i])
|
||||
}
|
||||
}
|
||||
|
||||
polyvec_invntt_tomont :: proc "contextless" (r: ^Polyvec, k: int) {
|
||||
for i in 0..<k {
|
||||
poly_invntt_tomont(&r.vec[i])
|
||||
}
|
||||
}
|
||||
|
||||
polyvec_basemul_acc_montgomery :: proc "contextless" (r: ^Poly, a, b: ^Polyvec, k: int) {
|
||||
t: Poly = ---
|
||||
defer crypto.zero_explicit(&t, size_of(t))
|
||||
|
||||
poly_basemul_montgomery(r, &a.vec[0], &b.vec[0])
|
||||
for i in 1..<k {
|
||||
poly_basemul_montgomery(&t, &a.vec[i], &b.vec[i])
|
||||
poly_add(r, r, &t)
|
||||
}
|
||||
|
||||
poly_reduce(r)
|
||||
}
|
||||
|
||||
polyvec_reduce :: proc "contextless" (r: ^Polyvec, k: int) {
|
||||
for i in 0..<k {
|
||||
poly_reduce(&r.vec[i])
|
||||
}
|
||||
}
|
||||
|
||||
polyvec_add :: proc "contextless" (r, a, b: ^Polyvec, k: int) {
|
||||
for i in 0..<k {
|
||||
poly_add(&r.vec[i], &a.vec[i], &b.vec[i])
|
||||
}
|
||||
}
|
||||
|
||||
polyvec_clear :: proc "contextless" (rs: ..^Polyvec) {
|
||||
for j in 0..<len(rs) {
|
||||
r := rs[j]
|
||||
crypto.zero_explicit(r, size_of(Polyvec))
|
||||
}
|
||||
}
|
||||
19
core/crypto/_mlkem/reduce.odin
Normal file
19
core/crypto/_mlkem/reduce.odin
Normal file
@@ -0,0 +1,19 @@
|
||||
#+private
|
||||
package _mlkem
|
||||
|
||||
@(require_results)
|
||||
montgomery_reduce :: #force_inline proc "contextless" (a: i32) -> i16 {
|
||||
QINV :: -3327 // q^-1 mod 2^16
|
||||
|
||||
t := i16(a) * QINV
|
||||
return i16((a - i32(t) * Q) >> 16)
|
||||
}
|
||||
|
||||
@(require_results)
|
||||
barrett_reduce :: #force_inline proc "contextless" (a: i16) -> i16 {
|
||||
V : i16 : ((1<<26) + Q / 2) / Q
|
||||
|
||||
t := i16((i32(V)*i32(a) + (1<<25)) >> 26)
|
||||
t *= Q
|
||||
return a - t
|
||||
}
|
||||
61
core/crypto/_mlkem/symmetric_shake.odin
Normal file
61
core/crypto/_mlkem/symmetric_shake.odin
Normal file
@@ -0,0 +1,61 @@
|
||||
#+private
|
||||
package _mlkem
|
||||
|
||||
import "core:crypto"
|
||||
import "core:crypto/_sha3"
|
||||
import "core:crypto/sha3"
|
||||
import "core:crypto/shake"
|
||||
|
||||
XOF_BLOCKBYTES :: _sha3.RATE_128
|
||||
#assert(XOF_BLOCKBYTES % 3 == 0)
|
||||
|
||||
prf :: proc(out, key: []byte, iv: byte) {
|
||||
ctx: shake.Context = ---
|
||||
defer shake.reset(&ctx)
|
||||
|
||||
shake.init_256(&ctx)
|
||||
shake.write(&ctx, key)
|
||||
shake.write(&ctx, []byte{iv})
|
||||
shake.read(&ctx, out)
|
||||
}
|
||||
|
||||
rkprf :: proc(out, key, input: []byte) {
|
||||
ctx: shake.Context = ---
|
||||
defer shake.reset(&ctx)
|
||||
|
||||
shake.init_256(&ctx)
|
||||
shake.write(&ctx, key)
|
||||
shake.write(&ctx, input)
|
||||
shake.read(&ctx, out)
|
||||
}
|
||||
|
||||
xof_absorb :: proc(ctx: ^shake.Context, seed: []byte, x, y: byte) {
|
||||
shake.init_128(ctx)
|
||||
|
||||
extseed: [SYMBYTES+2]byte = ---
|
||||
defer crypto.zero_explicit(&extseed, size_of(extseed))
|
||||
|
||||
copy(extseed[:], seed)
|
||||
extseed[SYMBYTES+0] = x
|
||||
extseed[SYMBYTES+1] = y
|
||||
|
||||
shake.write(ctx, extseed[:])
|
||||
}
|
||||
|
||||
hash_h :: proc(dst, src: []byte) {
|
||||
ctx: sha3.Context = ---
|
||||
|
||||
sha3.init_256(&ctx)
|
||||
sha3.update(&ctx, src)
|
||||
sha3.final(&ctx, dst)
|
||||
}
|
||||
|
||||
hash_g :: proc(dst: []byte, srcs: ..[]byte) {
|
||||
ctx: sha3.Context = ---
|
||||
|
||||
sha3.init_512(&ctx)
|
||||
for src in srcs {
|
||||
sha3.update(&ctx, src)
|
||||
}
|
||||
sha3.final(&ctx, dst)
|
||||
}
|
||||
@@ -3,6 +3,7 @@ Various useful bit operations in constant time.
|
||||
*/
|
||||
package _subtle
|
||||
|
||||
import "core:crypto/_fiat"
|
||||
import "core:math/bits"
|
||||
|
||||
// byte_eq returns 1 if and only if (⟺) a == b, 0 otherwise.
|
||||
@@ -40,3 +41,41 @@ u64_is_non_zero :: proc "contextless" (a: u64) -> u64 {
|
||||
is_zero := u64_is_zero(a)
|
||||
return (~is_zero) & 1
|
||||
}
|
||||
|
||||
@(optimization_mode="none")
|
||||
cmov_bytes :: proc "contextless" (dst, src: []byte, ctrl: int) {
|
||||
s_len := len(src)
|
||||
ensure_contextless(s_len == len(dst), "crypto: cmov length mismatch")
|
||||
|
||||
c := -(byte)(ctrl)
|
||||
for i in 0..<s_len {
|
||||
dst[i] ~= c & (dst[i] ~ src[i])
|
||||
}
|
||||
}
|
||||
|
||||
@(optimization_mode="none")
|
||||
csel_i16 :: proc "contextless" (a, b: i16, ctrl: int) -> i16 {
|
||||
c := -(u16)(ctrl)
|
||||
return a ~ i16(c & u16(a ~ b))
|
||||
}
|
||||
|
||||
@(optimization_mode="none")
|
||||
csel_u16 :: proc "contextless" (a, b: u16, ctrl: int) -> u16 {
|
||||
c := -(u16)(ctrl)
|
||||
return a ~ (c & (a ~ b))
|
||||
}
|
||||
|
||||
csel_u32 :: proc "contextless" (a, b: u32, ctrl: int) -> u32 {
|
||||
return _fiat.cmovznz_u32(_fiat.u1(ctrl), a, b)
|
||||
}
|
||||
|
||||
csel_u64 :: proc "contextless" (a, b: u64, ctrl: int) -> u64 {
|
||||
return _fiat.cmovznz_u64(_fiat.u1(ctrl), a, b)
|
||||
}
|
||||
|
||||
csel :: proc {
|
||||
csel_i16,
|
||||
csel_u16,
|
||||
csel_u32,
|
||||
csel_u64,
|
||||
}
|
||||
|
||||
@@ -85,18 +85,6 @@ zero_explicit :: proc "contextless" (data: rawptr, len: int) -> rawptr {
|
||||
return data
|
||||
}
|
||||
|
||||
/*
|
||||
Set each byte of a memory range to a specific value.
|
||||
|
||||
This procedure copies value specified by the `value` parameter into each of the
|
||||
`len` bytes of a memory range, located at address `data`.
|
||||
|
||||
This procedure returns the pointer to `data`.
|
||||
*/
|
||||
set :: proc "contextless" (data: rawptr, value: byte, len: int) -> rawptr {
|
||||
return runtime.memset(data, i32(value), len)
|
||||
}
|
||||
|
||||
// rand_bytes fills the dst buffer with cryptographic entropy taken from
|
||||
// the system entropy source. This routine will block if the system entropy
|
||||
// source is not ready yet. All system entropy source failures are treated
|
||||
|
||||
@@ -106,6 +106,7 @@ Public_Key :: struct {
|
||||
// private_key_generate uses the system entropy source to generate a new
|
||||
// Private_Key. This will only fail if and only if (⟺) the system entropy source is
|
||||
// missing or broken.
|
||||
@(require_results)
|
||||
private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool {
|
||||
private_key_clear(priv_key)
|
||||
|
||||
@@ -143,6 +144,7 @@ private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool {
|
||||
|
||||
// private_key_set_bytes decodes a byte-encoded private key, and returns
|
||||
// true if and only if (⟺) the operation was successful.
|
||||
@(require_results)
|
||||
private_key_set_bytes :: proc(priv_key: ^Private_Key, curve: Curve, b: []byte) -> bool {
|
||||
private_key_clear(priv_key)
|
||||
|
||||
@@ -281,6 +283,7 @@ private_key_bytes :: proc(priv_key: ^Private_Key, dst: []byte) {
|
||||
|
||||
// private_key_equal returns true if and only if (⟺) the private keys are equal,
|
||||
// in constant time.
|
||||
@(require_results)
|
||||
private_key_equal :: proc(p, q: ^Private_Key) -> bool {
|
||||
if p._curve != q._curve {
|
||||
return false
|
||||
@@ -311,6 +314,7 @@ private_key_clear :: proc "contextless" (priv_key: ^Private_Key) {
|
||||
|
||||
// public_key_set_bytes decodes a byte-encoded public key, and returns
|
||||
// true if and only if (⟺) the operation was successful.
|
||||
@(require_results)
|
||||
public_key_set_bytes :: proc(pub_key: ^Public_Key, curve: Curve, b: []byte) -> bool {
|
||||
public_key_clear(pub_key)
|
||||
|
||||
@@ -411,6 +415,7 @@ public_key_bytes :: proc(pub_key: ^Public_Key, dst: []byte) {
|
||||
|
||||
// public_key_equal returns true if and only if (⟺) the public keys are equal,
|
||||
// in constant time.
|
||||
@(require_results)
|
||||
public_key_equal :: proc(p, q: ^Public_Key) -> bool {
|
||||
if p._curve != q._curve {
|
||||
return false
|
||||
@@ -479,11 +484,13 @@ ecdh :: proc(priv_key: ^Private_Key, pub_key: ^Public_Key, dst: []byte) -> bool
|
||||
}
|
||||
|
||||
// curve returns the Curve used by a Private_Key or Public_Key instance.
|
||||
@(require_results)
|
||||
curve :: proc(k: ^$T) -> Curve where(T == Private_Key || T == Public_Key) {
|
||||
return k._curve
|
||||
}
|
||||
|
||||
// key_size returns the key size of a Private_Key or Public_Key in bytes.
|
||||
@(require_results)
|
||||
key_size :: proc(k: ^$T) -> int where(T == Private_Key || T == Public_Key) {
|
||||
when T == Private_Key {
|
||||
return PRIVATE_KEY_SIZES[k._curve]
|
||||
@@ -494,6 +501,7 @@ key_size :: proc(k: ^$T) -> int where(T == Private_Key || T == Public_Key) {
|
||||
|
||||
// shared_secret_size returns the shared secret size of a key exchange
|
||||
// in bytes.
|
||||
@(require_results)
|
||||
shared_secret_size :: proc(k: ^$T) -> int where(T == Private_Key || T == Public_Key) {
|
||||
return SHARED_SECRET_SIZES[k._curve]
|
||||
}
|
||||
|
||||
@@ -81,6 +81,7 @@ Public_Key :: struct {
|
||||
// private_key_generate uses the system entropy source to generate a new
|
||||
// Private_Key. This will only fail if and only if (⟺) the system entropy source is
|
||||
// missing or broken.
|
||||
@(require_results)
|
||||
private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool {
|
||||
private_key_clear(priv_key)
|
||||
|
||||
@@ -112,6 +113,7 @@ private_key_generate :: proc(priv_key: ^Private_Key, curve: Curve) -> bool {
|
||||
|
||||
// private_key_set_bytes decodes a byte-encoded private key, and returns
|
||||
// true if and only if (⟺) the operation was successful.
|
||||
@(require_results)
|
||||
private_key_set_bytes :: proc(priv_key: ^Private_Key, curve: Curve, b: []byte) -> bool {
|
||||
private_key_clear(priv_key)
|
||||
|
||||
@@ -222,6 +224,7 @@ private_key_set :: proc(priv_key, src: ^Private_Key) {
|
||||
|
||||
// private_key_equal returns true if and only if (⟺) the private keys are equal,
|
||||
// in constant time.
|
||||
@(require_results)
|
||||
private_key_equal :: proc(p, q: ^Private_Key) -> bool {
|
||||
if p._curve != q._curve {
|
||||
return false
|
||||
@@ -246,6 +249,7 @@ private_key_clear :: proc "contextless" (priv_key: ^Private_Key) {
|
||||
|
||||
// public_key_set_bytes decodes a byte-encoded public key, and returns
|
||||
// true if and only if (⟺) the operation was successful.
|
||||
@(require_results)
|
||||
public_key_set_bytes :: proc(pub_key: ^Public_Key, curve: Curve, b: []byte) -> bool {
|
||||
public_key_clear(pub_key)
|
||||
|
||||
@@ -334,6 +338,7 @@ public_key_bytes :: proc(pub_key: ^Public_Key, dst: []byte) {
|
||||
|
||||
// public_key_equal returns true if and only if (⟺) the public keys are equal,
|
||||
// in constant time.
|
||||
@(require_results)
|
||||
public_key_equal :: proc(p, q: ^Public_Key) -> bool {
|
||||
if p._curve != q._curve {
|
||||
return false
|
||||
|
||||
@@ -18,6 +18,7 @@ package md5
|
||||
zhibog, dotbmp: Initial implementation.
|
||||
*/
|
||||
|
||||
import "base:intrinsics"
|
||||
import "core:crypto"
|
||||
import "core:encoding/endian"
|
||||
import "core:math/bits"
|
||||
@@ -100,7 +101,7 @@ final :: proc(ctx: ^Context, hash: []byte, finalize_clone: bool = false) {
|
||||
i += 1
|
||||
}
|
||||
transform(ctx, ctx.data[:])
|
||||
crypto.set(&ctx.data, 0, 56)
|
||||
intrinsics.mem_zero(&ctx.data, 56)
|
||||
}
|
||||
|
||||
ctx.bitlen += u64(ctx.datalen * 8)
|
||||
|
||||
@@ -19,6 +19,7 @@ package sha1
|
||||
zhibog, dotbmp: Initial implementation.
|
||||
*/
|
||||
|
||||
import "base:intrinsics"
|
||||
import "core:crypto"
|
||||
import "core:encoding/endian"
|
||||
import "core:math/bits"
|
||||
@@ -107,7 +108,7 @@ final :: proc(ctx: ^Context, hash: []byte, finalize_clone: bool = false) {
|
||||
i += 1
|
||||
}
|
||||
transform(ctx, ctx.data[:])
|
||||
crypto.set(&ctx.data, 0, 56)
|
||||
intrinsics.mem_zero(&ctx.data, 56)
|
||||
}
|
||||
|
||||
ctx.bitlen += u64(ctx.datalen * 8)
|
||||
|
||||
268
core/crypto/mlkem/api.odin
Normal file
268
core/crypto/mlkem/api.odin
Normal file
@@ -0,0 +1,268 @@
|
||||
package mlkem
|
||||
|
||||
import "core:crypto"
|
||||
import "core:crypto/_mlkem"
|
||||
|
||||
// Parameters are the supported ML-KEM parameter sets.
|
||||
Parameters :: enum {
|
||||
Invalid,
|
||||
ML_KEM_512,
|
||||
ML_KEM_768,
|
||||
ML_KEM_1024,
|
||||
}
|
||||
|
||||
// DECAPSULATION_KEY_SEED_SIZE is the size of a Decapsulation key in bytes.
|
||||
DECAPSULATION_KEY_SEED_SIZE :: 64 // (d, z) in NIST terms.
|
||||
|
||||
// DECAPSULATION_KEY_EXPANDED_SIZES are the per-parameter sizes of the
|
||||
// decapsulation key in bytes.
|
||||
DECAPSULATION_KEY_EXPANDED_SIZES := [Parameters]int {
|
||||
.Invalid = 0,
|
||||
.ML_KEM_512 = _mlkem.DECAPSKEYBYTES_512, // 1632-bytes
|
||||
.ML_KEM_768 = _mlkem.DECAPSKEYBYTES_768, // 2400-bytes
|
||||
.ML_KEM_1024 = _mlkem.DECAPSKEYBYTES_1024, // 3168-bytes
|
||||
}
|
||||
|
||||
// ENCAPSULATION_KEY_SIZES are the per-parameter sizes of the encapsulation
|
||||
// key in bytes.
|
||||
ENCAPSULATION_KEY_SIZES := [Parameters]int {
|
||||
.Invalid = 0,
|
||||
.ML_KEM_512 = _mlkem.ENCAPSKEYBYTES_512, // 800-bytes
|
||||
.ML_KEM_768 = _mlkem.ENCAPSKEYBYTES_768, // 1184-bytes
|
||||
.ML_KEM_1024 = _mlkem.ENCAPSKEYBYTES_1024, // 1568-bytes
|
||||
}
|
||||
|
||||
// CIPHERTEXT_SIZES are the per-parameter set sizes of the ciphertext
|
||||
// in bytes.
|
||||
CIPHERTEXT_SIZES := [Parameters]int {
|
||||
.Invalid = 0,
|
||||
.ML_KEM_512 = _mlkem.CIPHERTEXTBYTES_512, // 768-bytes
|
||||
.ML_KEM_768 = _mlkem.CIPHERTEXTBYTES_768, // 1088-bytes
|
||||
.ML_KEM_1024 = _mlkem.CIPHERTEXTBYTES_1024, // 1568-bytes
|
||||
}
|
||||
|
||||
// SHARED_SECRET_SIZE is the size of the final shared secret in bytes.
|
||||
SHARED_SECRET_SIZE :: 32
|
||||
|
||||
// Decapsulation_Key is a ML-KEM decapsulation (aka "private") key.
|
||||
// This implementation opts to include the encapsulation (aka "public")
|
||||
// key as well for cases where the decapsulation key is reused (eg: HPKE
|
||||
// with X-Wing).
|
||||
Decapsulation_Key :: _mlkem.Decapsulation_Key
|
||||
|
||||
// Encapsulation_Key is a ML-KEM encapsulation (aka "public") key.
|
||||
Encapsulation_Key :: _mlkem.Encapsulation_Key
|
||||
|
||||
// decapsulation_key_generate uses the system entropy source to generate
|
||||
// a decapsulation key. This will only fail if and only if (⟺) the system
|
||||
// entropy source is missing or broken.
|
||||
@(require_results)
|
||||
decapsulation_key_generate :: proc(dk: ^Decapsulation_Key, params: Parameters) -> bool {
|
||||
decapsulation_key_clear(dk)
|
||||
|
||||
if !crypto.HAS_RAND_BYTES {
|
||||
return false
|
||||
}
|
||||
|
||||
k := params_to_k(params)
|
||||
if k == 0 {
|
||||
panic("crypto/mlkem: invalid parameter set")
|
||||
}
|
||||
|
||||
seed: [DECAPSULATION_KEY_SEED_SIZE]byte = ---
|
||||
defer crypto.zero_explicit(&seed, size_of(seed))
|
||||
|
||||
crypto.rand_bytes(seed[:])
|
||||
_mlkem.kem_keygen_internal(dk, seed[:], k)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// decapsulation_key_set_bytes decodes a byte-encoded decapsulation key
|
||||
// in (d, z) "seed" format, and returns true if and only if (⟺) the
|
||||
// operation was successful.
|
||||
@(require_results)
|
||||
decapsulation_key_set_bytes :: proc(dk: ^Decapsulation_Key, params: Parameters, seed: []byte) -> bool {
|
||||
k := params_to_k(params)
|
||||
if k == 0 {
|
||||
return false
|
||||
}
|
||||
if len(seed) != DECAPSULATION_KEY_SEED_SIZE {
|
||||
return false
|
||||
}
|
||||
|
||||
_mlkem.kem_keygen_internal(dk, seed, k)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// decapsulation_key_bytes sets dst to byte-encoding of dk in the (d, z)
|
||||
// "seed" format.
|
||||
decapsulation_key_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) {
|
||||
ensure(dk.pke_dk.k != 0, "crypto/mlkem: uninitialized Decapsulation_Key")
|
||||
ensure(len(dst) == DECAPSULATION_KEY_SEED_SIZE, "crypto/mlkem: invalid destination size")
|
||||
|
||||
copy(dst, dk.seed[:])
|
||||
}
|
||||
|
||||
// decapsulation_key_expanded_bytes sets dst to the byte-encoding of dk.
|
||||
// in the expanded FIPS 203 format. This primarily exists for export
|
||||
// purposes.
|
||||
decapsulation_key_expanded_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) {
|
||||
dk_len: int
|
||||
switch dk.pke_dk.k {
|
||||
case _mlkem.K_512:
|
||||
dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_512]
|
||||
case _mlkem.K_768:
|
||||
dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_768]
|
||||
case _mlkem.K_1024:
|
||||
dk_len = DECAPSULATION_KEY_EXPANDED_SIZES[.ML_KEM_1024]
|
||||
case:
|
||||
panic("crypto/mlkem: uninitialized Decapsulation_Key")
|
||||
}
|
||||
ensure(len(dst) == dk_len, "crypto/mlkem: invalid destination size")
|
||||
|
||||
_mlkem.decapsulation_key_expanded_bytes(dk, dst)
|
||||
}
|
||||
|
||||
// decapsulation_key_encaps_bytes sets dst to the byte-encoding of the
|
||||
// encasulation key corresponding to dk.
|
||||
decapsulation_key_encaps_bytes :: proc(dk: ^Decapsulation_Key, dst: []byte) {
|
||||
encapsulation_key_bytes(&dk.ek, dst)
|
||||
}
|
||||
|
||||
// decapsulation_key_clear clears dk to the uninitialized state.
|
||||
decapsulation_key_clear :: proc(dk: ^Decapsulation_Key) {
|
||||
crypto.zero_explicit(dk, size_of(Decapsulation_Key))
|
||||
}
|
||||
|
||||
// encapsulation_key_set_bytes decodes a byte-encoded encapsulation key,
|
||||
// and returns true if and only if (⟺) the operation was successful.
|
||||
@(require_results)
|
||||
encapsulation_key_set_bytes :: proc(ek: ^Encapsulation_Key, params: Parameters, b: []byte) -> bool {
|
||||
k := params_to_k(params)
|
||||
if k == 0 {
|
||||
return false
|
||||
}
|
||||
if len(b) != ENCAPSULATION_KEY_SIZES[params] {
|
||||
return false
|
||||
}
|
||||
|
||||
return _mlkem.encapsulation_key_set_bytes(ek, k, b)
|
||||
}
|
||||
|
||||
// encapsulation_key_set_decaps sets ek to the encapsulation key corresponding
|
||||
// to dk.
|
||||
encapsulation_key_set_decaps :: proc(ek: ^Encapsulation_Key, dk: ^Decapsulation_Key) {
|
||||
ensure(dk.pke_dk.k != 0, "crypto/mlkem: uninitialized Decapsulation_Key")
|
||||
_mlkem.encapsulation_key_set_decaps(ek, dk)
|
||||
}
|
||||
|
||||
// encapsulation_key_encaps_bytes sets dst to the byte-encoding of ek.
|
||||
encapsulation_key_bytes :: proc(ek: ^Encapsulation_Key, dst: []byte) {
|
||||
ensure(ek.pke_ek.k != 0, "crypto/mlkem: uninitialized Encapsulation_Key")
|
||||
|
||||
k_len: int
|
||||
switch ek.pke_ek.k {
|
||||
case _mlkem.K_512:
|
||||
k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_512]
|
||||
case _mlkem.K_768:
|
||||
k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_768]
|
||||
case _mlkem.K_1024:
|
||||
k_len = ENCAPSULATION_KEY_SIZES[.ML_KEM_1024]
|
||||
case:
|
||||
panic("crypto/mlkem: invalid destination size")
|
||||
}
|
||||
|
||||
copy(dst, ek.raw_bytes[:k_len])
|
||||
}
|
||||
|
||||
// encapsulation_key_clear clears ek to the uninitialized state.
|
||||
encapsulation_key_clear :: proc(ek: ^Encapsulation_Key) {
|
||||
crypto.zero_explicit(ek, size_of(Encapsulation_Key))
|
||||
}
|
||||
|
||||
// encaps_raw_ek_bytes uses the byte encoded encapsulation key to generate
|
||||
// a shared secret and an associated ciphertext. This routine will fail
|
||||
// if the system entropy source is unavailable, or of the encapsulation key
|
||||
// is invalid.
|
||||
@(require_results)
|
||||
encaps_ek_raw_bytes :: proc(params: Parameters, raw_ek, shared_secret, ciphertext: []byte) -> bool {
|
||||
ek: Encapsulation_Key = ---
|
||||
if !encapsulation_key_set_bytes(&ek, params, raw_ek) {
|
||||
return false
|
||||
}
|
||||
defer encapsulation_key_clear(&ek)
|
||||
|
||||
return encaps_ek(&ek, shared_secret, ciphertext)
|
||||
}
|
||||
|
||||
// encaps_ek uses the encapsulation key to generate a shared secret and an
|
||||
// associated ciphertext. This routine will fail if the system entropy source
|
||||
// is unavailable.
|
||||
@(require_results)
|
||||
encaps_ek :: proc(ek: ^Encapsulation_Key, shared_secret, ciphertext: []byte) -> bool {
|
||||
ensure(len(shared_secret) == SHARED_SECRET_SIZE, "crypto/mlkem: invalid shared_seret size")
|
||||
|
||||
if !crypto.HAS_RAND_BYTES {
|
||||
return false
|
||||
}
|
||||
|
||||
m: [_mlkem.SYMBYTES]byte = ---
|
||||
defer crypto.zero_explicit(&m, size_of(m))
|
||||
|
||||
crypto.rand_bytes(m[:])
|
||||
_mlkem.kem_encaps_internal(shared_secret, ciphertext, ek, m[:])
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
encaps :: proc {
|
||||
encaps_ek,
|
||||
encaps_ek_raw_bytes,
|
||||
}
|
||||
|
||||
// decaps uses the decapsulation key to generate a shared secret from a
|
||||
// ciphertext. Due to ML-KEM's implicit rejection mechanism, this function
|
||||
// will only return false if and only if (⟺) the lengths of the inputs
|
||||
// are invalid or the decapsulation key is uninitialized.
|
||||
//
|
||||
// This routine returning true does not guarantee that the shared secret
|
||||
// matches that generated by the peer.
|
||||
@(require_results)
|
||||
decaps :: proc(dk: ^Decapsulation_Key, ciphertext, shared_secret: []byte) -> bool {
|
||||
ensure(len(shared_secret) == SHARED_SECRET_SIZE, "crypto/mlkem: invalid shared_seret size")
|
||||
|
||||
ct_len: int
|
||||
switch dk.pke_dk.k {
|
||||
case _mlkem.K_512:
|
||||
ct_len = CIPHERTEXT_SIZES[.ML_KEM_512]
|
||||
case _mlkem.K_768:
|
||||
ct_len = CIPHERTEXT_SIZES[.ML_KEM_768]
|
||||
case _mlkem.K_1024:
|
||||
ct_len = CIPHERTEXT_SIZES[.ML_KEM_1024]
|
||||
case:
|
||||
return false
|
||||
}
|
||||
if len(ciphertext) != ct_len {
|
||||
return false
|
||||
}
|
||||
|
||||
_mlkem.kem_decaps_internal(shared_secret, dk, ciphertext)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@(private="file")
|
||||
params_to_k :: #force_inline proc "contextless" (params: Parameters) -> int {
|
||||
#partial switch params {
|
||||
case .ML_KEM_512:
|
||||
return _mlkem.K_512
|
||||
case .ML_KEM_768:
|
||||
return _mlkem.K_768
|
||||
case .ML_KEM_1024:
|
||||
return _mlkem.K_1024
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
7
core/crypto/mlkem/doc.odin
Normal file
7
core/crypto/mlkem/doc.odin
Normal file
@@ -0,0 +1,7 @@
|
||||
/*
|
||||
ML-KEM Module-Lattice-Based Key-Encapsulation Mechanism.
|
||||
|
||||
See:
|
||||
- [[ https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.pdf ]]
|
||||
*/
|
||||
package mlkem
|
||||
@@ -58,7 +58,9 @@ generate_keypair :: proc(protocol: ^Protocol, private_key: ^ecdh.Private_Key) {
|
||||
case: panic("crypto/noise: unsupported DH curve in protocol")
|
||||
}
|
||||
|
||||
ecdh.private_key_generate(private_key, protocol.dh)
|
||||
if !ecdh.private_key_generate(private_key, protocol.dh) {
|
||||
panic("crypto/noise: entropy source unavailable")
|
||||
}
|
||||
}
|
||||
|
||||
// Performs a Diffie-Hellman calculation between the private key in key_pair
|
||||
@@ -837,7 +839,9 @@ handshakestate_read_message :: proc(self: ^Handshake_State, message, dst: []byte
|
||||
panic("crypto/noise: re was not empty when processing token 'e' during ReadMessage")
|
||||
}
|
||||
|
||||
ecdh.public_key_set_bytes(&self.re, protocol.dh, re)
|
||||
if !ecdh.public_key_set_bytes(&self.re, protocol.dh, re) {
|
||||
return nil, .Invalid_Handshake_Message
|
||||
}
|
||||
symmetricstate_mix_hash(&self.symmetric_state, re)
|
||||
if self.message_pattern.is_psk {
|
||||
symmetricstate_mix_key(&self.symmetric_state, re)
|
||||
@@ -864,7 +868,10 @@ handshakestate_read_message :: proc(self: ^Handshake_State, message, dst: []byte
|
||||
panic("crypto/noise: rs was not empty when processing token 's' during ReadMessage")
|
||||
}
|
||||
|
||||
ecdh.public_key_set_bytes(&self.rs, protocol.dh, rs)
|
||||
if !ecdh.public_key_set_bytes(&self.rs, protocol.dh, rs) {
|
||||
self.status = .Handshake_Failed
|
||||
return nil, .Invalid_Handshake_Message
|
||||
}
|
||||
msg = msg[rs_len:]
|
||||
|
||||
case .ee:
|
||||
|
||||
@@ -43,6 +43,7 @@ package all
|
||||
@(require) import "core:crypto/legacy/keccak"
|
||||
@(require) import "core:crypto/legacy/md5"
|
||||
@(require) import "core:crypto/legacy/sha1"
|
||||
@(require) import "core:crypto/mlkem"
|
||||
@(require) import cnoise "core:crypto/noise"
|
||||
@(require) import "core:crypto/pbkdf2"
|
||||
@(require) import "core:crypto/poly1305"
|
||||
|
||||
@@ -48,6 +48,7 @@ package all
|
||||
@(require) import "core:crypto/legacy/keccak"
|
||||
@(require) import "core:crypto/legacy/md5"
|
||||
@(require) import "core:crypto/legacy/sha1"
|
||||
@(require) import "core:crypto/mlkem"
|
||||
@(require) import cnoise "core:crypto/noise"
|
||||
@(require) import "core:crypto/pbkdf2"
|
||||
@(require) import "core:crypto/poly1305"
|
||||
|
||||
@@ -2,6 +2,7 @@ package benchmark_core_crypto
|
||||
|
||||
import "base:runtime"
|
||||
import "core:encoding/hex"
|
||||
import "core:mem"
|
||||
import "core:log"
|
||||
import "core:testing"
|
||||
import "core:text/table"
|
||||
@@ -161,7 +162,7 @@ bench_ed25519 :: proc() -> (sk, sig, verif: time.Duration) {
|
||||
@(private="file")
|
||||
bench_ecdsa :: proc(curve: ecdsa.Curve, hash: hash.Algorithm) -> (sk, sig, verif: time.Duration) {
|
||||
priv_bytes := make([]byte, ecdsa.PRIVATE_KEY_SIZES[curve], context.temp_allocator)
|
||||
crypto.set(raw_data(priv_bytes), 0x69, len(priv_bytes))
|
||||
mem.set(raw_data(priv_bytes), 0x69, len(priv_bytes))
|
||||
priv_key: ecdsa.Private_Key
|
||||
start := time.tick_now()
|
||||
for _ in 0 ..< DSA_ITERS {
|
||||
|
||||
84
tests/benchmark/crypto/benchmark_pqc.odin
Normal file
84
tests/benchmark/crypto/benchmark_pqc.odin
Normal file
@@ -0,0 +1,84 @@
|
||||
package benchmark_core_crypto
|
||||
|
||||
import "core:log"
|
||||
import "core:testing"
|
||||
import "core:text/table"
|
||||
import "core:time"
|
||||
|
||||
import "core:crypto"
|
||||
import "core:crypto/mlkem"
|
||||
|
||||
@(private = "file")
|
||||
MLKEM_ITERS :: 50000
|
||||
|
||||
@(test)
|
||||
benchmark_crypto_mlkem :: proc(t: ^testing.T) {
|
||||
if !crypto.HAS_RAND_BYTES {
|
||||
log.warnf("ML-KEM benchmarks skipped, no system entropy source")
|
||||
}
|
||||
|
||||
tbl: table.Table
|
||||
table.init(&tbl)
|
||||
defer table.destroy(&tbl)
|
||||
|
||||
table.caption(&tbl, "ML-KEM")
|
||||
table.aligned_header_of_values(&tbl, .Right, "Parameters", "Keygen", "Encaps", "Decaps")
|
||||
|
||||
append_tbl := proc(tbl: ^table.Table, algo_name: string, keygen, encaps, decaps: time.Duration) {
|
||||
table.aligned_row_of_values(
|
||||
tbl,
|
||||
.Right,
|
||||
algo_name,
|
||||
table.format(tbl, "%8M", keygen),
|
||||
table.format(tbl, "%8M", encaps),
|
||||
table.format(tbl, "%8M", decaps),
|
||||
)
|
||||
}
|
||||
|
||||
for params in mlkem.Parameters {
|
||||
if params == .Invalid {
|
||||
continue
|
||||
}
|
||||
param_name := MLKEM_PARAMS_NAMES[params]
|
||||
|
||||
decaps_key: mlkem.Decapsulation_Key
|
||||
start := time.tick_now()
|
||||
for _ in 0 ..< MLKEM_ITERS {
|
||||
_ = mlkem.decapsulation_key_generate(&decaps_key, params)
|
||||
}
|
||||
keygen := time.tick_since(start) / MLKEM_ITERS
|
||||
|
||||
encaps_key := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params])
|
||||
defer delete(encaps_key)
|
||||
ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params])
|
||||
defer delete(ciphertext)
|
||||
|
||||
mlkem.decapsulation_key_encaps_bytes(&decaps_key, encaps_key)
|
||||
|
||||
bob_shared: [mlkem.SHARED_SECRET_SIZE]byte
|
||||
start = time.tick_now()
|
||||
for _ in 0 ..< MLKEM_ITERS {
|
||||
_ = mlkem.encaps(params, encaps_key, bob_shared[:], ciphertext)
|
||||
}
|
||||
encaps := time.tick_since(start) / MLKEM_ITERS
|
||||
|
||||
alice_shared: [mlkem.SHARED_SECRET_SIZE]byte
|
||||
start = time.tick_now()
|
||||
for _ in 0 ..< MLKEM_ITERS {
|
||||
_ = mlkem.decaps(&decaps_key, ciphertext, alice_shared[:])
|
||||
}
|
||||
decaps := time.tick_since(start) / MLKEM_ITERS
|
||||
|
||||
append_tbl(&tbl, param_name, keygen, encaps, decaps)
|
||||
}
|
||||
|
||||
log_table(&tbl)
|
||||
}
|
||||
|
||||
@(private="file")
|
||||
MLKEM_PARAMS_NAMES := [mlkem.Parameters]string {
|
||||
.Invalid = "invalid",
|
||||
.ML_KEM_512 = "ML-KEM-512",
|
||||
.ML_KEM_768 = "ML-KEM-768",
|
||||
.ML_KEM_1024 = "ML-KEM-1024",
|
||||
}
|
||||
72
tests/core/crypto/test_core_crypto_pqc.odin
Normal file
72
tests/core/crypto/test_core_crypto_pqc.odin
Normal file
@@ -0,0 +1,72 @@
|
||||
package test_core_crypto
|
||||
|
||||
import "core:bytes"
|
||||
import "core:log"
|
||||
import "core:testing"
|
||||
|
||||
import "core:crypto"
|
||||
import "core:crypto/mlkem"
|
||||
|
||||
@(test)
|
||||
test_mlkem :: proc(t: ^testing.T) {
|
||||
if !crypto.HAS_RAND_BYTES {
|
||||
log.info("rand_bytes not supported - skipping")
|
||||
return
|
||||
}
|
||||
|
||||
// Test vectors are huge, and are covered by the wycheproof corpus,
|
||||
// so just test a full key exchange with all supported parameter
|
||||
// sets.
|
||||
for params in mlkem.Parameters {
|
||||
if params == .Invalid {
|
||||
continue
|
||||
}
|
||||
|
||||
// Alice
|
||||
decaps_key: mlkem.Decapsulation_Key
|
||||
if !testing.expectf(
|
||||
t,
|
||||
mlkem.decapsulation_key_generate(&decaps_key, params),
|
||||
"%v: decapsulation_key_generate",
|
||||
params,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
defer mlkem.decapsulation_key_clear(&decaps_key)
|
||||
|
||||
ek_bytes := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params])
|
||||
defer delete(ek_bytes)
|
||||
mlkem.decapsulation_key_encaps_bytes(&decaps_key, ek_bytes)
|
||||
|
||||
// Bob
|
||||
bob_shared_secret: [mlkem.SHARED_SECRET_SIZE]byte
|
||||
ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params])
|
||||
defer delete(ciphertext)
|
||||
if !testing.expectf(
|
||||
t,
|
||||
mlkem.encaps(params, ek_bytes, bob_shared_secret[:], ciphertext),
|
||||
"%v: encaps",
|
||||
params,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Alice
|
||||
alice_shared_secret: [mlkem.SHARED_SECRET_SIZE]byte
|
||||
if !testing.expectf(
|
||||
t,
|
||||
mlkem.decaps(&decaps_key, ciphertext, alice_shared_secret[:]),
|
||||
"%v: decaps",
|
||||
params,
|
||||
) {
|
||||
continue
|
||||
}
|
||||
|
||||
testing.expectf(
|
||||
t,
|
||||
bytes.equal(alice_shared_secret[:], bob_shared_secret[:]),
|
||||
"%v: shared secret mismatch",
|
||||
params,
|
||||
)
|
||||
}
|
||||
}
|
||||
514
tests/core/crypto/wycheproof/aead.odin
Normal file
514
tests/core/crypto/wycheproof/aead.odin
Normal file
@@ -0,0 +1,514 @@
|
||||
package test_wycheproof
|
||||
|
||||
import "core:encoding/hex"
|
||||
import "core:log"
|
||||
import "core:mem"
|
||||
import "core:os"
|
||||
import "core:slice"
|
||||
import "core:testing"
|
||||
|
||||
import chacha_simd128 "core:crypto/_chacha20/simd128"
|
||||
import chacha_simd256 "core:crypto/_chacha20/simd256"
|
||||
import "core:crypto/aegis"
|
||||
import "core:crypto/aes"
|
||||
import "core:crypto/chacha20"
|
||||
import "core:crypto/chacha20poly1305"
|
||||
|
||||
import "../common"
|
||||
|
||||
supported_aegis_impls :: proc() -> [dynamic]aes.Implementation {
|
||||
impls := make([dynamic]aes.Implementation, 0, 2, context.temp_allocator)
|
||||
append(&impls, aes.Implementation.Portable)
|
||||
if aegis.is_hardware_accelerated() {
|
||||
append(&impls, aes.Implementation.Hardware)
|
||||
}
|
||||
|
||||
return impls
|
||||
}
|
||||
|
||||
@(test)
|
||||
test_aead_aegis :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
files := []string {
|
||||
"aegis128L_test.json",
|
||||
"aegis256_test.json",
|
||||
}
|
||||
|
||||
log.debug("aead/aegis: starting")
|
||||
|
||||
for f in files {
|
||||
mem.free_all() // Probably don't need this, but be safe.
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Aead_Test_Group)
|
||||
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
for impl in supported_aegis_impls() {
|
||||
testing.expectf(t, test_aead_aegis_impl(&test_vectors, impl), "impl {} failed", impl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_aead_aegis_impl :: proc(
|
||||
test_vectors: ^Test_Vectors(Aead_Test_Group),
|
||||
impl: aes.Implementation,
|
||||
) -> bool {
|
||||
log.debug("aead/aegis/%v: starting", impl)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
if comment := test_vector.comment; comment != "" {
|
||||
log.debugf(
|
||||
"aead/aegis/%v/%d: %s: %+v",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
comment,
|
||||
test_vector.flags,
|
||||
)
|
||||
} else {
|
||||
log.debugf("aead/aegis/%v/%d: %+v",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.flags,
|
||||
)
|
||||
}
|
||||
|
||||
key := common.hexbytes_decode(test_vector.key)
|
||||
iv := common.hexbytes_decode(test_vector.iv)
|
||||
aad := common.hexbytes_decode(test_vector.aad)
|
||||
msg := common.hexbytes_decode(test_vector.msg)
|
||||
ct := common.hexbytes_decode(test_vector.ct)
|
||||
tag := common.hexbytes_decode(test_vector.tag)
|
||||
|
||||
if len(iv) == 0 {
|
||||
log.infof(
|
||||
"aead/aegis/%v/%d: skipped, invalid IVs panic",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
)
|
||||
num_skipped += 1
|
||||
continue
|
||||
}
|
||||
|
||||
ctx: aegis.Context
|
||||
aegis.init(&ctx, key, impl)
|
||||
|
||||
if result_is_valid(test_vector.result) {
|
||||
ct_ := make([]byte, len(ct))
|
||||
tag_ := make([]byte, len(tag))
|
||||
aegis.seal(&ctx, ct_, tag_, iv, aad, msg)
|
||||
|
||||
ok := common.hexbytes_compare(test_vector.ct, ct_)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(ct_))
|
||||
log.errorf(
|
||||
"aead/aegis/%v/%d: ciphertext: expected %s actual %s",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.ct,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.tag, tag_)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(tag_))
|
||||
log.errorf(
|
||||
"aead/aegis/%v/%d: tag: expected %s actual %s",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.tag,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
msg_ := make([]byte, len(msg))
|
||||
ok := aegis.open(&ctx, msg_, iv, aad, ct, tag)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
log.errorf("aead/aegis/%v/%d: decrypt failed", impl, test_vector.tc_id)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
if ok && !common.hexbytes_compare(test_vector.msg, msg_) {
|
||||
x := transmute(string)(hex.encode(msg_))
|
||||
log.errorf(
|
||||
"aead/aegis/%v/%d: decrypt msg: expected %s actual %s",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.msg,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"aead/aegis: ran %d, passed %d, failed %d, skipped %d",
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
|
||||
supported_aes_impls :: proc() -> [dynamic]aes.Implementation {
|
||||
impls := make([dynamic]aes.Implementation, 0, 2)
|
||||
append(&impls, aes.Implementation.Portable)
|
||||
if aes.is_hardware_accelerated() {
|
||||
append(&impls, aes.Implementation.Hardware)
|
||||
}
|
||||
|
||||
return impls
|
||||
}
|
||||
|
||||
@(test)
|
||||
test_aead_aes_gcm :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, "aes_gcm_test.json"}, context.allocator)
|
||||
|
||||
log.debug("aead/aes-gcm: starting")
|
||||
|
||||
test_vectors: Test_Vectors(Aead_Test_Group)
|
||||
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", fn) {
|
||||
return
|
||||
}
|
||||
|
||||
for impl in supported_aes_impls() {
|
||||
testing.expectf(t, test_aead_aes_gcm_impl(&test_vectors, impl), "impl {} failed", impl)
|
||||
}
|
||||
}
|
||||
|
||||
test_aead_aes_gcm_impl :: proc(
|
||||
test_vectors: ^Test_Vectors(Aead_Test_Group),
|
||||
impl: aes.Implementation,
|
||||
) -> bool {
|
||||
log.debug("aead/aes-gcm/%v: starting", impl)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
if comment := test_vector.comment; comment != "" {
|
||||
log.debugf(
|
||||
"aead/aes-gcm/%v/%d: %s: %+v",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
comment,
|
||||
test_vector.flags,
|
||||
)
|
||||
} else {
|
||||
log.debugf("aead/aes-gcm/%v/%d: %+v",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.flags,
|
||||
)
|
||||
}
|
||||
|
||||
key := common.hexbytes_decode(test_vector.key)
|
||||
iv := common.hexbytes_decode(test_vector.iv)
|
||||
aad := common.hexbytes_decode(test_vector.aad)
|
||||
msg := common.hexbytes_decode(test_vector.msg)
|
||||
ct := common.hexbytes_decode(test_vector.ct)
|
||||
tag := common.hexbytes_decode(test_vector.tag)
|
||||
|
||||
if len(iv) == 0 {
|
||||
log.infof(
|
||||
"aead/aes-gcm/%v/%d: skipped, invalid IVs panic",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
)
|
||||
num_skipped += 1
|
||||
continue
|
||||
}
|
||||
|
||||
ctx: aes.Context_GCM
|
||||
aes.init_gcm(&ctx, key, impl)
|
||||
|
||||
if result_is_valid(test_vector.result) {
|
||||
ct_ := make([]byte, len(ct))
|
||||
tag_ := make([]byte, len(tag))
|
||||
aes.seal_gcm(&ctx, ct_, tag_, iv, aad, msg)
|
||||
|
||||
ok := common.hexbytes_compare(test_vector.ct, ct_)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(ct_))
|
||||
log.errorf(
|
||||
"aead/aes-gcm/%v/%d: ciphertext: expected %s actual %s",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.ct,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.tag, tag_)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(tag_))
|
||||
log.errorf(
|
||||
"aead/aes-gcm/%v/%d: tag: expected %s actual %s",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.tag,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
msg_ := make([]byte, len(msg))
|
||||
ok := aes.open_gcm(&ctx, msg_, iv, aad, ct, tag)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
log.errorf("aead/aes-gcm/%v/%d: decrypt failed", impl, test_vector.tc_id)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
if ok && !common.hexbytes_compare(test_vector.msg, msg_) {
|
||||
x := transmute(string)(hex.encode(msg_))
|
||||
log.errorf(
|
||||
"aead/aes-gcm/%v/%d: decrypt msg: expected %s actual %s",
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.msg,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"aead/aes-gcm: ran %d, passed %d, failed %d, skipped %d",
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
|
||||
supported_chacha_impls :: proc() -> [dynamic]chacha20.Implementation {
|
||||
impls := make([dynamic]chacha20.Implementation, 0, 3)
|
||||
append(&impls, chacha20.Implementation.Portable)
|
||||
if chacha_simd128.is_performant() {
|
||||
append(&impls, chacha20.Implementation.Simd128)
|
||||
}
|
||||
if chacha_simd256.is_performant() {
|
||||
append(&impls, chacha20.Implementation.Simd256)
|
||||
}
|
||||
|
||||
return impls
|
||||
}
|
||||
|
||||
@(test)
|
||||
test_aead_chacha20_poly1305 :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
files := []string {
|
||||
"chacha20_poly1305_test.json",
|
||||
"xchacha20_poly1305_test.json",
|
||||
}
|
||||
|
||||
log.debug("aead/(x)chacha20poly1305: starting")
|
||||
|
||||
for f, i in files {
|
||||
mem.free_all() // Probably don't need this, but be safe.
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Aead_Test_Group)
|
||||
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
for impl in supported_chacha_impls() {
|
||||
testing.expectf(t, test_aead_chacha20_poly1305_impl(&test_vectors, i == 1, impl), "impl {} failed", impl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_aead_chacha20_poly1305_impl :: proc(
|
||||
test_vectors: ^Test_Vectors(Aead_Test_Group),
|
||||
is_xchacha: bool,
|
||||
impl: chacha20.Implementation,
|
||||
) -> bool {
|
||||
FLAG_INVALID_NONCE_SIZE :: "InvalidNonceSize"
|
||||
|
||||
alg_str := is_xchacha ? "xchacha20poly1305" : "chacha20poly1305"
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
if comment := test_vector.comment; comment != "" {
|
||||
log.debugf(
|
||||
"aead/%s/%v/%d: %s: %+v",
|
||||
alg_str,
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
comment,
|
||||
test_vector.flags,
|
||||
)
|
||||
} else {
|
||||
log.debugf("aead/%s/%v/%d: %+v",
|
||||
alg_str,
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.flags,
|
||||
)
|
||||
}
|
||||
|
||||
key := common.hexbytes_decode(test_vector.key)
|
||||
iv := common.hexbytes_decode(test_vector.iv)
|
||||
aad := common.hexbytes_decode(test_vector.aad)
|
||||
msg := common.hexbytes_decode(test_vector.msg)
|
||||
ct := common.hexbytes_decode(test_vector.ct)
|
||||
tag := common.hexbytes_decode(test_vector.tag)
|
||||
|
||||
if slice.contains(test_vector.flags, FLAG_INVALID_NONCE_SIZE) {
|
||||
log.infof(
|
||||
"aead/%s/%v/%d: skipped, invalid nonces panic",
|
||||
alg_str,
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
)
|
||||
num_skipped += 1
|
||||
continue
|
||||
}
|
||||
|
||||
ctx: chacha20poly1305.Context
|
||||
switch is_xchacha {
|
||||
case true:
|
||||
chacha20poly1305.init_xchacha(&ctx, key, impl)
|
||||
case false:
|
||||
chacha20poly1305.init(&ctx, key, impl)
|
||||
}
|
||||
|
||||
if result_is_valid(test_vector.result) {
|
||||
ct_ := make([]byte, len(ct))
|
||||
tag_ := make([]byte, len(tag))
|
||||
chacha20poly1305.seal(&ctx, ct_, tag_, iv, aad, msg)
|
||||
|
||||
ok := common.hexbytes_compare(test_vector.ct, ct_)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(ct_))
|
||||
log.errorf(
|
||||
"aead/%s/%v/%d: ciphertext: expected %s actual %s",
|
||||
alg_str,
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.ct,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.tag, tag_)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(tag_))
|
||||
log.errorf(
|
||||
"aead/%s/%v/%d: tag: expected %s actual %s",
|
||||
alg_str,
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.tag,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
msg_ := make([]byte, len(msg))
|
||||
ok := chacha20poly1305.open(&ctx, msg_, iv, aad, ct, tag)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
log.errorf("aead/%s/%v/%d: decrypt failed",
|
||||
alg_str,
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
if ok && !common.hexbytes_compare(test_vector.msg, msg_) {
|
||||
x := transmute(string)(hex.encode(msg_))
|
||||
log.errorf(
|
||||
"aead/%s/%v/%d: decrypt msg: expected %s actual %s",
|
||||
alg_str,
|
||||
impl,
|
||||
test_vector.tc_id,
|
||||
test_vector.msg,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"aead/%s/%v: ran %d, passed %d, failed %d, skipped %d",
|
||||
alg_str,
|
||||
impl,
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
427
tests/core/crypto/wycheproof/ecc.odin
Normal file
427
tests/core/crypto/wycheproof/ecc.odin
Normal file
@@ -0,0 +1,427 @@
|
||||
package test_wycheproof
|
||||
|
||||
import "core:encoding/hex"
|
||||
import "core:log"
|
||||
import "core:mem"
|
||||
import "core:os"
|
||||
import "core:slice"
|
||||
import "core:strings"
|
||||
import "core:testing"
|
||||
|
||||
import "core:crypto/hash"
|
||||
import "core:crypto/ecdh"
|
||||
import "core:crypto/ecdsa"
|
||||
import "core:crypto/ed25519"
|
||||
|
||||
import "../common"
|
||||
|
||||
@(test)
|
||||
test_eddsa_ed25519 :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, "ed25519_test.json"}, context.allocator)
|
||||
|
||||
log.debug("eddsa/ed25519: starting")
|
||||
|
||||
test_vectors: Test_Vectors(Eddsa_Test_Group)
|
||||
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", fn) {
|
||||
return
|
||||
}
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group, i in test_vectors.test_groups {
|
||||
mem.free_all() // Probably don't need this, but be safe.
|
||||
pk_bytes := common.hexbytes_decode(test_group.public_key.pk)
|
||||
|
||||
pk: ed25519.Public_Key
|
||||
pk_ok := ed25519.public_key_set_bytes(&pk, pk_bytes)
|
||||
if !testing.expectf(t, pk_ok, "eddsa/ed25519/%d: invalid public key: %s", i, test_group.public_key.pk) {
|
||||
num_failed += len(test_group.tests)
|
||||
continue
|
||||
}
|
||||
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
if comment := test_vector.comment; comment != "" {
|
||||
log.debugf(
|
||||
"eddsa/ed25519/%d: %s: %+v",
|
||||
test_vector.tc_id,
|
||||
comment,
|
||||
test_vector.flags,
|
||||
)
|
||||
} else {
|
||||
log.debugf("eddsa/ed25519/%d: %+v", test_vector.tc_id, test_vector.flags)
|
||||
}
|
||||
|
||||
msg := common.hexbytes_decode(test_vector.msg)
|
||||
sig := common.hexbytes_decode(test_vector.sig)
|
||||
|
||||
verify_ok := ed25519.verify(&pk, msg, sig)
|
||||
if !testing.expectf(
|
||||
t,
|
||||
result_check(test_vector.result, verify_ok),
|
||||
"eddsa/ed25519/%d: verify failed: expected %s actual %v",
|
||||
test_vector.tc_id,
|
||||
test_vector.result,
|
||||
verify_ok,
|
||||
) {
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"eddsa/ed25519: ran %d, passed %d, failed %d, skipped %d",
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
}
|
||||
|
||||
@(test)
|
||||
test_ecdsa :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
log.debug("ecdsa: starting")
|
||||
|
||||
files := []string {
|
||||
"ecdsa_secp256r1_sha256_test.json",
|
||||
"ecdsa_secp256r1_sha512_test.json",
|
||||
"ecdsa_secp384r1_sha384_test.json",
|
||||
}
|
||||
|
||||
for f in files {
|
||||
mem.free_all() // Probably don't need this, but be safe.
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Ecdsa_Test_Group)
|
||||
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
testing.expectf(t, test_ecdsa_impl(t, &test_vectors), "ecdsa failed")
|
||||
}
|
||||
}
|
||||
|
||||
test_ecdsa_impl :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Ecdsa_Test_Group)) -> bool {
|
||||
curve_str := test_vectors.test_groups[0].public_key.curve
|
||||
hash_str := test_vectors.test_groups[0].sha
|
||||
|
||||
curve_alg: ecdsa.Curve
|
||||
switch curve_str {
|
||||
case "secp256r1":
|
||||
curve_alg = .SECP256R1
|
||||
case "secp384r1":
|
||||
curve_alg = .SECP384R1
|
||||
case:
|
||||
log.errorf("ecdsa: unsupported curve: %s", curve_str)
|
||||
}
|
||||
|
||||
hash_alg: hash.Algorithm
|
||||
switch hash_str {
|
||||
case "SHA-256":
|
||||
hash_alg = .SHA256
|
||||
case "SHA-384":
|
||||
hash_alg = .SHA384
|
||||
case "SHA-512":
|
||||
hash_alg = .SHA512
|
||||
case:
|
||||
log.errorf("ecdsa: unsupported hash: %s", hash_str)
|
||||
}
|
||||
|
||||
log.debugf("ecdsa/%s/%s: starting", curve_str, hash_str)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group, i in test_vectors.test_groups {
|
||||
pk_bytes := common.hexbytes_decode(test_group.public_key.uncompressed)
|
||||
|
||||
pk: ecdsa.Public_Key
|
||||
pk_ok := ecdsa.public_key_set_bytes(&pk, curve_alg, pk_bytes)
|
||||
if !testing.expectf(t, pk_ok, "ecdsa/%s/%s/%d: invalid public key: %s", curve_str, hash_str, i, test_group.public_key.uncompressed) {
|
||||
num_failed += len(test_group.tests)
|
||||
continue
|
||||
}
|
||||
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
if comment := test_vector.comment; comment != "" {
|
||||
log.debugf(
|
||||
"ecda/%s/%s/%d: %s: %+v",
|
||||
curve_str,
|
||||
hash_str,
|
||||
test_vector.tc_id,
|
||||
comment,
|
||||
test_vector.flags,
|
||||
)
|
||||
} else {
|
||||
log.debugf("ecdsa/%s/%s/%d: %+v", curve_str, hash_str, test_vector.tc_id, test_vector.flags)
|
||||
}
|
||||
|
||||
msg := common.hexbytes_decode(test_vector.msg)
|
||||
sig := common.hexbytes_decode(test_vector.sig)
|
||||
|
||||
verify_ok := ecdsa.verify_asn1(&pk, hash_alg, msg, sig)
|
||||
if !testing.expectf(
|
||||
t,
|
||||
result_check(test_vector.result, verify_ok),
|
||||
"ecdsa/%s/%s/%d: verify failed: expected %s actual %v",
|
||||
curve_str,
|
||||
hash_str,
|
||||
test_vector.tc_id,
|
||||
test_vector.result,
|
||||
verify_ok,
|
||||
) {
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"ecdsa/%s/%s: ran %d, passed %d, failed %d, skipped %d",
|
||||
curve_str,
|
||||
hash_str,
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
|
||||
@(test)
|
||||
test_ecdh :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
PREFIX_TEST_ECDH :: "ecdh_"
|
||||
SUFFIX_TEST_ECPOINT :: "_ecpoint"
|
||||
|
||||
files := []string {
|
||||
"ecdh_secp256r1_ecpoint_test.json",
|
||||
"ecdh_secp384r1_ecpoint_test.json",
|
||||
"x25519_test.json",
|
||||
"x448_test.json",
|
||||
}
|
||||
|
||||
log.debug("ecdh: starting")
|
||||
|
||||
for f in files {
|
||||
mem.free_all() // Probably don't need this, but be safe.
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Ecdh_Test_Group)
|
||||
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
alg_str := strings.trim_suffix(f, SUFFIX_TEST_JSON)
|
||||
alg_str = strings.trim_suffix(alg_str, SUFFIX_TEST_ECPOINT)
|
||||
alg_str = strings.trim_prefix(alg_str, PREFIX_TEST_ECDH)
|
||||
testing.expectf(t, test_ecdh_impl(&test_vectors, alg_str), "alg {} failed", alg_str)
|
||||
}
|
||||
}
|
||||
|
||||
test_ecdh_impl :: proc(
|
||||
test_vectors: ^Test_Vectors(Ecdh_Test_Group),
|
||||
alg_str: string,
|
||||
) -> bool {
|
||||
ALG_P256 :: "secp256r1"
|
||||
ALG_P384 :: "secp384r1"
|
||||
ALG_X25519 :: "x25519"
|
||||
ALG_X448 :: "x448"
|
||||
|
||||
// XDH exceptions
|
||||
FLAG_PUBLIC_KEY_TOO_LONG :: "PublicKeyTooLong"
|
||||
FLAG_ZERO_SHARED_SECRET :: "ZeroSharedSecret"
|
||||
|
||||
// ECDH exceptions
|
||||
FLAG_COMPRESSED_POINT :: "CompressedPoint"
|
||||
FLAG_INVALID_CURVE :: "InvalidCurveAttack"
|
||||
FLAG_INVALID_ENCODING :: "InvalidEncoding"
|
||||
|
||||
log.debugf("ecdh/%s: starting", alg_str)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
if comment := test_vector.comment; comment != "" {
|
||||
log.debugf("ecdh/%s/%d: %s: %+v", alg_str, test_vector.tc_id, comment, test_vector.flags)
|
||||
} else {
|
||||
log.debugf("ecdh/%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags)
|
||||
}
|
||||
|
||||
raw_pub := common.hexbytes_decode(test_vector.public)
|
||||
raw_priv := common.hexbytes_decode(test_vector.private)
|
||||
|
||||
curve: ecdh.Curve
|
||||
priv_key: ecdh.Private_Key
|
||||
pub_key: ecdh.Public_Key
|
||||
|
||||
is_nist, is_xdh: bool
|
||||
switch alg_str {
|
||||
case ALG_P256:
|
||||
curve = .SECP256R1
|
||||
// Ugh, ASN.1 :(
|
||||
l := len(raw_priv)
|
||||
if l == 33 {
|
||||
if raw_priv[0] == 0 {
|
||||
raw_priv = raw_priv[1:]
|
||||
}
|
||||
} else if l < 32 {
|
||||
// left-pad.odin
|
||||
tmp := make([]byte, 32)
|
||||
copy(tmp[32-l:], raw_priv)
|
||||
raw_priv = tmp
|
||||
}
|
||||
is_nist = true
|
||||
case ALG_P384:
|
||||
curve = .SECP384R1
|
||||
// Ugh, ASN.1 :(
|
||||
l := len(raw_priv)
|
||||
if l == 49 {
|
||||
if raw_priv[0] == 0 {
|
||||
raw_priv = raw_priv[1:]
|
||||
}
|
||||
} else if l < 48 {
|
||||
// left-pad.odin
|
||||
tmp := make([]byte, 48)
|
||||
copy(tmp[48-l:], raw_priv)
|
||||
raw_priv = tmp
|
||||
}
|
||||
is_nist = true
|
||||
case ALG_X25519:
|
||||
curve = .X25519
|
||||
is_xdh = true
|
||||
case ALG_X448:
|
||||
curve = .X448
|
||||
is_xdh = true
|
||||
case:
|
||||
log.errorf("ecdh: unsupported algorithm: %s", alg_str)
|
||||
return false
|
||||
}
|
||||
|
||||
if ok := ecdh.private_key_set_bytes(&priv_key, curve, raw_priv); !ok {
|
||||
log.errorf(
|
||||
"ecdh/%s/%d: failed to deserialize private_key: %s %d %x",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
test_vector.private,
|
||||
len(raw_priv),
|
||||
raw_priv,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
if ok := ecdh.public_key_set_bytes(&pub_key, curve, raw_pub); !ok {
|
||||
if is_nist {
|
||||
if slice.contains(test_vector.flags, FLAG_COMPRESSED_POINT) {
|
||||
num_passed += 1
|
||||
continue
|
||||
}
|
||||
if slice.contains(test_vector.flags, FLAG_INVALID_CURVE) {
|
||||
num_passed += 1
|
||||
continue
|
||||
}
|
||||
if slice.contains(test_vector.flags, FLAG_INVALID_ENCODING) {
|
||||
num_passed += 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
if slice.contains(test_vector.flags, FLAG_PUBLIC_KEY_TOO_LONG) {
|
||||
num_passed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
log.errorf(
|
||||
"ecdh/%s/%d: failed to deserialize public_key: %s",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
test_vector.public,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
shared := make([]byte, ecdh.SHARED_SECRET_SIZES[curve])
|
||||
|
||||
ok := ecdh.ecdh(&priv_key, &pub_key, shared)
|
||||
if !ok {
|
||||
if is_xdh && slice.contains(test_vector.flags, FLAG_ZERO_SHARED_SECRET) {
|
||||
num_passed += 1
|
||||
continue
|
||||
}
|
||||
// unused: x := transmute(string)(hex.encode(shared))
|
||||
log.errorf(
|
||||
"ecdh/%s/%d: ecdh failed",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.shared, shared)
|
||||
// "acceptable" results are fine from here because we have
|
||||
// checked for the all-zero shared secret XDH case already.
|
||||
if !result_check(test_vector.result, ok, false) {
|
||||
x := transmute(string)(hex.encode(shared))
|
||||
log.errorf(
|
||||
"ecdh/%s/%d: shared: expected %s actual %s",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
test_vector.shared,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"ecdh/%s: ran %d, passed %d, failed %d, skipped %d",
|
||||
alg_str,
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
238
tests/core/crypto/wycheproof/kdf.odin
Normal file
238
tests/core/crypto/wycheproof/kdf.odin
Normal file
@@ -0,0 +1,238 @@
|
||||
package test_wycheproof
|
||||
|
||||
import "core:encoding/hex"
|
||||
import "core:log"
|
||||
import "core:mem"
|
||||
import "core:os"
|
||||
import "core:slice"
|
||||
import "core:strings"
|
||||
import "core:testing"
|
||||
|
||||
import "core:crypto/hkdf"
|
||||
import "core:crypto/pbkdf2"
|
||||
|
||||
import "../common"
|
||||
|
||||
@(test)
|
||||
test_hkdf :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
log.debug("hkdf: starting")
|
||||
|
||||
files := []string {
|
||||
"hkdf_sha1_test.json",
|
||||
"hkdf_sha256_test.json",
|
||||
"hkdf_sha384_test.json",
|
||||
"hkdf_sha512_test.json",
|
||||
}
|
||||
|
||||
for f in files {
|
||||
mem.free_all() // Probably don't need this, but be safe.
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Hkdf_Test_Group)
|
||||
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
testing.expectf(t, test_hkdf_impl(&test_vectors), "hkdf failed")
|
||||
}
|
||||
}
|
||||
|
||||
test_hkdf_impl :: proc(test_vectors: ^Test_Vectors(Hkdf_Test_Group)) -> bool {
|
||||
PREFIX_HKDF :: "HKDF-"
|
||||
FLAG_SIZE_TOO_LARGE :: "SizeTooLarge"
|
||||
|
||||
alg_str := strings.trim_prefix(test_vectors.algorithm, PREFIX_HKDF)
|
||||
alg, ok := hash_name_to_algorithm(alg_str)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
alg_str = strings.to_lower(alg_str)
|
||||
|
||||
log.debugf("hkdf/%s: starting", alg_str)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
if comment := test_vector.comment; comment != "" {
|
||||
log.debugf(
|
||||
"hkdf/%s/%d: %s: %+v",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
comment,
|
||||
test_vector.flags,
|
||||
)
|
||||
} else {
|
||||
log.debugf("hkdf/%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags)
|
||||
}
|
||||
|
||||
ikm := common.hexbytes_decode(test_vector.ikm)
|
||||
salt := common.hexbytes_decode(test_vector.salt)
|
||||
info := common.hexbytes_decode(test_vector.info)
|
||||
|
||||
if slice.contains(test_vector.flags, FLAG_SIZE_TOO_LARGE) {
|
||||
log.infof(
|
||||
"hkdf/%s/%d: skipped, oversized outputs panic",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
)
|
||||
num_skipped += 1
|
||||
continue
|
||||
}
|
||||
|
||||
okm_ := make([]byte, test_vector.size)
|
||||
hkdf.extract_and_expand(alg, salt, ikm, info, okm_)
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.okm, okm_)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(okm_))
|
||||
log.errorf(
|
||||
"hkdf/%s/%d: shared: expected %s actual %s",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
test_vector.okm,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"hkdf/%s: ran %d, passed %d, failed %d, skipped %d",
|
||||
alg_str,
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
|
||||
@(test)
|
||||
test_pbkdf2 :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
log.debug("pbkdf2: starting")
|
||||
|
||||
files := []string {
|
||||
"pbkdf2_hmacsha1_test.json",
|
||||
"pbkdf2_hmacsha224_test.json",
|
||||
"pbkdf2_hmacsha256_test.json",
|
||||
"pbkdf2_hmacsha384_test.json",
|
||||
"pbkdf2_hmacsha512_test.json",
|
||||
}
|
||||
|
||||
for f in files {
|
||||
mem.free_all() // Probably don't need this, but be safe.
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Pbkdf_Test_Group)
|
||||
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
testing.expectf(t, test_pbkdf2_impl(&test_vectors), "pbkdf2 failed")
|
||||
}
|
||||
}
|
||||
|
||||
test_pbkdf2_impl :: proc(
|
||||
test_vectors: ^Test_Vectors(Pbkdf_Test_Group),
|
||||
) -> bool {
|
||||
PREFIX_PBKDF_HMAC :: "PBKDF2-HMAC"
|
||||
FLAG_LARGE_ITERATION_COUNT :: "LargeIterationCount"
|
||||
|
||||
alg_str := strings.trim_prefix(test_vectors.algorithm, PREFIX_PBKDF_HMAC)
|
||||
alg, ok := hash_name_to_algorithm(alg_str)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
alg_str = strings.to_lower(alg_str)
|
||||
|
||||
log.debugf("pbkdf2/hmac-%s: starting", alg_str)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
if comment := test_vector.comment; comment != "" {
|
||||
log.debugf(
|
||||
"pbkdf2/hmac-%s/%d: %s: %+v",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
comment,
|
||||
test_vector.flags,
|
||||
)
|
||||
} else {
|
||||
log.debugf("pbkdf2/hmac-%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags)
|
||||
}
|
||||
|
||||
if slice.contains(test_vector.flags, FLAG_LARGE_ITERATION_COUNT) {
|
||||
log.infof(
|
||||
"pbkdf2/hmac-%s/%d: skipped, takes fucking forever",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
)
|
||||
num_skipped += 1
|
||||
continue
|
||||
}
|
||||
|
||||
password := common.hexbytes_decode(test_vector.password)
|
||||
salt := common.hexbytes_decode(test_vector.salt)
|
||||
|
||||
dk_ := make([]byte, test_vector.dk_len)
|
||||
pbkdf2.derive(alg, password, salt, test_vector.iteration_count, dk_)
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.dk, dk_)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(dk_))
|
||||
log.errorf(
|
||||
"pbkdf2/hmac-%s/%d: shared: expected %s actual %s",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
test_vector.dk,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"pbkdf2/%s: ran %d, passed %d, failed %d, skipped %d",
|
||||
alg_str,
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
156
tests/core/crypto/wycheproof/mac.odin
Normal file
156
tests/core/crypto/wycheproof/mac.odin
Normal file
@@ -0,0 +1,156 @@
|
||||
package test_wycheproof
|
||||
|
||||
import "core:encoding/hex"
|
||||
|
||||
import "core:log"
|
||||
import "core:mem"
|
||||
import "core:os"
|
||||
import "core:testing"
|
||||
|
||||
import "core:crypto/hmac"
|
||||
import "core:crypto/kmac"
|
||||
import "core:crypto/siphash"
|
||||
|
||||
import "../common"
|
||||
|
||||
@(test)
|
||||
test_mac :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
log.debug("mac: starting")
|
||||
|
||||
files := []string {
|
||||
"hmac_sha1_test.json",
|
||||
"hmac_sha224_test.json",
|
||||
"hmac_sha256_test.json",
|
||||
"hmac_sha3_224_test.json",
|
||||
"hmac_sha3_256_test.json",
|
||||
"hmac_sha3_384_test.json",
|
||||
"hmac_sha3_512_test.json",
|
||||
"hmac_sha384_test.json",
|
||||
// "hmac_sha512_224_test.json",
|
||||
"hmac_sha512_256_test.json",
|
||||
"hmac_sha512_test.json",
|
||||
"hmac_sm3_test.json",
|
||||
"kmac128_no_customization_test.json",
|
||||
"kmac256_no_customization_test.json",
|
||||
"siphash_1_3_test.json",
|
||||
"siphash_2_4_test.json",
|
||||
"siphash_4_8_test.json",
|
||||
}
|
||||
|
||||
for f in files {
|
||||
mem.free_all() // Probably don't need this, but be safe.
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Mac_Test_Group)
|
||||
if !testing.expectf(t, load(&test_vectors, fn), "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
testing.expectf(t, test_mac_impl(&test_vectors), "hkdf failed")
|
||||
}
|
||||
}
|
||||
|
||||
test_mac_impl :: proc(test_vectors: ^Test_Vectors(Mac_Test_Group)) -> bool {
|
||||
PREFIX_HMAC :: "HMAC"
|
||||
PREFIX_KMAC :: "KMAC"
|
||||
|
||||
mac_alg, hmac_alg, alg_str, ok := mac_algorithm(test_vectors.algorithm)
|
||||
if !ok {
|
||||
log.errorf("mac: unsupported algorith: %s", test_vectors.algorithm)
|
||||
return false
|
||||
}
|
||||
|
||||
log.debugf("%s: starting", alg_str)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
if comment := test_vector.comment; comment != "" {
|
||||
log.debugf(
|
||||
"%s/%d: %s: %+v",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
comment,
|
||||
test_vector.flags,
|
||||
)
|
||||
} else {
|
||||
log.debugf("%s/%d: %+v", alg_str, test_vector.tc_id, test_vector.flags)
|
||||
}
|
||||
|
||||
key := common.hexbytes_decode(test_vector.key)
|
||||
msg := common.hexbytes_decode(test_vector.msg)
|
||||
|
||||
tag_ := make([]byte, len(test_vector.tag) / 2)
|
||||
|
||||
#partial switch mac_alg {
|
||||
case .HMAC:
|
||||
ctx: hmac.Context
|
||||
hmac.init(&ctx, hmac_alg, key)
|
||||
hmac.update(&ctx, msg)
|
||||
if l := hmac.tag_size(&ctx); l == len(tag_) {
|
||||
hmac.final(&ctx, tag_)
|
||||
} else {
|
||||
// Our hmac package does not support truncation.
|
||||
tmp := make([]byte, l)
|
||||
hmac.final(&ctx, tmp)
|
||||
copy(tag_, tmp)
|
||||
}
|
||||
case .KMAC128, .KMAC256:
|
||||
ctx: kmac.Context
|
||||
#partial switch mac_alg {
|
||||
case .KMAC128:
|
||||
kmac.init_128(&ctx, key, nil)
|
||||
case .KMAC256:
|
||||
kmac.init_256(&ctx, key, nil)
|
||||
}
|
||||
kmac.update(&ctx, msg)
|
||||
kmac.final(&ctx, tag_)
|
||||
case .SIPHASH_1_3:
|
||||
siphash.sum_1_3(msg, key, tag_)
|
||||
case .SIPHASH_2_4:
|
||||
siphash.sum_2_4(msg, key, tag_)
|
||||
case .SIPHASH_4_8:
|
||||
siphash.sum_4_8(msg, key, tag_)
|
||||
}
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.tag, tag_)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(tag_))
|
||||
log.errorf(
|
||||
"%s/%d: tag: expected %s actual %s",
|
||||
alg_str,
|
||||
test_vector.tc_id,
|
||||
test_vector.tag,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"%s: ran %d, passed %d, failed %d, skipped %d",
|
||||
alg_str,
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
401
tests/core/crypto/wycheproof/pqc.odin
Normal file
401
tests/core/crypto/wycheproof/pqc.odin
Normal file
@@ -0,0 +1,401 @@
|
||||
package test_wycheproof
|
||||
|
||||
import "core:encoding/hex"
|
||||
import "core:log"
|
||||
import "core:mem"
|
||||
import "core:os"
|
||||
import "core:testing"
|
||||
|
||||
import "core:crypto/_mlkem"
|
||||
import "core:crypto/mlkem"
|
||||
|
||||
import "../common"
|
||||
|
||||
@(test)
|
||||
test_mlkem :: proc(t: ^testing.T) {
|
||||
arena: mem.Arena
|
||||
arena_backing := make([]byte, ARENA_SIZE)
|
||||
defer delete(arena_backing)
|
||||
mem.arena_init(&arena, arena_backing)
|
||||
context.allocator = mem.arena_allocator(&arena)
|
||||
|
||||
log.debug("mlkem: starting")
|
||||
|
||||
files_keygen := []string {
|
||||
"mlkem_512_keygen_seed_test.json",
|
||||
"mlkem_768_keygen_seed_test.json",
|
||||
"mlkem_1024_keygen_seed_test.json",
|
||||
}
|
||||
for f in files_keygen {
|
||||
mem.free_all()
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Kem_Test_Group)
|
||||
load_ok := load(&test_vectors, fn)
|
||||
if !testing.expectf(t, load_ok, "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
testing.expectf(t, test_mlkem_keygen(t, &test_vectors), "ML-KEM KeyGen failed")
|
||||
}
|
||||
|
||||
files_encaps := []string {
|
||||
"mlkem_512_encaps_test.json",
|
||||
"mlkem_768_encaps_test.json",
|
||||
"mlkem_1024_encaps_test.json",
|
||||
}
|
||||
for f in files_encaps {
|
||||
mem.free_all()
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Kem_Test_Group)
|
||||
load_ok := load(&test_vectors, fn)
|
||||
if !testing.expectf(t, load_ok, "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
testing.expectf(t, test_mlkem_encaps(t, &test_vectors), "ML-KEM Encaps failed")
|
||||
}
|
||||
|
||||
files_decaps := []string {
|
||||
"mlkem_512_test.json",
|
||||
"mlkem_768_test.json",
|
||||
"mlkem_1024_test.json",
|
||||
}
|
||||
for f in files_decaps {
|
||||
mem.free_all()
|
||||
|
||||
fn, _ := os.join_path([]string{BASE_PATH, f}, context.allocator)
|
||||
|
||||
test_vectors: Test_Vectors(Kem_Test_Group)
|
||||
load_ok := load(&test_vectors, fn)
|
||||
if !testing.expectf(t, load_ok, "Unable to load {}", f) {
|
||||
continue
|
||||
}
|
||||
|
||||
testing.expectf(t, test_mlkem_decaps(t, &test_vectors), "ML-KEM Decaps failed")
|
||||
}
|
||||
}
|
||||
|
||||
test_mlkem_keygen :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool {
|
||||
params_str := test_vectors.test_groups[0].parameter_set
|
||||
params := parameter_set_to_params(params_str)
|
||||
if params == .Invalid {
|
||||
return false
|
||||
}
|
||||
|
||||
log.debugf("%s: KeyGen starting", params_str)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group, tg_id in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
seed := common.hexbytes_decode(test_vector.seed)
|
||||
|
||||
dk: mlkem.Decapsulation_Key
|
||||
if !testing.expectf(
|
||||
t,
|
||||
mlkem.decapsulation_key_set_bytes(&dk, params, seed),
|
||||
"%s/KeyGen/%d/%d: failed to set decapsulation key from seed",
|
||||
params_str,
|
||||
tg_id,
|
||||
test_vector.tc_id,
|
||||
test_vector.seed,
|
||||
) {
|
||||
num_failed *= 1
|
||||
continue
|
||||
}
|
||||
|
||||
ek_bytes := make([]byte, mlkem.ENCAPSULATION_KEY_SIZES[params])
|
||||
mlkem.decapsulation_key_encaps_bytes(&dk, ek_bytes)
|
||||
|
||||
ok := common.hexbytes_compare(test_vector.ek, ek_bytes)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(ek_bytes))
|
||||
log.errorf(
|
||||
"%s/KeyGen/%d/%d: ek: expected %s actual %s",
|
||||
params_str,
|
||||
tg_id,
|
||||
test_vector.tc_id,
|
||||
test_vector.ek,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
dk_bytes := make([]byte, mlkem.DECAPSULATION_KEY_EXPANDED_SIZES[params])
|
||||
mlkem.decapsulation_key_expanded_bytes(&dk, dk_bytes)
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.dk, dk_bytes)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(dk_bytes))
|
||||
log.errorf(
|
||||
"%s/KeyGen/%d/%d: dk: expected %s actual %s",
|
||||
tg_id,
|
||||
params_str,
|
||||
test_vector.tc_id,
|
||||
test_vector.dk,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
seed_bytes: [mlkem.DECAPSULATION_KEY_SEED_SIZE]byte
|
||||
mlkem.decapsulation_key_bytes(&dk, seed_bytes[:])
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.seed, seed_bytes[:])
|
||||
if !result_check(test_vector.result, ok) {
|
||||
x := transmute(string)(hex.encode(seed_bytes[:]))
|
||||
log.errorf(
|
||||
"%s/KeyGen/%d/%d: seed: expected %s actual %s",
|
||||
tg_id,
|
||||
params_str,
|
||||
test_vector.tc_id,
|
||||
test_vector.seed,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"%s/KeyGen: ran %d, passed %d, failed %d, skipped %d",
|
||||
params_str,
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
|
||||
test_mlkem_encaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool {
|
||||
params_str := test_vectors.test_groups[0].parameter_set
|
||||
params := parameter_set_to_params(params_str)
|
||||
if params == .Invalid {
|
||||
return false
|
||||
}
|
||||
|
||||
log.debugf("%s: Encaps starting", params_str)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group, tg_id in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
ek: mlkem.Encapsulation_Key
|
||||
ok := mlkem.encapsulation_key_set_bytes(
|
||||
&ek,
|
||||
params,
|
||||
common.hexbytes_decode(test_vector.ek),
|
||||
)
|
||||
|
||||
// The current corpus can only fail if the encapsulation key
|
||||
// is malformed in some way.
|
||||
if !result_check(test_vector.result, ok) {
|
||||
log.errorf(
|
||||
"%s/Encaps/%d/%d: unexpected set encapsulation key from bytes: %s (%v != %v)",
|
||||
params_str,
|
||||
tg_id,
|
||||
test_vector.tc_id,
|
||||
test_vector.ek,
|
||||
test_vector.result,
|
||||
ok,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
num_passed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
shared_secret: [mlkem.SHARED_SECRET_SIZE]byte
|
||||
ciphertext := make([]byte, mlkem.CIPHERTEXT_SIZES[params])
|
||||
|
||||
_mlkem.kem_encaps_internal(
|
||||
shared_secret[:],
|
||||
ciphertext,
|
||||
&ek,
|
||||
common.hexbytes_decode(test_vector.m),
|
||||
)
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.c, ciphertext)
|
||||
if !ok {
|
||||
x := transmute(string)(hex.encode(ciphertext))
|
||||
log.errorf(
|
||||
"%s/Encaps/%d/%d: ciphertext: expected: %s actual: %s",
|
||||
params_str,
|
||||
tg_id,
|
||||
test_vector.tc_id,
|
||||
test_vector.c,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.k, shared_secret[:])
|
||||
if !ok {
|
||||
x := transmute(string)(hex.encode(shared_secret[:]))
|
||||
log.errorf(
|
||||
"%s/Encaps/%d/%d: shared_secret: expected: %s actual: %s",
|
||||
params_str,
|
||||
tg_id,
|
||||
test_vector.tc_id,
|
||||
test_vector.k,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"%s/Encaps: ran %d, passed %d, failed %d, skipped %d",
|
||||
params_str,
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
|
||||
test_mlkem_decaps :: proc(t: ^testing.T, test_vectors: ^Test_Vectors(Kem_Test_Group)) -> bool {
|
||||
params_str := test_vectors.test_groups[0].parameter_set
|
||||
params := parameter_set_to_params(params_str)
|
||||
if params == .Invalid {
|
||||
return false
|
||||
}
|
||||
|
||||
log.debugf("%s: Decaps starting", params_str)
|
||||
|
||||
num_ran, num_passed, num_failed, num_skipped: int
|
||||
for &test_group, tg_id in test_vectors.test_groups {
|
||||
for &test_vector in test_group.tests {
|
||||
num_ran += 1
|
||||
|
||||
// We do not have an API for decaps with raw seed.
|
||||
seed := common.hexbytes_decode(test_vector.seed)
|
||||
switch len(seed) {
|
||||
case mlkem.DECAPSULATION_KEY_SEED_SIZE:
|
||||
case:
|
||||
if testing.expectf(
|
||||
t,
|
||||
result_is_invalid(test_vector.result),
|
||||
"%s/Decaps/%d/%d: test vector expects success with invalid seed",
|
||||
params_str,
|
||||
tg_id,
|
||||
test_vector.tc_id,
|
||||
) {
|
||||
num_passed += 1
|
||||
} else {
|
||||
num_failed += 1
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
dk: mlkem.Decapsulation_Key
|
||||
if !testing.expectf(
|
||||
t,
|
||||
mlkem.decapsulation_key_set_bytes(&dk, params, seed),
|
||||
"%s/Decaps/%d/%d: failed to set decapsulation key from seed",
|
||||
params_str,
|
||||
tg_id,
|
||||
test_vector.tc_id,
|
||||
test_vector.seed,
|
||||
) {
|
||||
num_failed *= 1
|
||||
continue
|
||||
}
|
||||
|
||||
shared_secret: [mlkem.SHARED_SECRET_SIZE]byte
|
||||
|
||||
ok := mlkem.decaps(
|
||||
&dk,
|
||||
common.hexbytes_decode(test_vector.c),
|
||||
shared_secret[:],
|
||||
)
|
||||
if !result_check(test_vector.result, ok) {
|
||||
log.errorf(
|
||||
"%s/Decaps/%d/%d: unexpected decapsulation failure",
|
||||
params_str,
|
||||
tg_id,
|
||||
test_vector.tc_id,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
if !ok {
|
||||
num_passed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
ok = common.hexbytes_compare(test_vector.k, shared_secret[:])
|
||||
if !ok {
|
||||
x := transmute(string)(hex.encode(shared_secret[:]))
|
||||
log.errorf(
|
||||
"%s/Decaps/%d/%d: shared_secret: expected: %s actual: %s",
|
||||
params_str,
|
||||
tg_id,
|
||||
test_vector.tc_id,
|
||||
test_vector.k,
|
||||
x,
|
||||
)
|
||||
num_failed += 1
|
||||
continue
|
||||
}
|
||||
|
||||
num_passed += 1
|
||||
}
|
||||
}
|
||||
|
||||
assert(num_ran == test_vectors.number_of_tests)
|
||||
assert(num_passed + num_failed + num_skipped == num_ran)
|
||||
|
||||
log.infof(
|
||||
"%s/Decaps: ran %d, passed %d, failed %d, skipped %d",
|
||||
params_str,
|
||||
num_ran,
|
||||
num_passed,
|
||||
num_failed,
|
||||
num_skipped,
|
||||
)
|
||||
|
||||
return num_failed == 0
|
||||
}
|
||||
|
||||
@(require_results, private="file")
|
||||
parameter_set_to_params :: proc(s: string) -> mlkem.Parameters {
|
||||
switch s {
|
||||
case "ML-KEM-512":
|
||||
return .ML_KEM_512
|
||||
case "ML-KEM-768":
|
||||
return .ML_KEM_768
|
||||
case "ML-KEM-1024":
|
||||
return .ML_KEM_1024
|
||||
case:
|
||||
return .Invalid
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,11 @@ Test_Vectors_Note :: struct {
|
||||
links: []string `json:"links"`,
|
||||
}
|
||||
|
||||
Test_Group_Source :: struct {
|
||||
name: string `json:"name"`,
|
||||
version: string `json:"version"`,
|
||||
}
|
||||
|
||||
Aead_Test_Group :: struct {
|
||||
iv_size: int `json:"ivSize"`,
|
||||
key_size: int `json:"keySize"`,
|
||||
@@ -198,3 +203,23 @@ Pbkdf_Test_Vector :: struct {
|
||||
result: Result `json:"result"`,
|
||||
flags: []string `json:"flags"`,
|
||||
}
|
||||
|
||||
Kem_Test_Group :: struct {
|
||||
type: string `json:"type"`,
|
||||
source: Test_Group_Source `json:"source"`,
|
||||
parameter_set: string `json:"parameterSet"`,
|
||||
tests: []Kem_Test_Vector `json:"tests"`,
|
||||
}
|
||||
|
||||
Kem_Test_Vector :: struct {
|
||||
tc_id: int `json:"tcId"`,
|
||||
flags: []string `json:"flags"`,
|
||||
comment: string `json:"comment"`,
|
||||
seed: common.Hex_Bytes `json:"seed"`,
|
||||
m: common.Hex_Bytes `json:"m"`,
|
||||
ek: common.Hex_Bytes `json:"ek"`,
|
||||
dk: common.Hex_Bytes `json:"dk"`,
|
||||
c: common.Hex_Bytes `json:"c"`,
|
||||
k: common.Hex_Bytes `json:"K"`,
|
||||
result: Result `json:"result"`,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user