core/crypto/sha2: Refactor update/final

This is largely modeled off the SM3 versions of these routines, since
the relevant parts of the code are the same between SHA-256 and SM3,
and the alterations required to support SHA-512 are relatively simple.

The prior versions of update and the transform would leak memory, and
doing things this way also reduces the context buffer sizes by 1 block.
This commit is contained in:
Yawning Angel
2023-11-17 01:13:27 +09:00
parent bc139ba6c6
commit b71afdc3ee

View File

@@ -14,7 +14,6 @@ package sha2
import "core:encoding/endian"
import "core:io"
import "core:math/bits"
import "core:mem"
import "core:os"
/*
@@ -482,8 +481,8 @@ init :: proc(ctx: ^$T) {
}
}
ctx.tot_len = 0
ctx.length = 0
ctx.bitlength = 0
ctx.is_initialized = true
}
@@ -491,65 +490,72 @@ init :: proc(ctx: ^$T) {
update :: proc(ctx: ^$T, data: []byte) {
assert(ctx.is_initialized)
length := uint(len(data))
block_nb: uint
new_len, rem_len, tmp_len: uint
shifted_message := make([]byte, length)
when T == Sha256_Context {
CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE
} else when T == Sha512_Context {
CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE
}
tmp_len = CURR_BLOCK_SIZE - ctx.length
rem_len = length < tmp_len ? length : tmp_len
copy(ctx.block[ctx.length:], data[:rem_len])
data := data
ctx.length += u64(len(data))
if ctx.length + length < CURR_BLOCK_SIZE {
ctx.length += length
return
if ctx.bitlength > 0 {
n := copy(ctx.block[ctx.bitlength:], data[:])
ctx.bitlength += u64(n)
if ctx.bitlength == CURR_BLOCK_SIZE {
sha2_transf(ctx, ctx.block[:])
ctx.bitlength = 0
}
data = data[n:]
}
new_len = length - rem_len
block_nb = new_len / CURR_BLOCK_SIZE
shifted_message = data[rem_len:]
sha2_transf(ctx, ctx.block[:], 1)
sha2_transf(ctx, shifted_message, block_nb)
rem_len = new_len % CURR_BLOCK_SIZE
if rem_len > 0 {
when T == Sha256_Context {copy(ctx.block[:], shifted_message[block_nb << 6:rem_len])} else when T == Sha512_Context {copy(ctx.block[:], shifted_message[block_nb << 7:rem_len])}
if len(data) >= CURR_BLOCK_SIZE {
n := len(data) &~ (CURR_BLOCK_SIZE - 1)
sha2_transf(ctx, data[:n])
data = data[n:]
}
if len(data) > 0 {
ctx.bitlength = u64(copy(ctx.block[:], data[:]))
}
ctx.length = rem_len
when T == Sha256_Context {ctx.tot_len += (block_nb + 1) << 6} else when T == Sha512_Context {ctx.tot_len += (block_nb + 1) << 7}
}
final :: proc(ctx: ^$T, hash: []byte) {
assert(ctx.is_initialized)
block_nb, pm_len: uint
len_b: u64
if len(hash) * 8 < ctx.md_bits {
panic("crypto/sha2: invalid destination digest size")
}
when T == Sha256_Context {CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE} else when T == Sha512_Context {CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE}
length := ctx.length
when T == Sha256_Context {block_nb = 1 + ((CURR_BLOCK_SIZE - 9) < (ctx.length % CURR_BLOCK_SIZE) ? 1 : 0)} else when T == Sha512_Context {block_nb = 1 + ((CURR_BLOCK_SIZE - 17) < (ctx.length % CURR_BLOCK_SIZE) ? 1 : 0)}
raw_pad: [SHA512_BLOCK_SIZE]byte
when T == Sha256_Context {
CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE
pm_len := 8 // 64-bits for length
} else when T == Sha512_Context {
CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE
pm_len := 16 // 128-bits for length
}
pad := raw_pad[:CURR_BLOCK_SIZE]
pad_len := u64(CURR_BLOCK_SIZE - pm_len)
len_b = u64(ctx.tot_len + ctx.length) << 3
when T == Sha256_Context {pm_len = block_nb << 6} else when T == Sha512_Context {pm_len = block_nb << 7}
pad[0] = 0x80
if length % CURR_BLOCK_SIZE < pad_len {
update(ctx, pad[0:pad_len - length % CURR_BLOCK_SIZE])
} else {
update(ctx, pad[0:CURR_BLOCK_SIZE + pad_len - length % CURR_BLOCK_SIZE])
}
mem.set(rawptr(&(ctx.block[ctx.length:])[0]), 0, int(pm_len - ctx.length))
ctx.block[ctx.length] = 0x80
endian.unchecked_put_u64be(ctx.block[pm_len - 8:], len_b)
sha2_transf(ctx, ctx.block[:], block_nb)
length_hi, length_lo := bits.mul_u64(length, 8) // Length in bits
when T == Sha256_Context {
_ = length_hi
endian.unchecked_put_u64be(pad[:], length_lo)
update(ctx, pad[:8])
} else when T == Sha512_Context {
endian.unchecked_put_u64be(pad[:], length_hi)
endian.unchecked_put_u64be(pad[8:], length_lo)
update(ctx, pad[0:16])
}
assert(ctx.bitlength == 0)
when T == Sha256_Context {
for i := 0; i < ctx.md_bits / 32; i += 1 {
@@ -572,21 +578,21 @@ SHA256_BLOCK_SIZE :: 64
SHA512_BLOCK_SIZE :: 128
Sha256_Context :: struct {
tot_len: uint,
length: uint,
block: [128]byte,
h: [8]u32,
md_bits: int,
block: [SHA256_BLOCK_SIZE]byte,
h: [8]u32,
bitlength: u64,
length: u64,
md_bits: int,
is_initialized: bool,
}
Sha512_Context :: struct {
tot_len: uint,
length: uint,
block: [256]byte,
h: [8]u64,
md_bits: int,
block: [SHA512_BLOCK_SIZE]byte,
h: [8]u64,
bitlength: u64,
length: u64,
md_bits: int,
is_initialized: bool,
}
@@ -716,52 +722,46 @@ SHA512_F4 :: #force_inline proc "contextless" (x: u64) -> u64 {
}
@(private)
sha2_transf :: proc(ctx: ^$T, data: []byte, block_nb: uint) {
sha2_transf :: proc "contextless" (ctx: ^$T, data: []byte) {
when T == Sha256_Context {
w: [64]u32
wv: [8]u32
t1, t2: u32
CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE
} else when T == Sha512_Context {
w: [80]u64
wv: [8]u64
t1, t2: u64
CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE
}
sub_block := make([]byte, len(data))
i, j: i32
for i = 0; i < i32(block_nb); i += 1 {
when T == Sha256_Context {
sub_block = data[i << 6:]
} else when T == Sha512_Context {
sub_block = data[i << 7:]
}
for j = 0; j < 16; j += 1 {
data := data
for len(data) >= CURR_BLOCK_SIZE {
for i := 0; i < 16; i += 1 {
when T == Sha256_Context {
w[j] = endian.unchecked_get_u32be(sub_block[j << 2:])
w[i] = endian.unchecked_get_u32be(data[i * 4:])
} else when T == Sha512_Context {
w[j] = endian.unchecked_get_u64be(sub_block[j << 3:])
w[i] = endian.unchecked_get_u64be(data[i * 8:])
}
}
when T == Sha256_Context {
for j = 16; j < 64; j += 1 {
w[j] = SHA256_F4(w[j - 2]) + w[j - 7] + SHA256_F3(w[j - 15]) + w[j - 16]
for i := 16; i < 64; i += 1 {
w[i] = SHA256_F4(w[i - 2]) + w[i - 7] + SHA256_F3(w[i - 15]) + w[i - 16]
}
} else when T == Sha512_Context {
for j = 16; j < 80; j += 1 {
w[j] = SHA512_F4(w[j - 2]) + w[j - 7] + SHA512_F3(w[j - 15]) + w[j - 16]
for i := 16; i < 80; i += 1 {
w[i] = SHA512_F4(w[i - 2]) + w[i - 7] + SHA512_F3(w[i - 15]) + w[i - 16]
}
}
for j = 0; j < 8; j += 1 {
wv[j] = ctx.h[j]
for i := 0; i < 8; i += 1 {
wv[i] = ctx.h[i]
}
when T == Sha256_Context {
for j = 0; j < 64; j += 1 {
t1 = wv[7] + SHA256_F2(wv[4]) + SHA256_CH(wv[4], wv[5], wv[6]) + sha256_k[j] + w[j]
for i := 0; i < 64; i += 1 {
t1 = wv[7] + SHA256_F2(wv[4]) + SHA256_CH(wv[4], wv[5], wv[6]) + sha256_k[i] + w[i]
t2 = SHA256_F1(wv[0]) + SHA256_MAJ(wv[0], wv[1], wv[2])
wv[7] = wv[6]
wv[6] = wv[5]
@@ -773,8 +773,8 @@ sha2_transf :: proc(ctx: ^$T, data: []byte, block_nb: uint) {
wv[0] = t1 + t2
}
} else when T == Sha512_Context {
for j = 0; j < 80; j += 1 {
t1 = wv[7] + SHA512_F2(wv[4]) + SHA512_CH(wv[4], wv[5], wv[6]) + sha512_k[j] + w[j]
for i := 0; i < 80; i += 1 {
t1 = wv[7] + SHA512_F2(wv[4]) + SHA512_CH(wv[4], wv[5], wv[6]) + sha512_k[i] + w[i]
t2 = SHA512_F1(wv[0]) + SHA512_MAJ(wv[0], wv[1], wv[2])
wv[7] = wv[6]
wv[6] = wv[5]
@@ -787,8 +787,10 @@ sha2_transf :: proc(ctx: ^$T, data: []byte, block_nb: uint) {
}
}
for j = 0; j < 8; j += 1 {
ctx.h[j] += wv[j]
for i := 0; i < 8; i += 1 {
ctx.h[i] += wv[i]
}
data = data[CURR_BLOCK_SIZE:]
}
}