From 2623a8fa9fda57938610785b2cc963f6cb6cacd8 Mon Sep 17 00:00:00 2001 From: Lord_Hellgrim Date: Tue, 14 Apr 2026 22:44:47 +0000 Subject: [PATCH] Added noise package to core:crypto Special thanks to Yawning for review, guidance, and massive updates. --- core/crypto/noise/api.odin | 429 +++++++++ core/crypto/noise/doc.odin | 35 + core/crypto/noise/patterns.odin | 686 ++++++++++++++ core/crypto/noise/protocol.odin | 1076 ++++++++++++++++++++++ core/crypto/noise/test_crypto_noise.odin | 305 ++++++ 5 files changed, 2531 insertions(+) create mode 100644 core/crypto/noise/api.odin create mode 100644 core/crypto/noise/doc.odin create mode 100644 core/crypto/noise/patterns.odin create mode 100644 core/crypto/noise/protocol.odin create mode 100644 core/crypto/noise/test_crypto_noise.odin diff --git a/core/crypto/noise/api.odin b/core/crypto/noise/api.odin new file mode 100644 index 000000000..da890e3c6 --- /dev/null +++ b/core/crypto/noise/api.odin @@ -0,0 +1,429 @@ +package noise + +import "base:runtime" +import "core:crypto/ecdh" + +// MAX_PACKET_SIZE is the maximum Noise message size, including TAG_SIZE +// if relevant (`seal_message`, `open_message`). +MAX_PACKET_SIZE :: 65535 + +// PSK_SIZE is the size of an optional handshake pre-shared symmetric key. +PSK_SIZE :: 32 +// TAG_SIZE is the size of the AEAD authentication tag. +TAG_SIZE :: 16 +// MAX_STEP_MSG_SIZE is the maximum per-handshake step message size, +// excluding the optional payload. +// +// `e` is DH_LEN, `s` is either DH_LEN or DH_LEN + TAG_SIZE, and there +// is a maximum of one per each message. +MAX_STEP_MSG_SIZE :: (MAX_DH_SIZE*2)+TAG_SIZE + +// Status is the status of Noise protocol operation. +Status :: enum { + Ok, + + // States + Handshake_Pending, + Handshake_Complete, + Handshake_Split, + Handshake_Failed, + + // Errors + Invalid_Protocol_String, + Invalid_Pre_Shared_Key, + Invalid_DH_Key, + No_Self_Identity, + No_Peer_Identity, + Unexpected_Peer_Identity, + Unexpected_Pre_Shared_Key, + + DH_Failure, + Invalid_Handshake_Message, + + Decryption_Failure, + IV_Exhausted, + Invalid_Cipher_State, + Invalid_Destination_Buffer, + Invalid_Payload_Message, + Max_Packet_Size, + + Out_Of_Memory, +} + +// Handshake_State is the per-handshake state. +Handshake_State :: struct { + s: ecdh.Private_Key, + e: ecdh.Private_Key, + rs: ecdh.Public_Key, + re: ecdh.Public_Key, + psk: [PSK_SIZE]byte, + + symmetric_state: Symmetric_State, + message_pattern: ^Message_Pattern, + current_message: int, + + status: Status, + + initiator: bool, + pre_set_e: bool, +} + +// Cipher_States are the keyed AEAD instances and associated state, +// derived from a successful handshake. +Cipher_States :: struct { + c1_i_to_r: Cipher_State, + c2_r_to_i: Cipher_State, + + initiator: bool, +} + +// handshake_init initializes a Handshake_State with the provided parameters. +// The relevant values are copied into the Handshake_State instance, and +// can be discarded/sanitized right after handshake_init returns (eg: psk). +// +// Note: While this implementation supports setting `e`, this is primarily +// intended for testing, or cases where the runtime cryptographic entropy +// source is unavailable. Use of this functionality is STRONGLY +// discouraged. +@(require_results) +handshake_init :: proc( + self: ^Handshake_State, + initiator: bool, + prologue: []byte, + s: ^ecdh.Private_Key, // Our static key + rs: ^ecdh.Public_Key, // Peer static key + protocol_name: string, + psk: []byte = nil, + _e: ^ecdh.Private_Key = nil, // Our ephemeral key (for testing/RNG-less systems) +) -> Status { + return handshakestate_Initialize( + self, + initiator, + prologue, + s, + _e, + rs, + nil, + protocol_name, + psk, + ) +} + +// handshake_initiator_step takes an input_message received from the responder +// if any and an optional payload to be sent to the responder, and performs +// one step of the Noise handshake process, returning the message to be sent +// to the responder if any, the payload received from the responder if any, +// and the status of the handshake. +// +// The output message MUST be sent to the responder even if the status code +// returned is .Handshake_Complete. +// +// If the dst parameter is provided, the message and payload will be written +// to dst, otherwise new buffers will be allocated. +@(require_results) +handshake_initiator_step :: proc( + self: ^Handshake_State, + input_message: []byte, + payload: []byte = nil, + dst: []byte = nil, + allocator := context.allocator, +) -> ([]byte, []byte, Status) { + output_message: []byte + payload_buffer: []byte + status: Status + + dst := dst + if input_message == nil { + output_message, status = handshakestate_WriteMessage(self, payload, dst, allocator) + } else { + payload_buffer, status = handshakestate_ReadMessage(self, input_message, dst, allocator) + if status == .Handshake_Pending { + if dst != nil { + dst = dst[len(payload_buffer):] + } + output_message, status = handshakestate_WriteMessage(self, payload, dst, allocator) + } + } + + return output_message, payload_buffer, status +} + +// handshake_responder_step takes a input_message received from the initiator, +// and and an optional payload to be sent to the initiator, and performs +// one step of the Noise handshake process, returning the message to be sent +// to the initiator if any, the payload received from the initiator if any, +// and the status of the handshake. +// +// The output message MUST be sent to the initiator even if the status code +// returned is .Handshake_Complete. +// +// If the dst parameter is provided, the message and payload will be written +// to dst, otherwise new buffers will be allocated. +@(require_results) +handshake_responder_step :: proc( + self: ^Handshake_State, + input_message: []byte, + payload: []byte = nil, + dst: []byte = nil, + allocator := context.allocator, +) -> ([]byte, []byte, Status) { + output_message: []byte + + if input_message == nil { + return nil, nil, .Invalid_Handshake_Message + } + + dst := dst + payload_buffer, status := handshakestate_ReadMessage(self, input_message, dst, allocator) + if status == .Handshake_Pending { + if dst != nil { + dst = dst[len(payload_buffer):] + } + output_message, status = handshakestate_WriteMessage(self, payload, dst, allocator) + } + + return output_message, payload_buffer, status +} + +// handshake_split initializes a Cipher_States instance from a completed +// handshake. This can be called once and only once per Handshake_State +// instance. +@(require_results) +handshake_split :: proc(self: ^Handshake_State, cipher_states: ^Cipher_States) -> Status { + if self.status != .Handshake_Complete { + return self.status + } + + symmetricstate_Split(&self.symmetric_state, cipher_states) + if self.message_pattern.is_one_way { + cipherstate_reset(&cipher_states.c2_r_to_i) + cipher_states.c2_r_to_i.is_invalid = true + } + cipher_states.initiator = self.initiator + self.status = .Handshake_Split + + return .Ok +} + +// handshake_peer_identity returns the peer's static DH key used by +// a completed handshake. +// +// This returns a pointer to the Handshake_State's copy of the peer's +// public key, that will get wiped by handshake_reset. If the key is +// needed after a call to handshake_reset, it must be copied. +@(require_results) +handshake_peer_identity :: proc(self: ^Handshake_State) -> (^ecdh.Public_Key, Status) { + #partial switch self.status { + case .Handshake_Complete, .Handshake_Split: + case: + return nil, self.status + } + + if ecdh.curve(&self.rs) == .Invalid { + return nil, .No_Peer_Identity + } + + return &self.rs, .Ok +} + +// handshake_hash returns the handshake transcript hash of a completed +// handshake, for the purposes of channel binding. See 11.2 of the +// specification for details on usage. +// +// This returns a slice to an internal buffer that will get wiped by +// handshake_reset. If the hash is needed after a call to handshake_reset, +// the slice must be copied. +@(require_results) +handshake_hash :: proc(self: ^Handshake_State) -> ([]byte, Status) { + #partial switch self.status { + case .Handshake_Complete, .Handshake_Split: + case: + return nil, self.status + } + + return symmetricstate_GetHandshakeHash(&self.symmetric_state), .Ok +} + +// handshake_reset sanitizes the Handshake_State. It is both safe and +// recommended to call this as soon as practical (after any calls to +// handshake_peer_identity, handshake_hash, and handshake_split are +// complete). +handshake_reset :: proc(self: ^Handshake_State) { + handshakestate_reset(self) +} + +// seal_message encrypts the provided data, authenticates the aad and +// ciphertext, and returns the resulting ciphertext. The ciphertext +// will ALWAYS be `len(plaintext) + TAG_SIZE` bytes in length. +// +// If the dst parameter is provided, the ciphertext will be written +// to dst, otherwise a new buffer will be allocated. +@(require_results) +seal_message :: proc(self: ^Cipher_States, aad, plaintext: []byte, dst: []byte = nil, allocator := context.allocator) -> ([]byte, Status) { + data_len := len(plaintext) + + dst := dst + did_alloc: bool + switch { + case dst == nil: + err: runtime.Allocator_Error + dst, err = make([]byte, data_len + TAG_SIZE, allocator) + if err != nil { + return nil, .Out_Of_Memory + } + did_alloc = true + case: + if len(dst) != data_len + TAG_SIZE { + return nil, .Invalid_Destination_Buffer + } + } + + status: Status + switch self.initiator { + case true: + dst, status = cipherstate_EncryptWithAd(&self.c1_i_to_r, aad, plaintext, dst) + case false: + dst, status = cipherstate_EncryptWithAd(&self.c2_r_to_i, aad, plaintext, dst) + } + if status != .Ok && did_alloc { + delete(dst, allocator) + dst = nil + } + + return dst, status +} + +// open_message authenticates the aad and ciphertext, decrypts the +// ciphertext and returns the resulting plaintext. The plaintext will +// ALWAYS be `len(ciphertext) - TAG_SIZE` bytes in length. +// +// If the dst parameter is provided, the plaintext will be written to +// dst, otherwise a new buffer will be allocated. +@(require_results) +open_message :: proc(self: ^Cipher_States, aad, ciphertext: []byte, dst: []byte = nil, allocator := context.allocator) -> ([]byte, Status) { + if len(ciphertext) < TAG_SIZE { + return nil, .Invalid_Payload_Message + } + + data_len := len(ciphertext) - TAG_SIZE + + dst := dst + did_alloc: bool + switch { + case dst == nil: + if data_len > 0 { + err: runtime.Allocator_Error + dst, err = make([]byte, data_len, allocator) + if err != nil { + return nil, .Out_Of_Memory + } + did_alloc = true + } + case: + if len(dst) != data_len { + return nil, .Invalid_Destination_Buffer + } + } + + status: Status + switch self.initiator { + case true: + dst, status = cipherstate_DecryptWithAd(&self.c2_r_to_i, aad, ciphertext, dst) + case false: + dst, status = cipherstate_DecryptWithAd(&self.c1_i_to_r, aad, ciphertext, dst) + } + if status != .Ok && did_alloc { + delete(dst, allocator) + dst = nil + } + + return dst, status +} + +// cipherstates_rekey updates the selected AEAD key, using a one way function. +// See 11.3 of the specification for examples of usage. +// +// Note: If one side updates the seal_key, the other side must update +// the non-seal_key and vice versa. +@(require_results) +cipherstates_rekey :: proc(self: ^Cipher_States, seal_key: bool) -> Status { + cs := cipherstates_cs(self, seal_key) + if cs.is_invalid { + return .Invalid_Cipher_State + } + if !cipherstate_HasKey(cs) { + return .Handshake_Pending + } + + cipherstate_Rekey(cs) + + return .Ok +} + +// cipherstates_set_n sets the interal counter used to generate the AEAD +// IV to an explicit value. This can be used to deal with out-of-order +// transport messages. See 11.4 of the specification. +// +// WARNING: Reusing n across different aad/messages with the same Cipher_States +// will result in catastrophic loss of security. +@(require_results) +cipherstates_set_n :: proc(self: ^Cipher_States, seal_key: bool, n: u64) -> Status { + cs := cipherstates_cs(self, seal_key) + if cs.is_invalid { + return .Invalid_Cipher_State + } + if !cipherstate_HasKey(cs) { + return .Handshake_Pending + } + + cs.n = n + + return .Ok +} + +// cipherstates_n returns the interal counter used to generate the AEAD +// IV. This can be used to deal with out-of-order transport messages. +// See 11.4 of the specification. +// +// WARNING: Reusing n across different aad/messages with the same Cipher_States +// will result in catastrophic loss of security. +@(require_results) +cipherstates_n :: proc(self: ^Cipher_States, seal_key: bool, n: u64) -> (u64, Status) { + cs := cipherstates_cs(self, seal_key) + if cs.is_invalid { + return 0, .Invalid_Cipher_State + } + if !cipherstate_HasKey(cs) { + return 0, .Handshake_Pending + } + + return cs.n, .Ok +} + +// cipherstates_reset sanitizes the Cipher_States. +cipherstates_reset :: proc(self: ^Cipher_States) { + self.initiator = false + cipherstate_reset(&self.c1_i_to_r) + cipherstate_reset(&self.c2_r_to_i) +} + +@(private = "file") +cipherstates_cs :: proc(self: ^Cipher_States, seal_key: bool) -> ^Cipher_State { + switch self.initiator { + case true: + switch seal_key { + case true: + return &self.c1_i_to_r + case false: + return &self.c2_r_to_i + } + case false: + switch seal_key { + case true: + return &self.c2_r_to_i + case false: + return &self.c1_i_to_r + } + } + unreachable() +} diff --git a/core/crypto/noise/doc.odin b/core/crypto/noise/doc.odin new file mode 100644 index 000000000..49416c6f0 --- /dev/null +++ b/core/crypto/noise/doc.odin @@ -0,0 +1,35 @@ +/* +An implementation of the Noise Protocol Framework (Revision 34). + +The `fallback` modifier and deferred/multi-PSK patterns are not supported +for the sake of simplicity. + +See: +- [[ https://noiseprotocol.org/ ]] +*/ +package noise + +// In general, to complete a noise handshake you must: +// +// - If you are initiating the connection, call `handshake_initiator_step` +// passing `nil` as the `input_message` parameter. +// +// - Send the resulting `[]byte` to the responder (generally a server) via +// the method of your choice. This MUST be done even if the status code +// returned is `.Handshake_Complete`. +// +// - If the status code returned by `handshake_initiator_step` was +// `.Handshake_Complete`, the handshake completed successfully, +// and it is now possible to validate the peer identity, obtain the +// handshake transcript hash, and most usefully call `handshake_split` +// to populate the `Cipher_States` struct that will be used to +// encrypt/decrypt data. +// +// Otherwise, read the response from the responder and feed the response +// data as the `input_message` to the next `handshake_initiator_step` +// until it returns `.Handshake_Complete`. +// +// - If you are the responder, the method is much the same, except you +// must pass a valid `input_message` received from an initiator to the +// first call to `handshake_responder_step`. Repeat until the returned +// status is `.Handshake_Complete`. diff --git a/core/crypto/noise/patterns.odin b/core/crypto/noise/patterns.odin new file mode 100644 index 000000000..73a4221f0 --- /dev/null +++ b/core/crypto/noise/patterns.odin @@ -0,0 +1,686 @@ +package noise + +import "core:slice" + +@(private) +Pre_Token :: enum { + res_s, + ini_s, +} + +@(private) +Token :: enum { + e, + s, + ee, + es, + se, + ss, + psk, +} + +@(private) +Message_Pattern :: struct { + pre_messages: []Pre_Token, + messages: [][]Token, + is_psk: bool, + is_one_way: bool, +} + +// Handshake_Pattern is the list of currently supported Noise Handshake +// Patterns. +Handshake_Pattern :: enum { + Invalid, + + // One way patterns + N, + K, + X, + + // Fundamental patterns + XX, + NK, + NN, + KN, + KK, + NX, + KX, + XN, + IN, + XK, + IK, + IX, + + // Recommended PSK patterns + Npsk0, + Kpsk0, + Xpsk1, + NNpsk0, + NNpsk2, + NKpsk0, + NKpsk2, + NXpsk2, + XNpsk3, + XKpsk3, + XXpsk3, + KNpsk0, + KNpsk2, + KKpsk0, + KKpsk2, + KXpsk2, + INpsk1, + INpsk2, + IKpsk1, + IKpsk2, + IXpsk2, +} + +@(require_results) +pattern_requires_initiator_s :: proc(pattern: Handshake_Pattern) -> (pre: bool, hs: bool) { + p := HANDSHAKE_PATTERNS[pattern] + if slice.contains(p.pre_messages, Pre_Token.ini_s) { + pre = true + } + for msg, i in p.messages { + if i & 1 != 0 { + continue + } + if slice.contains(msg, Token.s) { + hs = true + break + } + } + return pre, hs +} + +@(require_results) +pattern_requires_responder_s :: proc(pattern: Handshake_Pattern) -> (pre: bool, hs: bool) { + p := HANDSHAKE_PATTERNS[pattern] + if slice.contains(p.pre_messages, Pre_Token.res_s) { + pre = true + } + for msg, i in p.messages { + if i & 1 == 0 { + continue + } + if slice.contains(msg, Token.s) { + hs = true + break + } + } + return pre, hs +} + +@(require_results) +pattern_is_psk :: proc(pattern: Handshake_Pattern) -> bool { + return HANDSHAKE_PATTERNS[pattern].is_psk +} + +@(require_results) +pattern_is_one_way :: proc(pattern: Handshake_Pattern) -> bool { + return HANDSHAKE_PATTERNS[pattern].is_one_way +} + +@(require_results) +pattern_num_messages :: proc(pattern: Handshake_Pattern) -> int { + return len(HANDSHAKE_PATTERNS[pattern].messages) +} + +@(private) +HANDSHAKE_PATTERNS := [Handshake_Pattern]^Message_Pattern { + .Invalid = nil, + .N = &PATTERN_N, + .K = &PATTERN_K, + .X = &PATTERN_X, + .XX = &PATTERN_XX, + .NK = &PATTERN_NK, + .NN = &PATTERN_NN, + .KN = &PATTERN_KN, + .KK = &PATTERN_KK, + .NX = &PATTERN_NX, + .KX = &PATTERN_KX, + .XN = &PATTERN_XN, + .IN = &PATTERN_IN, + .XK = &PATTERN_XK, + .IK = &PATTERN_IK, + .IX = &PATTERN_IX, + .Npsk0 = &PATTERN_Npsk0, + .Kpsk0 = &PATTERN_Kpsk0, + .Xpsk1 = &PATTERN_Xpsk1, + .NNpsk0 = &PATTERN_NNpsk0, + .NNpsk2 = &PATTERN_NNpsk2, + .NKpsk0 = &PATTERN_NKpsk0, + .NKpsk2 = &PATTERN_NKpsk2, + .NXpsk2 = &PATTERN_NXpsk2, + .XNpsk3 = &PATTERN_XNpsk3, + .XKpsk3 = &PATTERN_XKpsk3, + .XXpsk3 = &PATTERN_XXpsk3, + .KNpsk0 = &PATTERN_KNpsk0, + .KNpsk2 = &PATTERN_KNpsk2, + .KKpsk0 = &PATTERN_KKpsk0, + .KKpsk2 = &PATTERN_KKpsk2, + .KXpsk2 = &PATTERN_KXpsk2, + .INpsk1 = &PATTERN_INpsk1, + .INpsk2 = &PATTERN_INpsk2, + .IKpsk1 = &PATTERN_IKpsk1, + .IKpsk2 = &PATTERN_IKpsk2, + .IXpsk2 = &PATTERN_IXpsk2, +} + +// ------------- ONE WAY PATTERNS --------------------------------------------------------- + +// N: +// <- s +// ... +// -> e, es +@(private,rodata) +PATTERN_N : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es}, + }, + is_one_way = true, +} + +// K: +// -> s +// <- s +// ... +// -> e, es, ss +@(private,rodata) +PATTERN_K : Message_Pattern = { + pre_messages = {.ini_s, .res_s}, + messages = { + {.e, .es, .ss}, + }, + is_one_way = true, +} + +// X: +// <- s +// ... +// -> e, es, s, ss +@(private,rodata) +PATTERN_X : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es, .s, .ss}, + }, + is_one_way = true, +} + +// ---------------------------------------------------------------------------------------- + +// ------------- FUNDAMENTAL PATTERNS ----------------------------------------------------- + +// XX: +// -> e +// <- e, ee, s, es +// -> s, se +@(private,rodata) +PATTERN_XX : Message_Pattern = { + pre_messages = nil, + messages = { + {.e}, + {.e, .ee, .s, .es}, + {.s, .se}, + }, +} + +// NK: +// <- s +// ... +// -> e, es +// <- e, ee +@(private,rodata) +PATTERN_NK : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es}, + {.e, .ee}, + }, +} + +// NN: +// -> e +// <- e, ee +@(private,rodata) +PATTERN_NN : Message_Pattern = { + pre_messages = nil, + messages = { + {.e}, + {.e, .ee}, + }, +} + +// KN: +// -> s +// ... +// -> e +// <- e, ee, se +@(private,rodata) +PATTERN_KN : Message_Pattern = { + pre_messages = {.ini_s}, + messages = { + {.e,}, + {.e, .ee, .se}, + }, +} + +// KK: +// -> s +// <- s +// ... +// -> e, es, ss +// <- e, ee, se +@(private,rodata) +PATTERN_KK : Message_Pattern = { + pre_messages = {.ini_s, .res_s}, + messages = { + {.e, .es, .ss}, + {.e, .ee, .se}, + }, +} + +// NX: +// -> e +// <- e, ee, s, es +@(private,rodata) +PATTERN_NX : Message_Pattern = { + pre_messages = nil, + messages = { + {.e}, + {.e, .ee, .s, .es}, + }, +} + +// KX: +// -> s +// ... +// -> e +// <- e, ee, se, s, es +@(private,rodata) +PATTERN_KX : Message_Pattern = { + pre_messages = {.ini_s}, + messages = { + {.e}, + {.e, .ee, .se, .s, .es}, + }, +} + +// XN: +// -> e +// <- e, ee +// -> s, se +@(private,rodata) +PATTERN_XN : Message_Pattern = { + pre_messages = nil, + messages = { + {.e}, + {.e, .ee}, + {.s, .se}, + }, +} + +// IN: +// -> e, s +// <- e, ee, se +@(private,rodata) +PATTERN_IN : Message_Pattern = { + pre_messages = nil, + messages = { + {.e, .s}, + {.e, .ee, .se}, + }, +} + +// XK: +// <- s +// ... +// -> e, es +// <- e, ee +// -> s, se +@(private,rodata) +PATTERN_XK : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es}, + {.e, .ee}, + {.s, .se}, + }, +} + +// IK: +// <- s +// ... +// -> e, es, s, ss +// <- e, ee, se +@(private,rodata) +PATTERN_IK : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es, .s, .ss}, + {.e, .ee, .se}, + }, +} + +// IX: +// -> e, s +// <- e, ee, se, s, es +@(private,rodata) +PATTERN_IX : Message_Pattern = { + pre_messages = nil, + messages = { + {.e, .s}, + {.e, .ee, .se, .s, .es}, + }, +} + +// ---------------------------------------------------------------------------------------- + +// ------------- PSK PATTERNS ------------------------------------------------------------- + +// Npsk0: +// <- s +// ... +// -> psk, e, es +@(private,rodata) +PATTERN_Npsk0 : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.psk, .e, .es}, + }, + is_psk = true, + is_one_way = true, +} + +// K: +// -> s +// <- s +// ... +// -> psk, e, es, ss +@(private,rodata) +PATTERN_Kpsk0 : Message_Pattern = { + pre_messages = {.ini_s, .res_s}, + messages = { + {.psk, .e, .es, .ss}, + }, + is_psk = true, + is_one_way = true, +} + +// X: +// <- s +// ... +// -> e, es, s, ss, psk +@(private,rodata) +PATTERN_Xpsk1 : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es, .s, .ss, .psk}, + }, + is_psk = true, + is_one_way = true, +} + +// NNpsk0: +// -> psk, e +// <- e, ee +@(private,rodata) +PATTERN_NNpsk0 : Message_Pattern = { + pre_messages = nil, + messages = { + {.psk, .e}, + {.e, .ee}, + }, + is_psk = true, +} + +// NNpsk2: +// -> e +// <- e, ee, psk +@(private,rodata) +PATTERN_NNpsk2 : Message_Pattern = { + pre_messages = nil, + messages = { + {.e}, + {.e, .ee, .psk}, + }, + is_psk = true, +} + +// NKpsk0: +// <- s +// ... +// -> psk, e, es +// <- e, ee +@(private,rodata) +PATTERN_NKpsk0 : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.psk, .e, .es}, + {.e, .ee}, + }, + is_psk = true, +} + +// NKpsk2: +// <- s +// ... +// -> e, es +// <- e, ee, psk +@(private,rodata) +PATTERN_NKpsk2 : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es}, + {.e, .ee, .psk}, + }, + is_psk = true, +} + +// NXpsk2: +// -> e +// <- e, ee, s, es, psk +@(private,rodata) +PATTERN_NXpsk2 : Message_Pattern = { + pre_messages = nil, + messages = { + {.e}, + {.e, .ee, .s, .es, .psk}, + }, + is_psk = true, +} + +// XNpsk3: +// -> e +// <- e, ee +// -> s, se, psk +@(private,rodata) +PATTERN_XNpsk3 : Message_Pattern = { + pre_messages = nil, + messages = { + {.e}, + {.e, .ee}, + {.s, .se, .psk}, + }, + is_psk = true, +} + +// XKpsk3: +// <- s +// ... +// -> e, es +// <- e, ee +// -> s, se, psk +@(private,rodata) +PATTERN_XKpsk3 : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es}, + {.e, .ee}, + {.s, .se, .psk}, + }, + is_psk = true, +} + +// XXpsk3: +// -> e +// <- e, ee, s, es +// -> s, se, psk +@(private,rodata) +PATTERN_XXpsk3 : Message_Pattern = { + pre_messages = nil, + messages = { + {.e}, + {.e, .ee, .s, .es}, + {.s, .se, .psk}, + }, + is_psk = true, +} + +// KNpsk0: +// -> s +// ... +// -> psk, e +// <- e, ee, se +@(private,rodata) +PATTERN_KNpsk0 : Message_Pattern = { + pre_messages = {.ini_s}, + messages = { + {.psk, .e}, + {.e, .ee, .se}, + }, + is_psk = true, +} + +// KNpsk2: +// -> s +// ... +// -> e +// <- e, ee, se, psk +@(private,rodata) +PATTERN_KNpsk2 : Message_Pattern = { + pre_messages = {.ini_s}, + messages = { + {.e}, + {.e, .ee, .se, .psk}, + }, + is_psk = true, +} + +// KKpsk0: +// -> s +// <- s +// ... +// -> psk, e, es, ss +// <- e, ee, se +@(private,rodata) +PATTERN_KKpsk0 : Message_Pattern = { + pre_messages = {.ini_s, .res_s}, + messages = { + {.psk, .e, .es, .ss}, + {.e, .ee, .se}, + }, + is_psk = true, +} + +// KKpsk2: +// -> s +// <- s +// ... +// -> e, es, ss +// <- e, ee, se, psk +@(private,rodata) +PATTERN_KKpsk2 : Message_Pattern = { + pre_messages = {.ini_s, .res_s}, + messages = { + {.e, .es, .ss}, + {.e, .ee, .se, .psk}, + }, + is_psk = true, +} + +// KXpsk2: +// -> s +// ... +// -> e +// <- e, ee, se, s, es, psk +@(private,rodata) +PATTERN_KXpsk2 : Message_Pattern = { + pre_messages = {.ini_s}, + messages = { + {.e}, + {.e, .ee, .se, .s, .es, .psk}, + }, + is_psk = true, +} + +// INpsk1: +// -> e, s, psk +// <- e, ee, se +@(private,rodata) +PATTERN_INpsk1 : Message_Pattern = { + pre_messages = nil, + messages = { + {.e, .s, .psk}, + {.e, .ee, .se}, + }, + is_psk = true, +} + +// INpsk2: +// -> e, s +// <- e, ee, se, psk +@(private,rodata) +PATTERN_INpsk2 : Message_Pattern = { + pre_messages = nil, + messages = { + {.e, .s}, + {.e, .ee, .se, .psk}, + }, + is_psk = true, +} + +// IKpsk1: +// <- s +// ... +// -> e, es, s, ss, psk +// <- e, ee, se +@(private,rodata) +PATTERN_IKpsk1 : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es, .s, .ss, .psk}, + {.e, .ee, .se}, + }, + is_psk = true, +} + +// IKpsk2: +// <- s +// ... +// -> e, es, s, ss +// <- e, ee, se, psk +@(private,rodata) +PATTERN_IKpsk2 : Message_Pattern = { + pre_messages = {.res_s}, + messages = { + {.e, .es, .s, .ss}, + {.e, .ee, .se, .psk}, + }, + is_psk = true, +} + +// IXpsk2: +// -> e, s +// <- e, ee, se, s, es, psk +@(private,rodata) +PATTERN_IXpsk2 : Message_Pattern = { + pre_messages = nil, + messages = { + {.e, .s}, + {.e, .ee, .se, .s, .es, .psk}, + }, + is_psk = true, +} diff --git a/core/crypto/noise/protocol.odin b/core/crypto/noise/protocol.odin new file mode 100644 index 000000000..b96343817 --- /dev/null +++ b/core/crypto/noise/protocol.odin @@ -0,0 +1,1076 @@ +#+private +package noise + +import "base:runtime" +import "core:crypto" +import "core:crypto/aead" +import "core:crypto/ecdh" +import "core:crypto/hash" +import "core:crypto/hkdf" +import "core:encoding/endian" +import "core:slice" +import "core:strings" + +AEAD_KEY_SIZE :: 32 + +MIN_DH_SIZE :: 32 +MAX_DH_SIZE :: 56 +MAX_HASH_SIZE :: 64 + +Protocol :: struct { + handshake_pattern: Handshake_Pattern, + dh: ecdh.Curve, + cipher: aead.Algorithm, + hash: hash.Algorithm, +} + +Symmetric_State :: struct { + protocol: Protocol, + cipher_state: Cipher_State, + + _ck: [MAX_HASH_SIZE]byte, + _h: [MAX_HASH_SIZE]byte, +} + +Cipher_State :: struct { + ctx: aead.Context, + n: u64, + n_exhausted: bool, + is_invalid: bool, +} + +@(require_results) +dh_len :: proc(protocol: ^Protocol) -> int { + return ecdh.PUBLIC_KEY_SIZES[protocol.dh] +} + +@(require_results) +hash_len :: proc(protocol: ^Protocol) -> int { + return hash.DIGEST_SIZES[protocol.hash] +} + +// Generates a new Diffie-Hellman key pair. A DH key pair consists of +// public_key and private_key elements. public_key represents an encoding +// of a DH public key into a byte sequence of length DHLEN. The public_key +// encoding details are specific to each set of DH functions. +GENERATE_KEYPAIR :: proc(protocol: ^Protocol, private_key: ^ecdh.Private_Key) { + #partial switch protocol.dh { + case .X25519, .X448: + case: panic("crypto/noise: unsupported DH curve in protocol") + } + + ecdh.private_key_generate(private_key, protocol.dh) +} + +// Performs a Diffie-Hellman calculation between the private key in key_pair +// and the public_key and returns an output sequence of bytes of length DHLEN. +// For security, the Gap-DH problem based on this function must be unsolvable +// by any practical cryptanalytic adversary [2]. +// +// The public_key either encodes some value which is a generator in a +// large prime-order group (which value may have multiple equivalent +// encodings), or is an invalid value. Implementations must handle invalid +// public keys either by returning some output which is purely a function +// of the public key and does not depend on the private key, or by signaling +// an error to the caller. +// +// The DH function may define more specific rules for handling invalid values. +@(require_results) +DH :: proc(our_private_key: ^ecdh.Private_Key, their_public_key: ^ecdh.Public_Key, dst: []byte) -> Status { + if ok := ecdh.ecdh(our_private_key, their_public_key, dst); !ok { + return .DH_Failure + } + return .Ok +} + +// Encrypts plaintext using the cipher key k of 32 bytes and an 8-byte +// unsigned integer nonce n which must be unique for the key k. +// Returns the ciphertext. Encryption must be done with an "AEAD" +// encryption mode with the associated data(AD) (using the terminology +// from [1]) and returns a ciphertext that is the same size as the plaintext +// plus 16 bytes for authentication data. The entire ciphertext must be +// indistinguishable from random if the key is secret (note that this is +// an additional requirement that isn't necessarily met by all AEAD schemes). +ENCRYPT :: proc(ctx: ^aead.Context, n: u64, ad, plaintext, dst: []byte) { + pt_len := len(plaintext) + ensure(len(dst) == pt_len + TAG_SIZE, "crypto/noise: invalid AEAD encrypt destination") + + iv: [12]byte + #partial switch aead.algorithm(ctx) { + case .AES_GCM_256: endian.unchecked_put_u64be(iv[4:], n) + case .CHACHA20POLY1305: endian.unchecked_put_u64le(iv[4:], n) + } + + ciphertext, tag := dst[:pt_len], dst[pt_len:] + aead.seal_ctx(ctx, ciphertext, tag, iv[:], ad, plaintext) +} + +// Decrypts ciphertext using a cipher key k of 32 bytes, an 8-byte unsigned +// integer nonce n, and associated data ad. Returns the plaintext, unless +// authentication fails, in which case an error is signaled to the caller. +@(require_results) +DECRYPT :: proc(ctx: ^aead.Context, n: u64, ad, ciphertext, dst: []byte) -> Status { + if len(ciphertext) < TAG_SIZE { + return .Decryption_Failure + } + + iv: [12]byte + #partial switch aead.algorithm(ctx) { + case .AES_GCM_256: endian.unchecked_put_u64be(iv[4:], n) + case .CHACHA20POLY1305: endian.unchecked_put_u64le(iv[4:], n) + } + + ct_len := len(ciphertext) - TAG_SIZE + ct, tag := ciphertext[:ct_len], ciphertext[ct_len:] + if ok := aead.open_ctx(ctx, dst, iv[:], ad, ct, tag); !ok { + return .Decryption_Failure + } + + return .Ok +} + +// Hashes some arbitrary-length data with a collision-resistant cryptographic +// hash function and returns an output of HASHLEN bytes. +HASH :: proc(dst: []byte, protocol: ^Protocol, data: ..[]byte) { + ctx: hash.Context + hash.init(&ctx, protocol.hash) + + for datum in data { + hash.update(&ctx, datum) + } + + hash.final(&ctx, dst) +} + +// Takes a chaining_key byte sequence of length HASHLEN, and an +// input_key_material byte sequence with length either zero bytes, +// 32 bytes, or DHLEN bytes. Returns a pair or triple of byte sequences +// each of length HASHLEN, depending on whether num_outputs is two or three: +// - Sets temp_key = HMAC-HASH(chaining_key, input_key_material). +// - Sets output1 = HMAC-HASH(temp_key, byte(0x01)). +// - Sets output2 = HMAC-HASH(temp_key, output1 || byte(0x02)). +// - If num_outputs == 2 then returns the pair (output1, output2). +// - Sets output3 = HMAC-HASH(temp_key, output2 || byte(0x03)). +// - Returns the triple (output1, output2, output3). +// +// Note that temp_key, output1, output2, and output3 are all HASHLEN +// bytes in length. Also note that the HKDF() function is simply HKDF +// from [4] with the chaining_key as HKDF salt, and zero-length HKDF info. +@(require_results) +HKDF :: proc(dst, chaining_key, input_key_material: []byte, protocol: ^Protocol) -> ([]byte, []byte, []byte) { + assert(len(input_key_material) == 0 || len(input_key_material) == 32 || len(input_key_material) == dh_len(protocol)) + + hkdf.extract_and_expand(protocol.hash, chaining_key, input_key_material, nil, dst) + + h_len := hash_len(protocol) + assert(len(dst) == h_len * 2 || len(dst) == h_len * 3) + + r1, r2 := dst[:h_len], dst[h_len:h_len*2] + if len(dst) == h_len * 2 { + return r1, r2, nil + } + return r1, r2, dst[h_len*2:] +} + +// Sets k = key. Sets n = 0. +cipherstate_InitializeKey :: proc(self: ^Cipher_State, key: []byte, protocol: ^Protocol) { + k_len := len(key) + switch { + case k_len == 0: + // k = empty + aead.reset(&self.ctx) + self.n = 0 + case k_len < AEAD_KEY_SIZE: + panic("crypto/noise: invalid AEAD key size") + case: + aead.init(&self.ctx, protocol.cipher, key[:AEAD_KEY_SIZE]) + self.n = 0 + } +} + +// Returns true if k is non-empty, false otherwise. +@(require_results) +cipherstate_HasKey :: proc(self: ^Cipher_State) -> bool { + return aead.algorithm(&self.ctx) != .Invalid +} + +// If k is non-empty returns ENCRYPT(k, n++, ad, plaintext). Otherwise +// returns plaintext. +@(require_results) +cipherstate_EncryptWithAd :: proc(self: ^Cipher_State, ad, plaintext, dst: []byte) -> ([]byte, Status) { + if self.is_invalid { + return nil, .Invalid_Cipher_State + } + if self.n_exhausted { + return nil, .IV_Exhausted + } + + pt_len := len(plaintext) + if pt_len > MAX_PACKET_SIZE - 16 { + return nil, .Max_Packet_Size + } + + if cipherstate_HasKey(self) { + if len(dst) != pt_len + TAG_SIZE { + return nil, .Invalid_Destination_Buffer + } + ENCRYPT(&self.ctx, self.n, ad, plaintext, dst) + self.n += 1 + if self.n == 0 { + self.n_exhausted = true + } + } else { + if len(dst) != pt_len { + return nil, .Invalid_Destination_Buffer + } + if raw_data(dst) != raw_data(plaintext) { + copy(dst, plaintext) + } + } + + return dst, .Ok +} + +// If k is non-empty returns DECRYPT(k, n++, ad, ciphertext). Otherwise +// returns ciphertext. If an authentication failure occurs in DECRYPT() +// then n is not incremented and an error is signaled to the caller. +@(require_results) +cipherstate_DecryptWithAd :: proc(self: ^Cipher_State, ad, ciphertext, dst: []byte) -> ([]byte, Status) { + if self.is_invalid { + return nil, .Invalid_Cipher_State + } + if self.n_exhausted { + return nil, .IV_Exhausted + } + + if cipherstate_HasKey(self) { + if status := DECRYPT(&self.ctx, self.n, ad, ciphertext, dst); status != .Ok { + return nil, status + } + self.n += 1 + if self.n == 0 { + self.n_exhausted = true + } + } else { + if len(dst) != len(ciphertext) { + return nil, .Invalid_Destination_Buffer + } + if raw_data(dst) != raw_data(ciphertext) { + copy(dst, ciphertext) + } + } + + return dst, .Ok +} + +// Sets k = REKEY(k). +cipherstate_Rekey :: proc(self: ^Cipher_State) { + if cipherstate_HasKey(self) { + algorithm := aead.algorithm(&self.ctx) + + // The "sensible" way to implement this is to inlike REKEY(k), + // so we do. + // + // Returns a new 32-byte cipher key as a pseudorandom function + // of k. If this function is not specifically defined for some + // set of cipher functions, then it defaults to returning the + // first 32 bytes from `ENCRYPT(k, maxnonce, zerolen, zeros)`, + // where maxnonce equals (2^64)-1, zerolen is a zero-length + // byte sequence, and zeros is a sequence of 32 bytes filled + // with zeros. + + zeroes: [AEAD_KEY_SIZE + TAG_SIZE]byte + defer crypto.zero_explicit(&zeroes, size_of(zeroes)) + + // 1 2 3 4 5 6 7 8 + n: u64 = 0xFF_FF_FF_FF_FF_FF_FF_FF + ENCRYPT(&self.ctx, n, nil, zeroes[:AEAD_KEY_SIZE], zeroes[:]) + aead.init(&self.ctx, algorithm, zeroes[:AEAD_KEY_SIZE]) + } +} + +cipherstate_reset :: proc(self: ^Cipher_State) { + aead.reset(&self.ctx) + crypto.zero_explicit(self, size_of(Cipher_State)) +} + +// Takes an arbitrary-length protocol_name byte sequence (see Section 8). +// Executes the following steps: +// - If protocol_name is less than or equal to HASHLEN bytes in length, +// sets h equal to protocol_name with zero bytes appended to make +// HASHLEN bytes. +// - Otherwise sets h = HASH(protocol_name). +// - Sets ck = h. +// - Calls InitializeKey(empty). +@(require_results) +symmetricstate_Initialize :: proc(ss: ^Symmetric_State, protocol_name: string) -> Status { + if status := protocol_from_string(&ss.protocol, protocol_name); status != .Ok { + return status + } + + cipherstate_InitializeKey(&ss.cipher_state, nil, &ss.protocol) + + h_len := hash_len(&ss.protocol) + h := ss._h[:h_len] + if len(protocol_name) <= h_len { + copy(h, protocol_name) + } else { + HASH(h, &ss.protocol, transmute([]byte)protocol_name) + } + + copy(ss._ck[:h_len], h) + + return .Ok +} + +// Sets h = HASH(h || data). +symmetricstate_MixHash :: proc(self: ^Symmetric_State, data: ..[]byte) { + h := self._h[:hash_len(&self.protocol)] + if len(data) == 1 { + HASH(h, &self.protocol, h, data[0]) + } else if len(data) == 2 { + HASH(h, &self.protocol, h, data[0], data[1]) + } else if len(data) == 3 { + HASH(h, &self.protocol, h, data[0], data[1], data[2]) + } else { + panic("crypto/noise: invalid MixHash inputs") + } +} + +// Executes the following steps: +// - Sets ck, temp_k = HKDF(ck, input_key_material, 2). +// - If HASHLEN is 64, then truncates temp_k to 32 bytes. +// - Calls InitializeKey(temp_k). +symmetricstate_MixKey :: proc(self: ^Symmetric_State, input_key_material: []byte) { + h_len := hash_len(&self.protocol) + + dst_len := h_len * 2 + dst: [2*MAX_HASH_SIZE]byte = --- + defer crypto.zero_explicit(&dst, dst_len) + + ck, temp_k, _ := HKDF(dst[:dst_len], self._ck[:h_len], input_key_material, &self.protocol) + copy(self._ck[:], ck) + cipherstate_InitializeKey(&self.cipher_state, temp_k, &self.protocol) +} + +// This function is used for handling pre-shared symmetric keys, as described +// in Section 9. It executes the following steps: +// - Sets ck, temp_h, temp_k = HKDF(ck, input_key_material, 3). +// - Calls MixHash(temp_h). +// - If HASHLEN is 64, then truncates temp_k to 32 bytes. +// - Calls InitializeKey(temp_k). +symmetricstate_MixKeyAndHash :: proc(self: ^Symmetric_State, input_key_material: []byte) { + h_len := hash_len(&self.protocol) + + dst_len := h_len * 3 + dst: [3*MAX_HASH_SIZE]byte = --- + defer crypto.zero_explicit(&dst, dst_len) + + ck, temp_h, temp_k := HKDF(dst[:dst_len], self._ck[:h_len], input_key_material, &self.protocol) + copy(self._ck[:], ck) + symmetricstate_MixHash(self, temp_h) + cipherstate_InitializeKey(&self.cipher_state, temp_k, &self.protocol) +} + +// Returns h. This function should only be called at the end of a handshake, +// i.e. after the Split() function has been called. +// +// This function is used for channel binding, as described in Section 11.2 +@(require_results) +symmetricstate_GetHandshakeHash :: proc(self: ^Symmetric_State) -> []byte { + return self._h[:hash_len(&self.protocol)] +} + +// Sets ciphertext = EncryptWithAd(h, plaintext), calls MixHash(ciphertext), +// and returns ciphertext. +// +// Note that if k is empty, the EncryptWithAd() call will set ciphertext +// equal to plaintext. +@(require_results) +symmetricstate_EncryptAndHash :: proc(self: ^Symmetric_State, plaintext, dst: []byte) -> ([]byte, Status) { + ciphertext, status := cipherstate_EncryptWithAd(&self.cipher_state, self._h[:hash_len(&self.protocol)], plaintext, dst) + if status != .Ok { + return nil, status + } + symmetricstate_MixHash(self, ciphertext) + return ciphertext, status +} + +// Sets plaintext = DecryptWithAd(h, ciphertext), calls MixHash(ciphertext), +// and returns plaintext. +// +// Note that if k is empty, the DecryptWithAd() call will set plaintext +// equal to ciphertext. +@(require_results) +symmetricstate_DecryptAndHash :: proc(self: ^Symmetric_State, ciphertext, dst: []byte) -> ([]byte, Status) { + h_len := hash_len(&self.protocol) + + h: [MAX_HASH_SIZE]byte = --- + copy(h[:], self._h[:h_len]) + defer crypto.zero_explicit(&h, size_of(h)) + + // We reverse the order to save having to copy the ciphertext, in + // the case that ciphertext and dst alias. + symmetricstate_MixHash(self, ciphertext) + return cipherstate_DecryptWithAd(&self.cipher_state, h[:h_len], ciphertext, dst) +} + +// Returns a pair of CipherState objects for encrypting transport messages. +// Executes the following steps, where zerolen is a zero-length byte sequence: +// - Sets temp_k1, temp_k2 = HKDF(ck, zerolen, 2). +// - If HASHLEN is 64, then truncates temp_k1 and temp_k2 to 32 bytes. +// - Creates two new CipherState objects c1 and c2. +// - Calls c1.InitializeKey(temp_k1) and c2.InitializeKey(temp_k2). +// - Returns the pair (c1, c2). +symmetricstate_Split :: proc(self: ^Symmetric_State, cipher_states: ^Cipher_States) { + h_len := hash_len(&self.protocol) + + dst_len := h_len * 2 + dst: [2*MAX_HASH_SIZE]byte = --- + defer crypto.zero_explicit(&dst, dst_len) + + temp_k1, temp_k2, _ := HKDF(dst[:dst_len], self._ck[:h_len], nil, &self.protocol) + cipherstate_InitializeKey(&cipher_states.c1_i_to_r, temp_k1, &self.protocol) + cipherstate_InitializeKey(&cipher_states.c2_r_to_i, temp_k2, &self.protocol) +} + +symmetricstate_reset :: proc(self: ^Symmetric_State) { + cipherstate_reset(&self.cipher_state) + + crypto.zero_explicit(self, size_of(Symmetric_State)) +} + +// Takes a valid handshake_pattern (see Section 7) and an initiator boolean +// specifying this party's role as either initiator or responder. +// Takes a prologue byte sequence which may be zero-length, or which may +// contain context information that both parties want to confirm is identical +// (see Section 6). +// +// Takes a set of DH key pairs (s, e) and public keys (rs, re) for +// initializing local variables, any of which may be empty. Public keys +// are only passed in if the handshake_pattern uses pre-messages +// (see Section 7). The ephemeral values (e, re) are typically left empty, +// since they are created and exchanged during the handshake; but there +// are exceptions (see Section 10). +// +// Performs the following steps: +// - Derives a protocol_name byte sequence by combining the names for +// the handshake pattern and crypto functions, as specified in Section 8. +// - Calls InitializeSymmetric(protocol_name). +// - Calls MixHash(prologue). +// - Sets the initiator, s, e, rs, and re variables to the corresponding +// arguments. +// - Calls MixHash() once for each public key listed in the pre-messages +// from handshake_pattern, with the specified public key as input +// (see Section 7 for an explanation of pre-messages). +// - If both initiator and responder have pre-messages, the initiator's +// public keys are hashed first. +// - If multiple public keys are listed in either party's pre-message, +// the public keys are hashed in the order that they are listed. +// - Sets message_pattern to the message patterns from handshake_pattern. +@(require_results) +handshakestate_Initialize :: proc( + handshake_state: ^Handshake_State, + initiator: bool, + prologue: []byte, + s: ^ecdh.Private_Key, + e: ^ecdh.Private_Key, // Only set for testing. + rs: ^ecdh.Public_Key, + re: ^ecdh.Public_Key, // Only set for testing. + protocol_name: string, + psk: []byte = nil, +) -> Status { + crypto.zero_explicit(handshake_state, size_of(Handshake_State)) + + symmetric_state := &handshake_state.symmetric_state + status: Status + do_init: { + if status = symmetricstate_Initialize(symmetric_state, protocol_name); status != .Ok { + break do_init + } + + curve := symmetric_state.protocol.dh + if s != nil && ecdh.curve(s) != curve { + status = .Invalid_DH_Key + break do_init + } + if e != nil && ecdh.curve(e) != curve { + status = .Invalid_DH_Key + break do_init + } + if rs != nil && ecdh.curve(rs) != curve { + status = .Invalid_DH_Key + break do_init + } + if re != nil && ecdh.curve(re) != curve { + status = .Invalid_DH_Key + break do_init + } + + // Check if we will require s later down the line. + s_pre, s_hs: bool + if initiator { + s_pre, s_hs = pattern_requires_initiator_s(symmetric_state.protocol.handshake_pattern) + } else { + s_pre, s_hs = pattern_requires_responder_s(symmetric_state.protocol.handshake_pattern) + } + if (s_pre || s_hs) && s == nil { + status = .No_Self_Identity + break do_init + } + + message_pattern := HANDSHAKE_PATTERNS[symmetric_state.protocol.handshake_pattern] + if message_pattern.pre_messages != nil { + if initiator { + if slice.contains(message_pattern.pre_messages, Pre_Token.res_s) { + if rs == nil { + status = .No_Peer_Identity + break do_init + } + } + } else { + if slice.contains(message_pattern.pre_messages, Pre_Token.ini_s) { + if rs == nil { + status = .No_Peer_Identity + break do_init + } + } + } + } else { + if rs != nil { + status = .Unexpected_Peer_Identity + break do_init + } + } + + symmetricstate_MixHash(symmetric_state, prologue) + + // In all supported patterns, `ini_s` will always precede `res_s`. + if message_pattern.pre_messages != nil { + tmp: [MAX_DH_SIZE]byte = --- + d_len := dh_len(&symmetric_state.protocol) + dst := tmp[:d_len] + + if initiator { + if slice.contains(message_pattern.pre_messages, Pre_Token.ini_s) { + ecdh.public_key_bytes(&s._pub_key, dst) + symmetricstate_MixHash(symmetric_state, dst) + } + if slice.contains(message_pattern.pre_messages, Pre_Token.res_s) { + ecdh.public_key_bytes(rs, dst) + symmetricstate_MixHash(symmetric_state, dst) + } + } else { + if slice.contains(message_pattern.pre_messages, Pre_Token.ini_s) { + ecdh.public_key_bytes(rs, dst) + symmetricstate_MixHash(symmetric_state, dst) + } + if slice.contains(message_pattern.pre_messages, Pre_Token.res_s) { + ecdh.public_key_bytes(&s._pub_key, dst) + symmetricstate_MixHash(symmetric_state, dst) + } + } + } + if message_pattern.is_psk { + if len(psk) != PSK_SIZE { + status = .Invalid_Pre_Shared_Key + break do_init + } + } else if len(psk) != 0 { + status = .Unexpected_Pre_Shared_Key + break do_init + } + } + if status != .Ok { + symmetricstate_reset(symmetric_state) + return status + } + + if s != nil { + ecdh.private_key_set(&handshake_state.s, s) + } + if e != nil { + ecdh.private_key_set(&handshake_state.e, e) + handshake_state.pre_set_e = true + } + if rs != nil { + ecdh.public_key_set(&handshake_state.rs, rs) + } + if re != nil { + ecdh.public_key_set(&handshake_state.re, re) + } + copy(handshake_state.psk[:], psk) + handshake_state.message_pattern = HANDSHAKE_PATTERNS[symmetric_state.protocol.handshake_pattern] + handshake_state.current_message = 0 + handshake_state.status = .Handshake_Pending + handshake_state.initiator = initiator + + return .Ok +} + +handshakestate_reset :: proc(self: ^Handshake_State) { + symmetricstate_reset(&self.symmetric_state) + ecdh.private_key_clear(&self.s) + ecdh.private_key_clear(&self.e) + + crypto.zero_explicit(self, size_of(Handshake_State)) +} + +// Takes a payload byte sequence which may be zero-length, and a +// message_buffer to write the output into. +// Performs the following steps, aborting if any EncryptAndHash() call +// returns an error: +// - Fetches and deletes the next message pattern from message_pattern, +// then sequentially processes each token from the message pattern: +// - For "e": Sets e (which must be empty) to GENERATE_KEYPAIR(). +// Appends e.public_key to the buffer. Calls MixHash(e.public_key). +// - For "s": Appends EncryptAndHash(s.public_key) to the buffer. +// - For "ee": Calls MixKey(DH(e, re)). +// - For "es": Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) +// if responder. +// - For "se": Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) +// if responder. +// - For "ss": Calls MixKey(DH(s, rs)). +// - Appends EncryptAndHash(payload) to the buffer. +// – (SKIPPED) If there are no more message patterns returns two new +// CipherState objects by calling Split(). +// +// Calling Split() is left to a separate function, although it is technically +// part of the specification. +@(require_results) +handshakestate_WriteMessage :: proc(self: ^Handshake_State, payload, dst: []byte, allocator := context.allocator) -> ([]byte, Status) { + ensure(self.status == .Handshake_Pending, "crypto/noise: invalid state for WriteMessage") + + protocol := &self.symmetric_state.protocol + d_len := dh_len(protocol) + + pattern_buf: [dynamic; MAX_STEP_MSG_SIZE]byte + dh_buf: [MAX_DH_SIZE]byte = --- + defer crypto.zero_explicit(&dh_buf, size_of(dh_buf)) + + pattern := self.message_pattern.messages[self.current_message] + for token in pattern { + switch token { + case .e: + switch self.pre_set_e { + case true: + // Note: "which must be empty", but we allow pre-generated `e` + // for testing/rng-less systems. + self.pre_set_e = false + case false: + if ecdh.curve(&self.e) != .Invalid { + panic("crypto/noise: e was not empty when processing token 'e' during WriteMessage") + } + GENERATE_KEYPAIR(protocol, &self.e) + } + e_public := dh_buf[:d_len] + ecdh.public_key_bytes(&self.e._pub_key, e_public) + n := append(&pattern_buf, ..e_public) + ensure(n == d_len, "crypto/noise: truncated append `e`") + + symmetricstate_MixHash(&self.symmetric_state, e_public) + if self.message_pattern.is_psk { + symmetricstate_MixKey(&self.symmetric_state, e_public) + } + + case .s: + s_public := dh_buf[:d_len] + ecdh.public_key_bytes(&self.s._pub_key, s_public) + + tmp: [MAX_DH_SIZE+TAG_SIZE]byte = --- + dh_buf := tmp[:d_len+TAG_SIZE] + if !cipherstate_HasKey(&self.symmetric_state.cipher_state) { + dh_buf = tmp[:d_len] + } + ct, status := symmetricstate_EncryptAndHash(&self.symmetric_state, s_public, dh_buf) + if status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + n := append(&pattern_buf, ..ct) + ensure(n == len(ct), "crypto/noise: truncated append `s`") + + case .ee: + dh := dh_buf[:d_len] + if status := DH(&self.e, &self.re, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + + case .es: + dh := dh_buf[:d_len] + if self.initiator { + if status := DH(&self.e, &self.rs, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + } else { + if status := DH(&self.s, &self.re, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + } + + case .se: + dh := dh_buf[:d_len] + if self.initiator { + if status := DH(&self.s, &self.re, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + } else { + if status := DH(&self.e, &self.rs, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + } + + case .ss: + dh := dh_buf[:d_len] + if status := DH(&self.s, &self.rs, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + + case .psk: + symmetricstate_MixKeyAndHash(&self.symmetric_state, self.psk[:]) + } + } + self.current_message += 1 // Advance after the current message is successful. + + pattern_len := len(pattern_buf) + payload_len := len(payload) + msg_len := pattern_len + payload_len + if payload_len != 0 && cipherstate_HasKey(&self.symmetric_state.cipher_state) { + msg_len += TAG_SIZE + } + + msg: []byte + if msg_len != 0 { + did_alloc: bool + if dst != nil { + if len(dst) < msg_len { + self.status = .Handshake_Failed + return nil, .Out_Of_Memory + } + msg = dst[:msg_len] + } else { + err: runtime.Allocator_Error + msg, err = make([]byte, msg_len, allocator) + if err != nil { + self.status = .Handshake_Failed + return nil, .Out_Of_Memory + } + did_alloc = true + } + + copy(msg, pattern_buf[:]) + if payload_len != 0 { + ciphertext := msg[pattern_len:] + if _, status := symmetricstate_EncryptAndHash(&self.symmetric_state, payload, ciphertext); status != .Ok { + if did_alloc { + delete(msg) + } + self.status = .Handshake_Failed + return nil, status + } + } + } + + if self.current_message == len(self.message_pattern.messages) { + self.current_message = -1 + self.status = .Handshake_Complete + } + + return msg, self.status +} + +// Takes a byte sequence containing a Noise handshake message, and a +// payload_buffer to write the message's plaintext payload into. +// Performs the following steps, aborting if any DecryptAndHash() +// call returns an error: +// - Fetches and deletes the next message pattern from message_pattern, +// then sequentially processes each token from the message pattern: +// - For "e": Sets re (which must be empty) to the next DHLEN bytes +// from the message. Calls MixHash(re.public_key). +// - For "s": Sets temp to the next DHLEN + 16 bytes of the message +// if HasKey() == True, or to the next DHLEN bytes otherwise. +// Sets rs (which must be empty) to DecryptAndHash(temp). +// - For "ee": Calls MixKey(DH(e, re)). +// - For "es": Calls MixKey(DH(e, rs)) if initiator, MixKey(DH(s, re)) +// if responder. +// - For "se": Calls MixKey(DH(s, re)) if initiator, MixKey(DH(e, rs)) +// if responder. +// -For "ss": Calls MixKey(DH(s, rs)). +// - Calls DecryptAndHash() on the remaining bytes of the message and stores +// the output into payload_buffer. +// – (SKIPPED) If there are no more message patterns returns two new +// CipherState objects by calling Split(). +// +// Calling Split() is left to a separate function, although it is technically +// part of the specification. +@(require_results) +handshakestate_ReadMessage :: proc(self: ^Handshake_State, message, dst: []byte, allocator := context.allocator) -> ([]byte, Status) { + ensure(self.status == .Handshake_Pending, "crypto/noise: invalid state for ReadMessage") + + if len(message) < MIN_DH_SIZE { + return nil, .Invalid_Handshake_Message + } + + protocol := &self.symmetric_state.protocol + d_len := dh_len(&self.symmetric_state.protocol) + + dh_buf: [MAX_DH_SIZE]byte = --- + defer crypto.zero_explicit(&dh_buf, size_of(dh_buf)) + + msg := message + + pattern := self.message_pattern.messages[self.current_message] + for token in pattern { + switch token { + case .e: + if len(msg) < d_len { + return nil, .Invalid_Handshake_Message + } + re := msg[:d_len] + + if ecdh.curve(&self.re) != .Invalid { + panic("crypto/noise: re was not empty when processing token 'e' during ReadMessage") + } + + ecdh.public_key_set_bytes(&self.re, protocol.dh, re) + symmetricstate_MixHash(&self.symmetric_state, re) + if self.message_pattern.is_psk { + symmetricstate_MixKey(&self.symmetric_state, re) + } + msg = msg[d_len:] + + case .s: + rs_len := d_len + if cipherstate_HasKey(&self.symmetric_state.cipher_state) { + rs_len += TAG_SIZE + } + if len(msg) < rs_len { + self.status = .Handshake_Failed + return nil, .Invalid_Handshake_Message + } + + rs := dh_buf[:d_len] + if _, status := symmetricstate_DecryptAndHash(&self.symmetric_state, msg[:rs_len], rs); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + + if ecdh.curve(&self.rs) != .Invalid { + panic("crypto/noise: rs was not empty when processing token 's' during ReadMessage") + } + + ecdh.public_key_set_bytes(&self.rs, protocol.dh, rs) + msg = msg[rs_len:] + + case .ee: + dh := dh_buf[:d_len] + if status := DH(&self.e, &self.re, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + + case .es: + dh := dh_buf[:d_len] + if self.initiator { + if status := DH(&self.e, &self.rs, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + } else { + if status := DH(&self.s, &self.re, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + } + + case .se: + dh := dh_buf[:d_len] + if self.initiator { + if status := DH(&self.s, &self.re, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + } else { + if status := DH(&self.e, &self.rs, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + } + + case .ss: + dh := dh_buf[:d_len] + if status := DH(&self.s, &self.rs, dh); status != .Ok { + self.status = .Handshake_Failed + return nil, status + } + symmetricstate_MixKey(&self.symmetric_state, dh) + + case .psk: + symmetricstate_MixKeyAndHash(&self.symmetric_state, self.psk[:]) + } + } + self.current_message += 1 // Advance after the current message is successful. + + payload: []byte + if len(msg) > 0 { + payload_len := len(msg) + if cipherstate_HasKey(&self.symmetric_state.cipher_state) { + if payload_len < TAG_SIZE { + self.status = .Handshake_Failed + return nil, self.status + } + payload_len -= TAG_SIZE + } + + did_alloc: bool + if dst != nil { + if len(dst) < payload_len { + self.status = .Handshake_Failed + return nil, .Out_Of_Memory + } + payload = dst[:payload_len] + } else { + err: runtime.Allocator_Error + payload, err = make([]byte, payload_len, allocator) + if err != nil { + self.status = .Handshake_Failed + return nil, .Out_Of_Memory + } + did_alloc = true + } + + if _, status := symmetricstate_DecryptAndHash(&self.symmetric_state, msg, payload); status != .Ok { + if did_alloc { + delete(payload) + } + self.status = .Handshake_Failed + return nil, self.status + } + } + + if self.current_message == len(self.message_pattern.messages) { + self.current_message = -1 + self.status = .Handshake_Complete + } + + return payload, self.status +} + +@(require_results) +protocol_from_string :: proc(self: ^Protocol, protocol_name: string) -> Status { + str := protocol_name + self^ = Protocol{} + + if len(str) > 255 { + return .Invalid_Protocol_String + } + + s, ok := strings.split_by_byte_iterator(&str, '_') + if !ok || s != "Noise" { + return .Invalid_Protocol_String + } + + if s, ok = strings.split_by_byte_iterator(&str, '_'); !ok { + return .Invalid_Protocol_String + } + pattern: Handshake_Pattern + switch s { + case "N" : pattern = .N + case "K" : pattern = .K + case "X" : pattern = .X + case "XX": pattern = .XX + case "NK": pattern = .NK + case "NN": pattern = .NN + case "KN": pattern = .KN + case "KK": pattern = .KK + case "NX": pattern = .NX + case "KX": pattern = .KX + case "XN": pattern = .XN + case "IN": pattern = .IN + case "XK": pattern = .XK + case "IK": pattern = .IK + case "IX": pattern = .IX + case "Npsk0": pattern = .Npsk0 + case "Kpsk0": pattern = .Kpsk0 + case "Xpsk1": pattern = .Xpsk1 + case "NNpsk0": pattern = .NNpsk0 + case "NNpsk2": pattern = .NNpsk2 + case "NKpsk0": pattern = .NKpsk0 + case "NKpsk2": pattern = .NKpsk2 + case "NXpsk2": pattern = .NXpsk2 + case "XNpsk3": pattern = .XNpsk3 + case "XKpsk3": pattern = .XKpsk3 + case "XXpsk3": pattern = .XXpsk3 + case "KNpsk0": pattern = .KNpsk0 + case "KNpsk2": pattern = .KNpsk2 + case "KKpsk0": pattern = .KKpsk0 + case "KKpsk2": pattern = .KKpsk2 + case "KXpsk2": pattern = .KXpsk2 + case "INpsk1": pattern = .INpsk1 + case "INpsk2": pattern = .INpsk2 + case "IKpsk1": pattern = .IKpsk1 + case "IKpsk2": pattern = .IKpsk2 + case "IXpsk2": pattern = .IXpsk2 + case: return .Invalid_Protocol_String + } + + if s, ok = strings.split_by_byte_iterator(&str, '_'); !ok { + return .Invalid_Protocol_String + } + dh: ecdh.Curve + switch s { + case "25519": dh = .X25519 + case "448": dh = .X448 + case: return .Invalid_Protocol_String + } + + if s, ok = strings.split_by_byte_iterator(&str, '_'); !ok { + return .Invalid_Protocol_String + } + cipher: aead.Algorithm + switch s { + case "AESGCM": cipher = .AES_GCM_256 + case "ChaChaPoly": cipher = .CHACHA20POLY1305 + case: return .Invalid_Protocol_String + } + + if s, ok = strings.split_by_byte_iterator(&str, '_'); !ok { + return .Invalid_Protocol_String + } + hash: hash.Algorithm + switch s { + case "SHA512": hash = .SHA512 + case "SHA256": hash = .SHA256 + case "Blake2s": hash = .BLAKE2S + case "Blake2b": hash = .BLAKE2B + case: return .Invalid_Protocol_String + } + + if len(str) != 0 { + return .Invalid_Protocol_String + } + + self.handshake_pattern = pattern + self.dh = dh + self.cipher = cipher + self.hash = hash + + return .Ok +} diff --git a/core/crypto/noise/test_crypto_noise.odin b/core/crypto/noise/test_crypto_noise.odin new file mode 100644 index 000000000..50a540f97 --- /dev/null +++ b/core/crypto/noise/test_crypto_noise.odin @@ -0,0 +1,305 @@ +package noise + +import "core:bytes" +import "core:crypto" +import "core:crypto/aead" +import "core:crypto/ecdh" +import "core:crypto/hash" +import "core:fmt" +import "core:log" +import "core:math/rand" +import "core:testing" + +@(private = "file") +DH_CURVES :: []ecdh.Curve { + .X25519, + .X448, +} +@(private = "file") +CIPHERS :: []aead.Algorithm{ + .AES_GCM_256, + .CHACHA20POLY1305, +} +@(private = "file") +HASHES :: []hash.Algorithm{ + .SHA256, + .SHA512, + .BLAKE2S, + .BLAKE2B, +} + +@(test) +test_supported_protocols :: proc(t: ^testing.T) { + if !crypto.HAS_RAND_BYTES { + log.info("rand_bytes not supported - skipping") + return + } + + protocol: Test_Protocol + for pattern in Handshake_Pattern { + if pattern == .Invalid { + continue + } + protocol.handshake_pattern = pattern + for dh in DH_CURVES { + protocol.dh = dh + for cipher in CIPHERS { + protocol.cipher = cipher + for hash in HASHES { + protocol.hash = hash + if !testing.expectf( + t, + test_noise_one_protocol(t, &protocol, context.temp_allocator), + "Failed protocol: %v", protocol, + ) { + testing.fail(t) + break + } + } + } + } + } +} + +@(private = "file") +test_noise_one_protocol :: proc(t: ^testing.T, protocol: ^Test_Protocol, allocator := context.allocator) -> bool { + protocol_name := test_protocol_string(protocol, allocator) + defer delete(protocol_name, allocator) + + log.debugf("crypto/noise: %s", protocol_name) + + is_one_way := pattern_is_one_way(protocol.handshake_pattern) + + initiator_s, responder_s: ecdh.Private_Key + ini_s, res_s: ^ecdh.Private_Key + ini_s_pub, res_s_pub: ^ecdh.Public_Key + + pre, hs := pattern_requires_initiator_s(protocol.handshake_pattern) + if pre || hs { + if !testing.expect(t, ecdh.private_key_generate(&initiator_s, protocol.dh), "failed to generate initiator s") { + return false + } + ini_s = &initiator_s + if pre { + ini_s_pub = &initiator_s._pub_key + } + } + pre, hs = pattern_requires_responder_s(protocol.handshake_pattern) + if pre || hs { + if !testing.expect(t, ecdh.private_key_generate(&responder_s, protocol.dh), "failed to generate responder s") { + return false + } + res_s = &responder_s + if pre { + res_s_pub = &responder_s._pub_key + } + } + + psk_buf: [32]byte = --- + psk: []byte + if pattern_is_psk(protocol.handshake_pattern) { + crypto.rand_bytes(psk_buf[:]) + psk = psk_buf[:] + } + + ini_hs, res_hs: Handshake_State + status := handshake_init(&ini_hs, true, nil, ini_s, res_s_pub, protocol_name, psk) + if !testing.expectf(t, status == .Ok, "failed to initialize initiator Handshake_State: %v", status) { + return false + } + status = handshake_init(&res_hs, false, nil, res_s, ini_s_pub, protocol_name, psk) + if !testing.expectf(t, status == .Ok, "failed to initialize responder Handshake_State: %v", status) { + return false + } + + ini_status, res_status: Status + ini_msg, res_msg: []byte + ini_payload, res_payload: []byte + hs_msg_buf: [MAX_STEP_MSG_SIZE]byte + for i := 0; ; i += 1{ + if ini_status == .Handshake_Complete && res_status == .Handshake_Complete { + break + } + + // Test the allocation path + res_msg, res_payload, ini_status = handshake_initiator_step(&ini_hs, ini_msg, allocator = allocator) + ini_msg = nil + + if ini_status == .Handshake_Complete && res_status == .Handshake_Complete { + break + } + + if !testing.expectf(t, len(res_payload) == 0, "step %d: unexpected responder payload: %x", i, res_payload) { + return false + } + if !testing.expectf(t, ini_status == .Handshake_Pending || ini_status == .Handshake_Complete, "step %d: initiator step failed: %v", i, ini_status) { + return false + } + + // Test the non-allocation path + ini_msg, ini_payload, res_status = handshake_responder_step(&res_hs, res_msg, nil, hs_msg_buf[:]) + delete(res_msg, allocator) + res_msg = nil + + if !testing.expectf(t, len(ini_payload) == 0, "step %d: unexpected initiator payload: %x", i, ini_payload) { + return false + } + if !testing.expectf(t, res_status == .Handshake_Pending || res_status == .Handshake_Complete, "step %d: responder step failed: %v", i, res_status) { + return false + } + } + delete(res_msg, allocator) + + hs_pub: ^ecdh.Public_Key + if ini_s != nil { + hs_pub, status = handshake_peer_identity(&res_hs) + if !testing.expect(t, status == .Ok) { + return false + } + if !testing.expectf(t, ecdh.public_key_equal(&ini_s._pub_key, hs_pub), "responder has incorrect initiator identity") { + return false + } + } + if res_s != nil { + hs_pub, status = handshake_peer_identity(&ini_hs) + if !testing.expect(t, status == .Ok) { + return false + } + if !testing.expectf(t, ecdh.public_key_equal(&res_s._pub_key, hs_pub), "initiator has incorrect responder identity") { + return false + } + } + + h1, h2: []byte + h1, status = handshake_hash(&ini_hs) + if !testing.expect(t, status == .Ok) { + return false + } + h2, status = handshake_hash(&res_hs) + if !testing.expect(t, status == .Ok) { + return false + } + if !testing.expectf(t, bytes.equal(h1, h2), "handshake hash mismatch: %x != %x", h1, h2) { + return false + } + + ini_cs, res_cs: Cipher_States + if !testing.expectf(t, .Ok == handshake_split(&ini_hs, &ini_cs), "failed to split initiator: %v") { + return false + } + if !testing.expectf(t, .Ok == handshake_split(&res_hs, &res_cs), "failed to split responder: %v") { + return false + } + + handshake_reset(&ini_hs) + handshake_reset(&res_hs) + + if !testing.expect(t, test_messages(t, &ini_cs, &res_cs, is_one_way, allocator), "message tests failed") { + return false + } + + cipherstates_reset(&ini_cs) + cipherstates_reset(&res_cs) + + return true +} + +@(private = "file") +test_messages :: proc(t: ^testing.T, ini_cs, res_cs: ^Cipher_States, is_one_way: bool, allocator := context.allocator) -> bool { + ad_buf: [256]byte = --- + payload_buf: [MAX_PACKET_SIZE-TAG_SIZE]byte = --- + + for i in 0..<10 { + ad := ad_buf[:rand.int_max(len(ad_buf))] + payload := payload_buf[:rand.int_max(len(payload_buf))] + + _ = rand.read(payload) + _ = rand.read(ad) + + // Initiator -> Responder (allocate buffers) + tx_msg, status := seal_message(ini_cs, ad, payload, allocator = allocator) + defer delete(tx_msg, allocator) + if !testing.expectf(t, status == .Ok, "i->r %d: seal failed: %v", i, status) { + return false + } + + rx_dst: []byte + rx_dst, status = open_message(res_cs, ad, tx_msg, allocator = allocator) + defer delete(rx_dst, allocator) + if !testing.expectf(t, status == .Ok, "i->r %d: open failed: %v", i, status) { + return false + } + + if !testing.expectf(t, bytes.equal(rx_dst, payload), "i->r %d: payload mismatch") { + return false + } + + if i == 5 { + status = cipherstates_rekey(ini_cs, true) + if !testing.expectf(t, status == .Ok, "i %d: rekey failed: %v", i, status) { + return false + } + status = cipherstates_rekey(res_cs, false) + if !testing.expectf(t, status == .Ok, "r %d: rekey failed: %v", i, status) { + return false + } + } + + if is_one_way { + continue + } + + // Responder -> Initiator (reuse allocated buffers) + tx_msg, status = seal_message(res_cs, ad, payload, tx_msg) + if !testing.expectf(t, status == .Ok, "r->i %d: seal failed: %v", i, status) { + return false + } + + _, status = open_message(ini_cs, ad, tx_msg, rx_dst) + if !testing.expectf(t, status == .Ok, "r->i %d: open failed: %v", i, status) { + return false + } + + if !testing.expectf(t, bytes.equal(rx_dst, payload), "r-i %d: payload mismatch") { + return false + } + } + + return true +} + +@(private = "file") +Test_Protocol :: struct { + handshake_pattern: Handshake_Pattern, + dh: ecdh.Curve, + cipher: aead.Algorithm, + hash: hash.Algorithm, +} + +@(private = "file") +test_protocol_string :: proc(protocol: ^Test_Protocol, allocator := context.allocator) -> string { + dh: string + #partial switch protocol.dh { + case .X25519: dh = "25519" + case .X448: dh = "448" + case: panic("crypto/noise: unsupported DH") + } + + cipher: string + #partial switch protocol.cipher { + case .AES_GCM_256: cipher = "AESGCM" + case .CHACHA20POLY1305: cipher = "ChaChaPoly" + case: panic("crypto/noise: unsupported cipher") + } + + hash: string + #partial switch protocol.hash { + case .SHA256: hash = "SHA256" + case .SHA512: hash = "SHA512" + case .BLAKE2S: hash = "Blake2s" + case .BLAKE2B: hash = "Blake2b" + case: panic("crypto/noise: unsupported hash") + } + + return fmt.aprintf("Noise_%v_%v_%v_%v", protocol.handshake_pattern, dh, cipher, hash, allocator = allocator) +}