From b71afdc3ee8648def5de5b8df72e8e25790217f6 Mon Sep 17 00:00:00 2001 From: Yawning Angel Date: Fri, 17 Nov 2023 01:13:27 +0900 Subject: [PATCH] core/crypto/sha2: Refactor update/final This is largely modeled off the SM3 versions of these routines, since the relevant parts of the code are the same between SHA-256 and SM3, and the alterations required to support SHA-512 are relatively simple. The prior versions of update and the transform would leak memory, and doing things this way also reduces the context buffer sizes by 1 block. --- core/crypto/sha2/sha2.odin | 154 +++++++++++++++++++------------------ 1 file changed, 78 insertions(+), 76 deletions(-) diff --git a/core/crypto/sha2/sha2.odin b/core/crypto/sha2/sha2.odin index d4b6b87bb..024e52623 100644 --- a/core/crypto/sha2/sha2.odin +++ b/core/crypto/sha2/sha2.odin @@ -14,7 +14,6 @@ package sha2 import "core:encoding/endian" import "core:io" import "core:math/bits" -import "core:mem" import "core:os" /* @@ -482,8 +481,8 @@ init :: proc(ctx: ^$T) { } } - ctx.tot_len = 0 ctx.length = 0 + ctx.bitlength = 0 ctx.is_initialized = true } @@ -491,65 +490,72 @@ init :: proc(ctx: ^$T) { update :: proc(ctx: ^$T, data: []byte) { assert(ctx.is_initialized) - length := uint(len(data)) - block_nb: uint - new_len, rem_len, tmp_len: uint - shifted_message := make([]byte, length) - when T == Sha256_Context { CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE } else when T == Sha512_Context { CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE } - tmp_len = CURR_BLOCK_SIZE - ctx.length - rem_len = length < tmp_len ? length : tmp_len - copy(ctx.block[ctx.length:], data[:rem_len]) + data := data + ctx.length += u64(len(data)) - if ctx.length + length < CURR_BLOCK_SIZE { - ctx.length += length - return + if ctx.bitlength > 0 { + n := copy(ctx.block[ctx.bitlength:], data[:]) + ctx.bitlength += u64(n) + if ctx.bitlength == CURR_BLOCK_SIZE { + sha2_transf(ctx, ctx.block[:]) + ctx.bitlength = 0 + } + data = data[n:] } - - new_len = length - rem_len - block_nb = new_len / CURR_BLOCK_SIZE - shifted_message = data[rem_len:] - - sha2_transf(ctx, ctx.block[:], 1) - sha2_transf(ctx, shifted_message, block_nb) - - rem_len = new_len % CURR_BLOCK_SIZE - if rem_len > 0 { - when T == Sha256_Context {copy(ctx.block[:], shifted_message[block_nb << 6:rem_len])} else when T == Sha512_Context {copy(ctx.block[:], shifted_message[block_nb << 7:rem_len])} + if len(data) >= CURR_BLOCK_SIZE { + n := len(data) &~ (CURR_BLOCK_SIZE - 1) + sha2_transf(ctx, data[:n]) + data = data[n:] + } + if len(data) > 0 { + ctx.bitlength = u64(copy(ctx.block[:], data[:])) } - - ctx.length = rem_len - when T == Sha256_Context {ctx.tot_len += (block_nb + 1) << 6} else when T == Sha512_Context {ctx.tot_len += (block_nb + 1) << 7} } final :: proc(ctx: ^$T, hash: []byte) { assert(ctx.is_initialized) - block_nb, pm_len: uint - len_b: u64 - if len(hash) * 8 < ctx.md_bits { panic("crypto/sha2: invalid destination digest size") } - when T == Sha256_Context {CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE} else when T == Sha512_Context {CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE} + length := ctx.length - when T == Sha256_Context {block_nb = 1 + ((CURR_BLOCK_SIZE - 9) < (ctx.length % CURR_BLOCK_SIZE) ? 1 : 0)} else when T == Sha512_Context {block_nb = 1 + ((CURR_BLOCK_SIZE - 17) < (ctx.length % CURR_BLOCK_SIZE) ? 1 : 0)} + raw_pad: [SHA512_BLOCK_SIZE]byte + when T == Sha256_Context { + CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE + pm_len := 8 // 64-bits for length + } else when T == Sha512_Context { + CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE + pm_len := 16 // 128-bits for length + } + pad := raw_pad[:CURR_BLOCK_SIZE] + pad_len := u64(CURR_BLOCK_SIZE - pm_len) - len_b = u64(ctx.tot_len + ctx.length) << 3 - when T == Sha256_Context {pm_len = block_nb << 6} else when T == Sha512_Context {pm_len = block_nb << 7} + pad[0] = 0x80 + if length % CURR_BLOCK_SIZE < pad_len { + update(ctx, pad[0:pad_len - length % CURR_BLOCK_SIZE]) + } else { + update(ctx, pad[0:CURR_BLOCK_SIZE + pad_len - length % CURR_BLOCK_SIZE]) + } - mem.set(rawptr(&(ctx.block[ctx.length:])[0]), 0, int(pm_len - ctx.length)) - ctx.block[ctx.length] = 0x80 - - endian.unchecked_put_u64be(ctx.block[pm_len - 8:], len_b) - - sha2_transf(ctx, ctx.block[:], block_nb) + length_hi, length_lo := bits.mul_u64(length, 8) // Length in bits + when T == Sha256_Context { + _ = length_hi + endian.unchecked_put_u64be(pad[:], length_lo) + update(ctx, pad[:8]) + } else when T == Sha512_Context { + endian.unchecked_put_u64be(pad[:], length_hi) + endian.unchecked_put_u64be(pad[8:], length_lo) + update(ctx, pad[0:16]) + } + assert(ctx.bitlength == 0) when T == Sha256_Context { for i := 0; i < ctx.md_bits / 32; i += 1 { @@ -572,21 +578,21 @@ SHA256_BLOCK_SIZE :: 64 SHA512_BLOCK_SIZE :: 128 Sha256_Context :: struct { - tot_len: uint, - length: uint, - block: [128]byte, - h: [8]u32, - md_bits: int, + block: [SHA256_BLOCK_SIZE]byte, + h: [8]u32, + bitlength: u64, + length: u64, + md_bits: int, is_initialized: bool, } Sha512_Context :: struct { - tot_len: uint, - length: uint, - block: [256]byte, - h: [8]u64, - md_bits: int, + block: [SHA512_BLOCK_SIZE]byte, + h: [8]u64, + bitlength: u64, + length: u64, + md_bits: int, is_initialized: bool, } @@ -716,52 +722,46 @@ SHA512_F4 :: #force_inline proc "contextless" (x: u64) -> u64 { } @(private) -sha2_transf :: proc(ctx: ^$T, data: []byte, block_nb: uint) { +sha2_transf :: proc "contextless" (ctx: ^$T, data: []byte) { when T == Sha256_Context { w: [64]u32 wv: [8]u32 t1, t2: u32 + CURR_BLOCK_SIZE :: SHA256_BLOCK_SIZE } else when T == Sha512_Context { w: [80]u64 wv: [8]u64 t1, t2: u64 + CURR_BLOCK_SIZE :: SHA512_BLOCK_SIZE } - sub_block := make([]byte, len(data)) - i, j: i32 - - for i = 0; i < i32(block_nb); i += 1 { - when T == Sha256_Context { - sub_block = data[i << 6:] - } else when T == Sha512_Context { - sub_block = data[i << 7:] - } - - for j = 0; j < 16; j += 1 { + data := data + for len(data) >= CURR_BLOCK_SIZE { + for i := 0; i < 16; i += 1 { when T == Sha256_Context { - w[j] = endian.unchecked_get_u32be(sub_block[j << 2:]) + w[i] = endian.unchecked_get_u32be(data[i * 4:]) } else when T == Sha512_Context { - w[j] = endian.unchecked_get_u64be(sub_block[j << 3:]) + w[i] = endian.unchecked_get_u64be(data[i * 8:]) } } when T == Sha256_Context { - for j = 16; j < 64; j += 1 { - w[j] = SHA256_F4(w[j - 2]) + w[j - 7] + SHA256_F3(w[j - 15]) + w[j - 16] + for i := 16; i < 64; i += 1 { + w[i] = SHA256_F4(w[i - 2]) + w[i - 7] + SHA256_F3(w[i - 15]) + w[i - 16] } } else when T == Sha512_Context { - for j = 16; j < 80; j += 1 { - w[j] = SHA512_F4(w[j - 2]) + w[j - 7] + SHA512_F3(w[j - 15]) + w[j - 16] + for i := 16; i < 80; i += 1 { + w[i] = SHA512_F4(w[i - 2]) + w[i - 7] + SHA512_F3(w[i - 15]) + w[i - 16] } } - for j = 0; j < 8; j += 1 { - wv[j] = ctx.h[j] + for i := 0; i < 8; i += 1 { + wv[i] = ctx.h[i] } when T == Sha256_Context { - for j = 0; j < 64; j += 1 { - t1 = wv[7] + SHA256_F2(wv[4]) + SHA256_CH(wv[4], wv[5], wv[6]) + sha256_k[j] + w[j] + for i := 0; i < 64; i += 1 { + t1 = wv[7] + SHA256_F2(wv[4]) + SHA256_CH(wv[4], wv[5], wv[6]) + sha256_k[i] + w[i] t2 = SHA256_F1(wv[0]) + SHA256_MAJ(wv[0], wv[1], wv[2]) wv[7] = wv[6] wv[6] = wv[5] @@ -773,8 +773,8 @@ sha2_transf :: proc(ctx: ^$T, data: []byte, block_nb: uint) { wv[0] = t1 + t2 } } else when T == Sha512_Context { - for j = 0; j < 80; j += 1 { - t1 = wv[7] + SHA512_F2(wv[4]) + SHA512_CH(wv[4], wv[5], wv[6]) + sha512_k[j] + w[j] + for i := 0; i < 80; i += 1 { + t1 = wv[7] + SHA512_F2(wv[4]) + SHA512_CH(wv[4], wv[5], wv[6]) + sha512_k[i] + w[i] t2 = SHA512_F1(wv[0]) + SHA512_MAJ(wv[0], wv[1], wv[2]) wv[7] = wv[6] wv[6] = wv[5] @@ -787,8 +787,10 @@ sha2_transf :: proc(ctx: ^$T, data: []byte, block_nb: uint) { } } - for j = 0; j < 8; j += 1 { - ctx.h[j] += wv[j] + for i := 0; i < 8; i += 1 { + ctx.h[i] += wv[i] } + + data = data[CURR_BLOCK_SIZE:] } }