mirror of
https://github.com/nim-lang/Nim.git
synced 2026-04-19 14:00:35 +00:00
Modified #3472 to make its API more idiomatic.
This commit is contained in:
@@ -1,14 +1,26 @@
|
||||
# Stores extra data inside the SSL context.
|
||||
import net
|
||||
|
||||
let ctx = newContext()
|
||||
|
||||
# Our unique index for storing foos
|
||||
let fooIndex = getSslContextExtraDataIndex()
|
||||
let fooIndex = ctx.getExtraDataIndex()
|
||||
# And another unique index for storing foos
|
||||
let barIndex = getSslContextExtraDataIndex()
|
||||
let barIndex = ctx.getExtraDataIndex()
|
||||
echo "got indexes ", fooIndex, " ", barIndex
|
||||
|
||||
let ctx = newContext()
|
||||
assert ctx.getExtraData(fooIndex) == nil
|
||||
let foo: int = 5
|
||||
ctx.setExtraData(fooIndex, cast[pointer](foo))
|
||||
assert cast[int](ctx.getExtraData(fooIndex)) == foo
|
||||
try:
|
||||
discard ctx.getExtraData(fooIndex)
|
||||
assert false
|
||||
except IndexError:
|
||||
echo("Success")
|
||||
|
||||
type
|
||||
FooRef = ref object of RootRef
|
||||
foo: int
|
||||
|
||||
let foo = FooRef(foo: 5)
|
||||
ctx.setExtraData(fooIndex, foo)
|
||||
doAssert ctx.getExtraData(fooIndex).FooRef == foo
|
||||
|
||||
ctx.destroyContext()
|
||||
|
||||
@@ -66,7 +66,7 @@
|
||||
##
|
||||
|
||||
{.deadCodeElim: on.}
|
||||
import nativesockets, os, strutils, parseutils, times
|
||||
import nativesockets, os, strutils, parseutils, times, sets
|
||||
export Port, `$`, `==`
|
||||
export Domain, SockType, Protocol
|
||||
|
||||
@@ -88,7 +88,10 @@ when defineSsl:
|
||||
SslProtVersion* = enum
|
||||
protSSLv2, protSSLv3, protTLSv1, protSSLv23
|
||||
|
||||
SslContext* = distinct SslCtx
|
||||
SslContext* = ref object
|
||||
context: SslCtx
|
||||
extraInternalIndex: int
|
||||
referencedData: HashSet[int]
|
||||
|
||||
SslAcceptResult* = enum
|
||||
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
|
||||
@@ -229,9 +232,10 @@ when defineSsl:
|
||||
ErrLoadBioStrings()
|
||||
OpenSSL_add_all_algorithms()
|
||||
|
||||
type SslContextExtraInternal = ref object
|
||||
serverGetPskFunc: SslServerGetPskFunc
|
||||
clientGetPskFunc: SslClientGetPskFunc
|
||||
type
|
||||
SslContextExtraInternal = ref object of RootRef
|
||||
serverGetPskFunc: SslServerGetPskFunc
|
||||
clientGetPskFunc: SslClientGetPskFunc
|
||||
|
||||
proc raiseSSLError*(s = "") =
|
||||
## Raises a new SSL error.
|
||||
@@ -245,21 +249,33 @@ when defineSsl:
|
||||
var errStr = ErrErrorString(err, nil)
|
||||
raise newException(SSLError, $errStr)
|
||||
|
||||
proc getSslContextExtraDataIndex*(): cint =
|
||||
proc getExtraDataIndex*(ctx: SSLContext): int =
|
||||
## Retrieves unique index for storing extra data in SSLContext.
|
||||
return SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil)
|
||||
|
||||
proc setExtraData*(ctx: SSLContext, index: cint, data: pointer) =
|
||||
## Stores arbitrary data inside SSLContext. The unique `index`
|
||||
## should be retrieved using getSslContextExtraDataIndex.
|
||||
if SslCtx(ctx).SSL_CTX_set_ex_data(index, data) == -1:
|
||||
result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int
|
||||
if result < 0:
|
||||
raiseSSLError()
|
||||
|
||||
proc getExtraData*(ctx: SSLContext, index: cint): pointer =
|
||||
proc getExtraData*(ctx: SSLContext, index: int): RootRef =
|
||||
## Retrieves arbitrary data stored inside SSLContext.
|
||||
return SslCtx(ctx).SSL_CTX_get_ex_data(index)
|
||||
if index notin ctx.referencedData:
|
||||
raise newException(IndexError, "No data with that index.")
|
||||
let res = ctx.context.SSL_CTX_get_ex_data(index.cint)
|
||||
if cast[int](res) == 0:
|
||||
raiseSSLError()
|
||||
return cast[RootRef](res)
|
||||
|
||||
let extraInternalIndex = getSslContextExtraDataIndex()
|
||||
proc setExtraData*(ctx: SSLContext, index: int, data: RootRef) =
|
||||
## Stores arbitrary data inside SSLContext. The unique `index`
|
||||
## should be retrieved using getSslContextExtraDataIndex.
|
||||
if index in ctx.referencedData:
|
||||
GC_unref(getExtraData(ctx, index))
|
||||
|
||||
if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1:
|
||||
raiseSSLError()
|
||||
|
||||
if index notin ctx.referencedData:
|
||||
ctx.referencedData.incl(index)
|
||||
GC_ref(data)
|
||||
|
||||
# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
|
||||
proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) =
|
||||
@@ -323,26 +339,33 @@ when defineSsl:
|
||||
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
|
||||
newCTX.loadCertificates(certFile, keyFile)
|
||||
|
||||
result = SSLContext(newCTX)
|
||||
result = SSLContext(context: newCTX, extraInternalIndex: 0,
|
||||
referencedData: initSet[int]())
|
||||
result.extraInternalIndex = getExtraDataIndex(result)
|
||||
# The PSK callback functions assume the internal index is 0.
|
||||
assert result.extraInternalIndex == 0
|
||||
|
||||
let extraInternal = new(SslContextExtraInternal)
|
||||
GC_ref(extraInternal)
|
||||
result.setExtraData(extraInternalIndex, cast[pointer](extraInternal))
|
||||
result.setExtraData(result.extraInternalIndex, extraInternal)
|
||||
|
||||
proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
|
||||
return cast[SslContextExtraInternal](ctx.getExtraData(extraInternalIndex))
|
||||
return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex))
|
||||
|
||||
proc destroyContext*(ctx: SSLContext) =
|
||||
## Free memory referenced by SSLContext.
|
||||
let extraInternal = ctx.getExtraInternal()
|
||||
if extraInternal != nil:
|
||||
GC_unref(extraInternal)
|
||||
SSLCTX(ctx).SSL_CTX_free()
|
||||
|
||||
# We assume here that OpenSSL's internal indexes increase by 1 each time.
|
||||
# That means we can assume that the next internal index is the length of
|
||||
# extra data indexes.
|
||||
for i in ctx.referencedData:
|
||||
GC_unref(getExtraData(ctx, i).RootRef)
|
||||
ctx.context.SSL_CTX_free()
|
||||
|
||||
proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) =
|
||||
## Sets the identity hint passed to server.
|
||||
##
|
||||
## Only used in PSK ciphersuites.
|
||||
if SSLCTX(ctx).SSL_CTX_use_psk_identity_hint(hint) <= 0:
|
||||
if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0:
|
||||
raiseSSLError()
|
||||
|
||||
proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc =
|
||||
@@ -350,7 +373,7 @@ when defineSsl:
|
||||
|
||||
proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
|
||||
max_psk_len: cuint): cuint {.cdecl.} =
|
||||
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
|
||||
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
|
||||
let hintString = if hint == nil: nil else: $hint
|
||||
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
|
||||
if psk.len.cuint > max_psk_len:
|
||||
@@ -369,13 +392,14 @@ when defineSsl:
|
||||
##
|
||||
## Only used in PSK ciphersuites.
|
||||
ctx.getExtraInternal().clientGetPskFunc = fun
|
||||
SslCtx(ctx).SSL_CTX_set_psk_client_callback(if fun == nil: nil else: pskClientCallback)
|
||||
ctx.context.SSL_CTX_set_psk_client_callback(
|
||||
if fun == nil: nil else: pskClientCallback)
|
||||
|
||||
proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
|
||||
return ctx.getExtraInternal().serverGetPskFunc
|
||||
|
||||
proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
|
||||
let ctx = SSLContext(ssl.SSL_get_SSL_CTX)
|
||||
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
|
||||
let pskString = (ctx.serverGetPskFunc)($identity)
|
||||
if psk.len.cint > max_psk_len:
|
||||
return 0
|
||||
@@ -388,7 +412,7 @@ when defineSsl:
|
||||
##
|
||||
## Only used in PSK ciphersuites.
|
||||
ctx.getExtraInternal().serverGetPskFunc = fun
|
||||
SslCtx(ctx).SSL_CTX_set_psk_server_callback(if fun == nil: nil
|
||||
ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil
|
||||
else: pskServerCallback)
|
||||
|
||||
proc getPskIdentity*(socket: Socket): string =
|
||||
@@ -409,7 +433,7 @@ when defineSsl:
|
||||
assert (not socket.isSSL)
|
||||
socket.isSSL = true
|
||||
socket.sslContext = ctx
|
||||
socket.sslHandle = SSLNew(SSLCTX(socket.sslContext))
|
||||
socket.sslHandle = SSLNew(socket.sslContext.context)
|
||||
socket.sslNoHandshake = false
|
||||
socket.sslHasPeekChar = false
|
||||
if socket.sslHandle == nil:
|
||||
|
||||
Reference in New Issue
Block a user