Modified #3472 to make its API more idiomatic.

This commit is contained in:
Dominik Picheta
2016-06-03 13:22:18 +01:00
parent c1706463dc
commit 5390c25b60
2 changed files with 72 additions and 36 deletions

View File

@@ -1,14 +1,26 @@
# Stores extra data inside the SSL context.
import net
let ctx = newContext()
# Our unique index for storing foos
let fooIndex = getSslContextExtraDataIndex()
let fooIndex = ctx.getExtraDataIndex()
# And another unique index for storing foos
let barIndex = getSslContextExtraDataIndex()
let barIndex = ctx.getExtraDataIndex()
echo "got indexes ", fooIndex, " ", barIndex
let ctx = newContext()
assert ctx.getExtraData(fooIndex) == nil
let foo: int = 5
ctx.setExtraData(fooIndex, cast[pointer](foo))
assert cast[int](ctx.getExtraData(fooIndex)) == foo
try:
discard ctx.getExtraData(fooIndex)
assert false
except IndexError:
echo("Success")
type
FooRef = ref object of RootRef
foo: int
let foo = FooRef(foo: 5)
ctx.setExtraData(fooIndex, foo)
doAssert ctx.getExtraData(fooIndex).FooRef == foo
ctx.destroyContext()

View File

@@ -66,7 +66,7 @@
##
{.deadCodeElim: on.}
import nativesockets, os, strutils, parseutils, times
import nativesockets, os, strutils, parseutils, times, sets
export Port, `$`, `==`
export Domain, SockType, Protocol
@@ -88,7 +88,10 @@ when defineSsl:
SslProtVersion* = enum
protSSLv2, protSSLv3, protTLSv1, protSSLv23
SslContext* = distinct SslCtx
SslContext* = ref object
context: SslCtx
extraInternalIndex: int
referencedData: HashSet[int]
SslAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
@@ -229,9 +232,10 @@ when defineSsl:
ErrLoadBioStrings()
OpenSSL_add_all_algorithms()
type SslContextExtraInternal = ref object
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc
type
SslContextExtraInternal = ref object of RootRef
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc
proc raiseSSLError*(s = "") =
## Raises a new SSL error.
@@ -245,21 +249,33 @@ when defineSsl:
var errStr = ErrErrorString(err, nil)
raise newException(SSLError, $errStr)
proc getSslContextExtraDataIndex*(): cint =
proc getExtraDataIndex*(ctx: SSLContext): int =
## Retrieves unique index for storing extra data in SSLContext.
return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil)
proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) =
## Stores arbitrary data inside SSLContext. The unique `index`
## should be retrieved using getSslContextExtraDataIndex.
if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1:
result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int
if result < 0:
raiseSSLError()
proc getExtraData*(ctx: SSLContext, index: cint): pointer =
proc getExtraData*(ctx: SSLContext, index: int): RootRef =
## Retrieves arbitrary data stored inside SSLContext.
return SslCtx(ctx).SSL_CTX_get_ex_data(index)
if index notin ctx.referencedData:
raise newException(IndexError, "No data with that index.")
let res = ctx.context.SSL_CTX_get_ex_data(index.cint)
if cast[int](res) == 0:
raiseSSLError()
return cast[RootRef](res)
let extraInternalIndex = getSslContextExtraDataIndex()
proc setExtraData*(ctx: SSLContext, index: int, data: RootRef) =
## Stores arbitrary data inside SSLContext. The unique `index`
## should be retrieved using getSslContextExtraDataIndex.
if index in ctx.referencedData:
GC_unref(getExtraData(ctx, index))
if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1:
raiseSSLError()
if index notin ctx.referencedData:
ctx.referencedData.incl(index)
GC_ref(data)
# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) =
@@ -323,26 +339,33 @@ when defineSsl:
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
newCTX.loadCertificates(certFile, keyFile)
result = SSLContext(newCTX)
result = SSLContext(context: newCTX, extraInternalIndex: 0,
referencedData: initSet[int]())
result.extraInternalIndex = getExtraDataIndex(result)
# The PSK callback functions assume the internal index is 0.
assert result.extraInternalIndex == 0
let extraInternal = new(SslContextExtraInternal)
GC_ref(extraInternal)
result.setExtraData(extraInternalIndex, cast[pointer](extraInternal))
result.setExtraData(result.extraInternalIndex, extraInternal)
proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex))
return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex))
proc destroyContext*(ctx: SSLContext) =
## Free memory referenced by SSLContext.
let extraInternal = ctx.getExtraInternal()
if extraInternal != nil:
GC_unref(extraInternal)
SSLCTX(ctx).SSL_CTX_free()
# We assume here that OpenSSL's internal indexes increase by 1 each time.
# That means we can assume that the next internal index is the length of
# extra data indexes.
for i in ctx.referencedData:
GC_unref(getExtraData(ctx, i).RootRef)
ctx.context.SSL_CTX_free()
proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) =
## Sets the identity hint passed to server.
##
## Only used in PSK ciphersuites.
if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0:
if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0:
raiseSSLError()
proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc =
@@ -350,7 +373,7 @@ when defineSsl:
proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
max_psk_len: cuint): cuint {.cdecl.} =
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let hintString = if hint == nil: nil else: $hint
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
if psk.len.cuint > max_psk_len:
@@ -369,13 +392,14 @@ when defineSsl:
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().clientGetPskFunc = fun
SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback)
ctx.context.SSL_CTX_set_psk_client_callback(
if fun == nil: nil else: pskClientCallback)
proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
return ctx.getExtraInternal().serverGetPskFunc
proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let pskString = (ctx.serverGetPskFunc)($identity)
if psk.len.cint > max_psk_len:
return 0
@@ -388,7 +412,7 @@ when defineSsl:
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().serverGetPskFunc = fun
SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil
ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil
else: pskServerCallback)
proc getPskIdentity*(socket: Socket): string =
@@ -409,7 +433,7 @@ when defineSsl:
assert (not socket.isSSL)
socket.isSSL = true
socket.sslContext = ctx
socket.sslHandle = SSLNew(SSLCTX(socket.sslContext))
socket.sslHandle = SSLNew(socket.sslContext.context)
socket.sslNoHandshake = false
socket.sslHasPeekChar = false
if socket.sslHandle == nil: