mirror of
https://github.com/nim-lang/Nim.git
synced 2025-12-28 17:04:41 +00:00
Merge branch 'tls-psk' of https://github.com/zielmicha/nim into zielmicha-tls-psk
This commit is contained in:
14
examples/ssl/extradata.nim
Normal file
14
examples/ssl/extradata.nim
Normal file
@@ -0,0 +1,14 @@
|
||||
# Stores extra data inside the SSL context.
|
||||
import net
|
||||
|
||||
# Our unique index for storing foos
|
||||
let fooIndex = getSslContextExtraDataIndex()
|
||||
# And another unique index for storing foos
|
||||
let barIndex = getSslContextExtraDataIndex()
|
||||
echo "got indexes ", fooIndex, " ", barIndex
|
||||
|
||||
let ctx = newContext()
|
||||
assert ctx.getExtraData(fooIndex) == nil
|
||||
let foo: int = 5
|
||||
ctx.setExtraData(fooIndex, cast[pointer](foo))
|
||||
assert cast[int](ctx.getExtraData(fooIndex)) == foo
|
||||
16
examples/ssl/pskclient.nim
Normal file
16
examples/ssl/pskclient.nim
Normal file
@@ -0,0 +1,16 @@
|
||||
# Create connection encrypted using preshared key (TLS-PSK).
|
||||
import net
|
||||
|
||||
static: assert defined(ssl)
|
||||
|
||||
let sock = newSocket()
|
||||
sock.connect("localhost", Port(8800))
|
||||
|
||||
proc clientFunc(identityHint: string): tuple[identity: string, psk: string] =
|
||||
echo "identity hint ", identityHint.repr
|
||||
return ("foo", "psk-of-foo")
|
||||
|
||||
let context = newContext(cipherList="PSK-AES256-CBC-SHA")
|
||||
context.clientGetPskFunc = clientFunc
|
||||
context.wrapConnectedSocket(sock, handshakeAsClient)
|
||||
context.destroyContext()
|
||||
20
examples/ssl/pskserver.nim
Normal file
20
examples/ssl/pskserver.nim
Normal file
@@ -0,0 +1,20 @@
|
||||
# Accept connection encrypted using preshared key (TLS-PSK).
|
||||
import net
|
||||
|
||||
static: assert defined(ssl)
|
||||
|
||||
let sock = newSocket()
|
||||
sock.bindAddr(Port(8800))
|
||||
sock.listen()
|
||||
|
||||
let context = newContext(cipherList="PSK-AES256-CBC-SHA")
|
||||
context.pskIdentityHint = "hello"
|
||||
context.serverGetPskFunc = proc(identity: string): string = "psk-of-" & identity
|
||||
|
||||
while true:
|
||||
var client = new(Socket)
|
||||
sock.accept(client)
|
||||
sock.setSockOpt(OptReuseAddr, true)
|
||||
echo "accepted connection"
|
||||
context.wrapConnectedSocket(client, handshakeAsServer)
|
||||
echo "got connection with identity ", client.getPskIdentity()
|
||||
102
lib/pure/net.nim
102
lib/pure/net.nim
@@ -96,6 +96,10 @@ when defineSsl:
|
||||
SslHandshakeType* = enum
|
||||
handshakeAsClient, handshakeAsServer
|
||||
|
||||
SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string]
|
||||
|
||||
SslServerGetPskFunc* = proc(identity: string): string
|
||||
|
||||
{.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
|
||||
TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
|
||||
TSSLAcceptResult: SSLAcceptResult].}
|
||||
@@ -225,6 +229,10 @@ when defineSsl:
|
||||
ErrLoadBioStrings()
|
||||
OpenSSL_add_all_algorithms()
|
||||
|
||||
type SslContextExtraInternal = ref object
|
||||
serverGetPskFunc: SslServerGetPskFunc
|
||||
clientGetPskFunc: SslClientGetPskFunc
|
||||
|
||||
proc raiseSSLError*(s = "") =
|
||||
## Raises a new SSL error.
|
||||
if s != "":
|
||||
@@ -237,6 +245,22 @@ when defineSsl:
|
||||
var errStr = ErrErrorString(err, nil)
|
||||
raise newException(SSLError, $errStr)
|
||||
|
||||
proc getSslContextExtraDataIndex*(): cint =
|
||||
## Retrieves unique index for storing extra data in SSLContext.
|
||||
return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil)
|
||||
|
||||
proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) =
|
||||
## Stores arbitrary data inside SSLContext. The unique `index`
|
||||
## should be retrieved using getSslContextExtraDataIndex.
|
||||
if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1:
|
||||
raiseSSLError()
|
||||
|
||||
proc getExtraData*(ctx: SSLContext, index: cint): pointer =
|
||||
## Retrieves arbitrary data stored inside SSLContext.
|
||||
return SslCtx(ctx).SSL_CTX_get_ex_data(index)
|
||||
|
||||
let extraInternalIndex = getSslContextExtraDataIndex()
|
||||
|
||||
# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
|
||||
proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) =
|
||||
if certFile != "" and not existsFile(certFile):
|
||||
@@ -259,7 +283,7 @@ when defineSsl:
|
||||
raiseSSLError("Verification of private key file failed.")
|
||||
|
||||
proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
|
||||
certFile = "", keyFile = ""): SSLContext =
|
||||
certFile = "", keyFile = "", cipherList = "ALL"): SSLContext =
|
||||
## Creates an SSL context.
|
||||
##
|
||||
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
|
||||
@@ -286,7 +310,7 @@ when defineSsl:
|
||||
of protTLSv1:
|
||||
newCTX = SSL_CTX_new(TLSv1_method())
|
||||
|
||||
if newCTX.SSLCTXSetCipherList("ALL") != 1:
|
||||
if newCTX.SSLCTXSetCipherList(cipherList) != 1:
|
||||
raiseSSLError()
|
||||
case verifyMode
|
||||
of CVerifyPeer:
|
||||
@@ -298,7 +322,79 @@ when defineSsl:
|
||||
|
||||
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
|
||||
newCTX.loadCertificates(certFile, keyFile)
|
||||
return SSLContext(newCTX)
|
||||
|
||||
result = SSLContext(newCTX)
|
||||
let extraInternal = new(SslContextExtraInternal)
|
||||
GC_ref(extraInternal)
|
||||
result.setExtraData(extraInternalIndex, cast[pointer](extraInternal))
|
||||
|
||||
proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
|
||||
return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex))
|
||||
|
||||
proc destroyContext*(ctx: SSLContext) =
|
||||
## Free memory referenced by SSLContext.
|
||||
let extraInternal = ctx.getExtraInternal()
|
||||
if extraInternal != nil:
|
||||
GC_unref(extraInternal)
|
||||
SSLCTX(ctx).SSL_CTX_free()
|
||||
|
||||
proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) =
|
||||
## Sets the identity hint passed to server.
|
||||
##
|
||||
## Only used in PSK ciphersuites.
|
||||
if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0:
|
||||
raiseSSLError()
|
||||
|
||||
proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc =
|
||||
return ctx.getExtraInternal().clientGetPskFunc
|
||||
|
||||
proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
|
||||
max_psk_len: cuint): cuint {.cdecl.} =
|
||||
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
|
||||
let hintString = if hint == nil: nil else: $hint
|
||||
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
|
||||
if psk.len.cuint > max_psk_len:
|
||||
return 0
|
||||
if identityString.len.cuint >= max_identity_len:
|
||||
return 0
|
||||
|
||||
copyMem(identity, identityString.cstring, pskString.len + 1) # with the last zero byte
|
||||
copyMem(psk, pskString.cstring, pskString.len)
|
||||
|
||||
return pskString.len.cuint
|
||||
|
||||
proc `clientGetPskFunc=`*(ctx: SSLContext, fun: SslClientGetPskFunc) =
|
||||
## Sets function that returns the client identity and the PSK based on identity
|
||||
## hint from the server.
|
||||
##
|
||||
## Only used in PSK ciphersuites.
|
||||
ctx.getExtraInternal().clientGetPskFunc = fun
|
||||
SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback)
|
||||
|
||||
proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
|
||||
return ctx.getExtraInternal().serverGetPskFunc
|
||||
|
||||
proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
|
||||
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
|
||||
let pskString = (ctx.serverGetPskFunc)($identity)
|
||||
if psk.len.cint > max_psk_len:
|
||||
return 0
|
||||
copyMem(psk, pskString.cstring, pskString.len)
|
||||
|
||||
return pskString.len.cuint
|
||||
|
||||
proc `serverGetPskFunc=`*(ctx: SSLContext, fun: SslServerGetPskFunc) =
|
||||
## Sets function that returns PSK based on the client identity.
|
||||
##
|
||||
## Only used in PSK ciphersuites.
|
||||
ctx.getExtraInternal().serverGetPskFunc = fun
|
||||
SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil
|
||||
else: pskServerCallback)
|
||||
|
||||
proc getPskIdentity*(socket: Socket): string =
|
||||
## Gets the PSK identity provided by the client.
|
||||
assert socket.isSSL
|
||||
return $(socket.sslHandle.SSL_get_psk_identity)
|
||||
|
||||
proc wrapSocket*(ctx: SSLContext, socket: Socket) =
|
||||
## Wraps a socket in an SSL context. This function effectively turns
|
||||
|
||||
@@ -197,6 +197,7 @@ proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
|
||||
proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSL_get_SSL_CTX*(ssl: SslPtr): SslCtx {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSL_CTX_new*(meth: PSSL_METHOD): SslCtx{.cdecl,
|
||||
dynlib: DLLSSLName, importc.}
|
||||
proc SSL_CTX_load_verify_locations*(ctx: SslCtx, CAfile: cstring,
|
||||
@@ -216,6 +217,10 @@ proc SSL_CTX_use_PrivateKey_file*(ctx: SslCtx,
|
||||
proc SSL_CTX_check_private_key*(ctx: SslCtx): cInt{.cdecl, dynlib: DLLSSLName,
|
||||
importc.}
|
||||
|
||||
proc SSL_CTX_get_ex_new_index*(argl: clong, argp: pointer, new_func: pointer, dup_func: pointer, free_func: pointer): cint {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSL_CTX_set_ex_data*(ssl: SslCtx, idx: cint, arg: pointer): cint {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSL_CTX_get_ex_data*(ssl: SslCtx, idx: cint): pointer {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
|
||||
proc SSL_set_fd*(ssl: SslPtr, fd: SocketHandle): cint{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
|
||||
proc SSL_shutdown*(ssl: SslPtr): cInt{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
@@ -314,6 +319,25 @@ proc SSL_CTX_set_tlsext_servername_arg*(ctx: SslCtx, arg: pointer): int =
|
||||
## Set the pointer to be used in the callback registered to ``SSL_CTX_set_tlsext_servername_callback``.
|
||||
result = SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG, 0, arg)
|
||||
|
||||
type
|
||||
PskClientCallback* = proc (ssl: SslPtr;
|
||||
hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
|
||||
max_psk_len: cuint): cuint {.cdecl.}
|
||||
|
||||
PskServerCallback* = proc (ssl: SslPtr;
|
||||
identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.}
|
||||
|
||||
proc SSL_CTX_set_psk_client_callback*(ctx: SslCtx; callback: PskClientCallback) {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
## Set callback called when OpenSSL needs PSK (for client).
|
||||
|
||||
proc SSL_CTX_set_psk_server_callback*(ctx: SslCtx; callback: PskServerCallback) {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
## Set callback called when OpenSSL needs PSK (for server).
|
||||
|
||||
proc SSL_CTX_use_psk_identity_hint*(ctx: SslCtx; hint: cstring): cint {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
## Set PSK identity hint to use.
|
||||
|
||||
proc SSL_get_psk_identity*(ssl: SslPtr): cstring {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
## Get PSK identity.
|
||||
|
||||
proc bioNew*(b: PBIO_METHOD): BIO{.cdecl, dynlib: DLLUtilName, importc: "BIO_new".}
|
||||
proc bioFreeAll*(b: BIO){.cdecl, dynlib: DLLUtilName, importc: "BIO_free_all".}
|
||||
|
||||
Reference in New Issue
Block a user