net.nim: support for TLS-PSK ciphersuites

This commit is contained in:
Michał Zieliński
2015-10-24 08:53:18 +02:00
parent 3ebf27ddd2
commit ba61a8d00a
4 changed files with 144 additions and 13 deletions

View File

@@ -0,0 +1,15 @@
# Create connection encrypted using preshared key (TLS-PSK).
import net
static: assert defined(ssl)
let sock = newSocket()
sock.connect("localhost", Port(8800))
proc clientFunc(identityHint: string): tuple[identity: string, psk: string] =
echo "identity hint ", identityHint.repr
return ("foo", "psk-of-foo")
let context = newContext(cipherList="PSK-AES256-CBC-SHA")
context.clientGetPskFunc = clientFunc
context.wrapConnectedSocket(sock, handshakeAsClient)

View File

@@ -0,0 +1,20 @@
# Accept connection encrypted using preshared key (TLS-PSK).
import net
static: assert defined(ssl)
let sock = newSocket()
sock.bindAddr(Port(8800))
sock.listen()
let context = newContext(cipherList="PSK-AES256-CBC-SHA")
context.pskIdentityHint = "hello"
context.serverGetPskFunc = proc(identity: string): string = "psk-of-" & identity
while true:
var client = new(Socket)
sock.accept(client)
sock.setSockOpt(OptReuseAddr, true)
echo "accepted connection"
context.wrapConnectedSocket(client, handshakeAsServer)
echo "got connection with identity ", client.getPskIdentity()

View File

@@ -38,6 +38,10 @@ when defined(ssl):
SslHandshakeType* = enum
handshakeAsClient, handshakeAsServer
SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string]
SslServerGetPskFunc* = proc(identity: string): string
{.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
TSSLAcceptResult: SSLAcceptResult].}
@@ -168,6 +172,10 @@ when defined(ssl):
ErrLoadBioStrings()
OpenSSL_add_all_algorithms()
type SslContextExtraInternal = ref object
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc
proc raiseSSLError*(s = "") =
## Raises a new SSL error.
if s != "":
@@ -180,6 +188,22 @@ when defined(ssl):
var errStr = ErrErrorString(err, nil)
raise newException(SSLError, $errStr)
proc getSslContextExtraDataIndex*(): cint =
## Retrieves unique index for storing extra data in SSLContext.
return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil)
proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) =
## Stores arbitrary data inside SSLContext. The unique `index`
## should be retrieved using getSslContextExtraDataIndex.
if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1:
raiseSSLError()
proc getExtraData*(ctx: SSLContext, index: cint): pointer =
## Retrieves arbitrary data stored inside SSLContext.
return SslCtx(ctx).SSL_CTX_get_ex_data(index)
let extraInternalIndex = getSslContextExtraDataIndex()
# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) =
if certFile != "" and not existsFile(certFile):
@@ -202,7 +226,7 @@ when defined(ssl):
raiseSSLError("Verification of private key file failed.")
proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
certFile = "", keyFile = ""): SSLContext =
certFile = "", keyFile = "", cipherList = "ALL"): SSLContext =
## Creates an SSL context.
##
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
@@ -229,7 +253,7 @@ when defined(ssl):
of protTLSv1:
newCTX = SSL_CTX_new(TLSv1_method())
if newCTX.SSLCTXSetCipherList("ALL") != 1:
if newCTX.SSLCTXSetCipherList(cipherList) != 1:
raiseSSLError()
case verifyMode
of CVerifyPeer:
@@ -241,21 +265,73 @@ when defined(ssl):
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
newCTX.loadCertificates(certFile, keyFile)
return SSLContext(newCTX)
proc getSslContextExtraDataIndex*(): cint =
## Retrieves unique index for storing extra data in SSLContext.
return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil)
result = SSLContext(newCTX)
# this is never freed, but SSLContext can't be freed anyway yet
let extraInternal = new(SslContextExtraInternal)
GC_ref(extraInternal)
result.setExtraData(extraInternalIndex, cast[pointer](extraInternal))
proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) =
## Stores arbitrary data inside SSLContext. The unique `index`
## should be retrieved using getSslContextExtraDataIndex.
if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1:
proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex))
proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) =
## Sets the identity hint passed to server.
##
## Only used in PSK ciphersuites.
if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0:
raiseSSLError()
proc getExtraData*(ctx: SSLContext, index: cint): pointer =
## Retrieves arbitrary data stored inside SSLContext.
return SslCtx(ctx).SSL_CTX_get_ex_data(index)
proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc =
return ctx.getExtraInternal().clientGetPskFunc
proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
max_psk_len: cuint): cuint {.cdecl.} =
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
let hintString = if hint == nil: nil else: $hint
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
if psk.len.cuint > max_psk_len:
return 0
if identityString.len.cuint >= max_identity_len:
return 0
copyMem(identity, identityString.cstring, pskString.len + 1) # with the last zero byte
copyMem(psk, pskString.cstring, pskString.len)
return pskString.len.cuint
proc `clientGetPskFunc=`*(ctx: SSLContext, fun: SslClientGetPskFunc) =
## Sets function that returns the client identity and the PSK based on identity
## hint from the server.
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().clientGetPskFunc = fun
SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback)
proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
return ctx.getExtraInternal().serverGetPskFunc
proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
let pskString = (ctx.serverGetPskFunc)($identity)
if psk.len.cint > max_psk_len:
return 0
copyMem(psk, pskString.cstring, pskString.len)
return pskString.len.cuint
proc `serverGetPskFunc=`*(ctx: SSLContext, fun: SslServerGetPskFunc) =
## Sets function that returns PSK based on the client identity.
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().serverGetPskFunc = fun
SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil
else: pskServerCallback)
proc getPskIdentity*(socket: Socket): string =
## Gets the PSK identity provided by the client.
assert socket.isSSL
return $(socket.sslHandle.SSL_get_psk_identity)
proc wrapSocket*(ctx: SSLContext, socket: Socket) =
## Wraps a socket in an SSL context. This function effectively turns

View File

@@ -197,6 +197,7 @@ proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_get_SSL_CTX*(ssl: SslPtr): SslCtx {.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_CTX_new*(meth: PSSL_METHOD): SslCtx{.cdecl,
dynlib: DLLSSLName, importc.}
proc SSL_CTX_load_verify_locations*(ctx: SslCtx, CAfile: cstring,
@@ -318,6 +319,25 @@ proc SSL_CTX_set_tlsext_servername_arg*(ctx: SslCtx, arg: pointer): int =
## Set the pointer to be used in the callback registered to ``SSL_CTX_set_tlsext_servername_callback``.
result = SSL_CTX_ctrl(ctx, SSL_CTRL_SET_TLSEXT_SERVERNAME_ARG, 0, arg)
type
PskClientCallback* = proc (ssl: SslPtr;
hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
max_psk_len: cuint): cuint {.cdecl.}
PskServerCallback* = proc (ssl: SslPtr;
identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.}
proc SSL_CTX_set_psk_client_callback*(ctx: SslCtx; callback: PskClientCallback) {.cdecl, dynlib: DLLSSLName, importc.}
## Set callback called when OpenSSL needs PSK (for client).
proc SSL_CTX_set_psk_server_callback*(ctx: SslCtx; callback: PskServerCallback) {.cdecl, dynlib: DLLSSLName, importc.}
## Set callback called when OpenSSL needs PSK (for server).
proc SSL_CTX_use_psk_identity_hint*(ctx: SslCtx; hint: cstring): cint {.cdecl, dynlib: DLLSSLName, importc.}
## Set PSK identity hint to use.
proc SSL_get_psk_identity*(ssl: SslPtr): cstring {.cdecl, dynlib: DLLSSLName, importc.}
## Get PSK identity.
proc bioNew*(b: PBIO_METHOD): BIO{.cdecl, dynlib: DLLUtilName, importc: "BIO_new".}
proc bioFreeAll*(b: BIO){.cdecl, dynlib: DLLUtilName, importc: "BIO_free_all".}