core/crypto/_chacha20: Use the precomputation trick for ref

Might as well bring this in sync with the runtime chacha8 version of the
code since this is faster.
This commit is contained in:
Yawning Angel
2026-01-28 23:04:38 +09:00
parent f5b7274a77
commit d438f27efb

View File

@@ -4,133 +4,68 @@ import "core:crypto/_chacha20"
import "core:encoding/endian"
import "core:math/bits"
// At least with LLVM21 force_inline produces identical perf to
// manual inlining, yay.
@(private)
quarter_round :: #force_inline proc "contextless" (a, b, c, d: u32) -> (u32, u32, u32, u32) {
a, b, c, d := a, b, c, d
a += b
d ~= a
d = bits.rotate_left32(d, 16)
c += d
b ~= c
b = bits.rotate_left32(b, 12)
a += b
d ~= a
d = bits.rotate_left32(d, 8)
c += d
b ~= c
b = bits.rotate_left32(b, 7)
return a, b, c, d
}
stream_blocks :: proc(ctx: ^_chacha20.Context, dst, src: []byte, nr_blocks: int) {
// Enforce the maximum consumed keystream per IV.
_chacha20.check_counter_limit(ctx, nr_blocks)
dst, src := dst, src
x := &ctx._s
// Filippo Valsorda made an observation that only one of the column
// round depends on the counter (s12), so it is worth precomputing
// and reusing across multiple blocks. As far as I know, only Go's
// chacha implementation does this.
p1, p5, p9, p13 := quarter_round(_chacha20.SIGMA_1, x[5], x[9], x[13])
p2, p6, p10, p14 := quarter_round(_chacha20.SIGMA_2, x[6], x[10], x[14])
p3, p7, p11, p15 := quarter_round(_chacha20.SIGMA_3, x[7], x[11], x[15])
for n := 0; n < nr_blocks; n = n + 1 {
x0, x1, x2, x3 :=
_chacha20.SIGMA_0, _chacha20.SIGMA_1, _chacha20.SIGMA_2, _chacha20.SIGMA_3
x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 :=
x[4], x[5], x[6], x[7], x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15]
// First column round that depends on the counter
p0, p4, p8, p12 := quarter_round(_chacha20.SIGMA_0, x[4], x[8], x[12])
for i := _chacha20.ROUNDS; i > 0; i = i - 2 {
// Even when forcing inlining manually inlining all of
// these is decently faster.
// First diagonal round
x0, x5, x10, x15 := quarter_round(p0, p5, p10, p15)
x1, x6, x11, x12 := quarter_round(p1, p6, p11, p12)
x2, x7, x8, x13 := quarter_round(p2, p7, p8, p13)
x3, x4, x9, x14 := quarter_round(p3, p4, p9, p14)
// quarterround(x, 0, 4, 8, 12)
x0 += x4
x12 ~= x0
x12 = bits.rotate_left32(x12, 16)
x8 += x12
x4 ~= x8
x4 = bits.rotate_left32(x4, 12)
x0 += x4
x12 ~= x0
x12 = bits.rotate_left32(x12, 8)
x8 += x12
x4 ~= x8
x4 = bits.rotate_left32(x4, 7)
for i := _chacha20.ROUNDS - 2; i > 0; i = i - 2 {
x0, x4, x8, x12 = quarter_round(x0, x4, x8, x12)
x1, x5, x9, x13 = quarter_round(x1, x5, x9, x13)
x2, x6, x10, x14 = quarter_round(x2, x6, x10, x14)
x3, x7, x11, x15 = quarter_round(x3, x7, x11, x15)
// quarterround(x, 1, 5, 9, 13)
x1 += x5
x13 ~= x1
x13 = bits.rotate_left32(x13, 16)
x9 += x13
x5 ~= x9
x5 = bits.rotate_left32(x5, 12)
x1 += x5
x13 ~= x1
x13 = bits.rotate_left32(x13, 8)
x9 += x13
x5 ~= x9
x5 = bits.rotate_left32(x5, 7)
// quarterround(x, 2, 6, 10, 14)
x2 += x6
x14 ~= x2
x14 = bits.rotate_left32(x14, 16)
x10 += x14
x6 ~= x10
x6 = bits.rotate_left32(x6, 12)
x2 += x6
x14 ~= x2
x14 = bits.rotate_left32(x14, 8)
x10 += x14
x6 ~= x10
x6 = bits.rotate_left32(x6, 7)
// quarterround(x, 3, 7, 11, 15)
x3 += x7
x15 ~= x3
x15 = bits.rotate_left32(x15, 16)
x11 += x15
x7 ~= x11
x7 = bits.rotate_left32(x7, 12)
x3 += x7
x15 ~= x3
x15 = bits.rotate_left32(x15, 8)
x11 += x15
x7 ~= x11
x7 = bits.rotate_left32(x7, 7)
// quarterround(x, 0, 5, 10, 15)
x0 += x5
x15 ~= x0
x15 = bits.rotate_left32(x15, 16)
x10 += x15
x5 ~= x10
x5 = bits.rotate_left32(x5, 12)
x0 += x5
x15 ~= x0
x15 = bits.rotate_left32(x15, 8)
x10 += x15
x5 ~= x10
x5 = bits.rotate_left32(x5, 7)
// quarterround(x, 1, 6, 11, 12)
x1 += x6
x12 ~= x1
x12 = bits.rotate_left32(x12, 16)
x11 += x12
x6 ~= x11
x6 = bits.rotate_left32(x6, 12)
x1 += x6
x12 ~= x1
x12 = bits.rotate_left32(x12, 8)
x11 += x12
x6 ~= x11
x6 = bits.rotate_left32(x6, 7)
// quarterround(x, 2, 7, 8, 13)
x2 += x7
x13 ~= x2
x13 = bits.rotate_left32(x13, 16)
x8 += x13
x7 ~= x8
x7 = bits.rotate_left32(x7, 12)
x2 += x7
x13 ~= x2
x13 = bits.rotate_left32(x13, 8)
x8 += x13
x7 ~= x8
x7 = bits.rotate_left32(x7, 7)
// quarterround(x, 3, 4, 9, 14)
x3 += x4
x14 ~= x3
x14 = bits.rotate_left32(x14, 16)
x9 += x14
x4 ~= x9
x4 = bits.rotate_left32(x4, 12)
x3 += x4
x14 ~= x3
x14 = bits.rotate_left32(x14, 8)
x9 += x14
x4 ~= x9
x4 = bits.rotate_left32(x4, 7)
x0, x5, x10, x15 = quarter_round(x0, x5, x10, x15)
x1, x6, x11, x12 = quarter_round(x1, x6, x11, x12)
x2, x7, x8, x13 = quarter_round(x2, x7, x8, x13)
x3, x4, x9, x14 = quarter_round(x3, x4, x9, x14)
}
x0 += _chacha20.SIGMA_0
@@ -236,117 +171,15 @@ hchacha20 :: proc "contextless" (dst, key, iv: []byte) {
x15 := endian.unchecked_get_u32le(iv[12:16])
for i := _chacha20.ROUNDS; i > 0; i = i - 2 {
// quarterround(x, 0, 4, 8, 12)
x0 += x4
x12 ~= x0
x12 = bits.rotate_left32(x12, 16)
x8 += x12
x4 ~= x8
x4 = bits.rotate_left32(x4, 12)
x0 += x4
x12 ~= x0
x12 = bits.rotate_left32(x12, 8)
x8 += x12
x4 ~= x8
x4 = bits.rotate_left32(x4, 7)
x0, x4, x8, x12 = quarter_round(x0, x4, x8, x12)
x1, x5, x9, x13 = quarter_round(x1, x5, x9, x13)
x2, x6, x10, x14 = quarter_round(x2, x6, x10, x14)
x3, x7, x11, x15 = quarter_round(x3, x7, x11, x15)
// quarterround(x, 1, 5, 9, 13)
x1 += x5
x13 ~= x1
x13 = bits.rotate_left32(x13, 16)
x9 += x13
x5 ~= x9
x5 = bits.rotate_left32(x5, 12)
x1 += x5
x13 ~= x1
x13 = bits.rotate_left32(x13, 8)
x9 += x13
x5 ~= x9
x5 = bits.rotate_left32(x5, 7)
// quarterround(x, 2, 6, 10, 14)
x2 += x6
x14 ~= x2
x14 = bits.rotate_left32(x14, 16)
x10 += x14
x6 ~= x10
x6 = bits.rotate_left32(x6, 12)
x2 += x6
x14 ~= x2
x14 = bits.rotate_left32(x14, 8)
x10 += x14
x6 ~= x10
x6 = bits.rotate_left32(x6, 7)
// quarterround(x, 3, 7, 11, 15)
x3 += x7
x15 ~= x3
x15 = bits.rotate_left32(x15, 16)
x11 += x15
x7 ~= x11
x7 = bits.rotate_left32(x7, 12)
x3 += x7
x15 ~= x3
x15 = bits.rotate_left32(x15, 8)
x11 += x15
x7 ~= x11
x7 = bits.rotate_left32(x7, 7)
// quarterround(x, 0, 5, 10, 15)
x0 += x5
x15 ~= x0
x15 = bits.rotate_left32(x15, 16)
x10 += x15
x5 ~= x10
x5 = bits.rotate_left32(x5, 12)
x0 += x5
x15 ~= x0
x15 = bits.rotate_left32(x15, 8)
x10 += x15
x5 ~= x10
x5 = bits.rotate_left32(x5, 7)
// quarterround(x, 1, 6, 11, 12)
x1 += x6
x12 ~= x1
x12 = bits.rotate_left32(x12, 16)
x11 += x12
x6 ~= x11
x6 = bits.rotate_left32(x6, 12)
x1 += x6
x12 ~= x1
x12 = bits.rotate_left32(x12, 8)
x11 += x12
x6 ~= x11
x6 = bits.rotate_left32(x6, 7)
// quarterround(x, 2, 7, 8, 13)
x2 += x7
x13 ~= x2
x13 = bits.rotate_left32(x13, 16)
x8 += x13
x7 ~= x8
x7 = bits.rotate_left32(x7, 12)
x2 += x7
x13 ~= x2
x13 = bits.rotate_left32(x13, 8)
x8 += x13
x7 ~= x8
x7 = bits.rotate_left32(x7, 7)
// quarterround(x, 3, 4, 9, 14)
x3 += x4
x14 ~= x3
x14 = bits.rotate_left32(x14, 16)
x9 += x14
x4 ~= x9
x4 = bits.rotate_left32(x4, 12)
x3 += x4
x14 ~= x3
x14 = bits.rotate_left32(x14, 8)
x9 += x14
x4 ~= x9
x4 = bits.rotate_left32(x4, 7)
x0, x5, x10, x15 = quarter_round(x0, x5, x10, x15)
x1, x6, x11, x12 = quarter_round(x1, x6, x11, x12)
x2, x7, x8, x13 = quarter_round(x2, x7, x8, x13)
x3, x4, x9, x14 = quarter_round(x3, x4, x9, x14)
}
endian.unchecked_put_u32le(dst[0:4], x0)