Merge pull request #5069 from yglukhov/ssl-init-fix

Fixed dynlink with OpenSSL >1.1.0. Added loadLibPattern.
This commit is contained in:
Dominik Picheta
2016-11-30 18:44:20 +01:00
committed by GitHub
6 changed files with 109 additions and 51 deletions

View File

@@ -17,6 +17,7 @@ import
lowerings, semparallel
from modulegraphs import ModuleGraph
from dynlib import libCandidates
import strutils except `%` # collides with ropes.`%`

View File

@@ -226,6 +226,8 @@ proc testCompileOptionArg*(switch, arg: string, info: TLineInfo): bool =
of "staticlib": result = contains(gGlobalOptions, optGenStaticLib) and
not contains(gGlobalOptions, optGenGuiApp)
else: localError(info, errGuiConsoleOrLibExpectedButXFound, arg)
of "dynliboverride":
result = isDynlibOverride(arg)
else: invalidCmdLineOption(passCmd1, switch, info)
proc testCompileOption*(switch: string, info: TLineInfo): bool =

View File

@@ -372,17 +372,6 @@ proc findModule*(modulename, currentModule: string): string =
result = findFile(m)
patchModule()
proc libCandidates*(s: string, dest: var seq[string]) =
var le = strutils.find(s, '(')
var ri = strutils.find(s, ')', le+1)
if le >= 0 and ri > le:
var prefix = substr(s, 0, le - 1)
var suffix = substr(s, ri + 1)
for middle in split(substr(s, le + 1, ri - 1), '|'):
libCandidates(prefix & middle & suffix, dest)
else:
add(dest, s)
proc canonDynlibName(s: string): string =
let start = if s.startsWith("lib"): 3 else: 0
let ende = strutils.find(s, {'(', ')', '.'})

View File

@@ -11,20 +11,22 @@
## libraries. On POSIX this uses the ``dlsym`` mechanism, on
## Windows ``LoadLibrary``.
import strutils
type
LibHandle* = pointer ## a handle to a dynamically loaded library
{.deprecated: [TLibHandle: LibHandle].}
proc loadLib*(path: string, global_symbols=false): LibHandle
proc loadLib*(path: string, global_symbols=false): LibHandle {.gcsafe.}
## loads a library from `path`. Returns nil if the library could not
## be loaded.
proc loadLib*(): LibHandle
proc loadLib*(): LibHandle {.gcsafe.}
## gets the handle from the current executable. Returns nil if the
## library could not be loaded.
proc unloadLib*(lib: LibHandle)
proc unloadLib*(lib: LibHandle) {.gcsafe.}
## unloads the library `lib`
proc raiseInvalidLibrary*(name: cstring) {.noinline, noreturn.} =
@@ -34,7 +36,7 @@ proc raiseInvalidLibrary*(name: cstring) {.noinline, noreturn.} =
e.msg = "could not find symbol: " & $name
raise e
proc symAddr*(lib: LibHandle, name: cstring): pointer
proc symAddr*(lib: LibHandle, name: cstring): pointer {.gcsafe.}
## retrieves the address of a procedure/variable from `lib`. Returns nil
## if the symbol could not be found.
@@ -44,6 +46,28 @@ proc checkedSymAddr*(lib: LibHandle, name: cstring): pointer =
result = symAddr(lib, name)
if result == nil: raiseInvalidLibrary(name)
proc libCandidates*(s: string, dest: var seq[string]) =
## given a library name pattern `s` write possible library names to `dest`.
var le = strutils.find(s, '(')
var ri = strutils.find(s, ')', le+1)
if le >= 0 and ri > le:
var prefix = substr(s, 0, le - 1)
var suffix = substr(s, ri + 1)
for middle in split(substr(s, le + 1, ri - 1), '|'):
libCandidates(prefix & middle & suffix, dest)
else:
add(dest, s)
proc loadLibPattern*(pattern: string, global_symbols=false): LibHandle =
## loads a library with name matching `pattern`, similar to what `dlimport`
## pragma does. Returns nil if the library could not be loaded.
## Warning: this proc uses the GC and so cannot be used to load the GC.
var candidates = newSeq[string]()
libCandidates(pattern, candidates)
for c in candidates:
result = loadLib(c, global_symbols)
if not result.isNil: break
when defined(posix):
#
# =========================================================================

View File

@@ -90,8 +90,8 @@ when defineSsl:
SslContext* = ref object
context*: SslCtx
extraInternalIndex: int
referencedData: HashSet[int]
extraInternal: SslContextExtraInternal
SslAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
@@ -103,6 +103,10 @@ when defineSsl:
SslServerGetPskFunc* = proc(identity: string): string
SslContextExtraInternal = ref object of RootRef
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc
{.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
TSSLAcceptResult: SSLAcceptResult].}
@@ -240,11 +244,6 @@ when defineSsl:
ErrLoadBioStrings()
OpenSSL_add_all_algorithms()
type
SslContextExtraInternal = ref object of RootRef
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc
proc raiseSSLError*(s = "") =
## Raises a new SSL error.
if s != "":
@@ -257,12 +256,6 @@ when defineSsl:
var errStr = ErrErrorString(err, nil)
raise newException(SSLError, $errStr)
proc getExtraDataIndex*(ctx: SSLContext): int =
## Retrieves unique index for storing extra data in SSLContext.
result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int
if result < 0:
raiseSSLError()
proc getExtraData*(ctx: SSLContext, index: int): RootRef =
## Retrieves arbitrary data stored inside SSLContext.
if index notin ctx.referencedData:
@@ -347,15 +340,11 @@ when defineSsl:
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
newCTX.loadCertificates(certFile, keyFile)
result = SSLContext(context: newCTX, extraInternalIndex: 0,
referencedData: initSet[int]())
result.extraInternalIndex = getExtraDataIndex(result)
let extraInternal = new(SslContextExtraInternal)
result.setExtraData(result.extraInternalIndex, extraInternal)
result = SSLContext(context: newCTX, referencedData: initSet[int](),
extraInternal: new(SslContextExtraInternal))
proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex))
return ctx.extraInternal
proc destroyContext*(ctx: SSLContext) =
## Free memory referenced by SSLContext.
@@ -379,7 +368,7 @@ when defineSsl:
proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
max_psk_len: cuint): cuint {.cdecl.} =
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
let hintString = if hint == nil: nil else: $hint
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
if psk.len.cuint > max_psk_len:
@@ -398,8 +387,6 @@ when defineSsl:
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().clientGetPskFunc = fun
assert ctx.extraInternalIndex == 0,
"The pskClientCallback assumes the extraInternalIndex is 0"
ctx.context.SSL_CTX_set_psk_client_callback(
if fun == nil: nil else: pskClientCallback)
@@ -407,7 +394,7 @@ when defineSsl:
return ctx.getExtraInternal().serverGetPskFunc
proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
let pskString = (ctx.serverGetPskFunc)($identity)
if psk.len.cint > max_psk_len:
return 0

View File

@@ -37,6 +37,8 @@ else:
DLLUtilName = "libcrypto.so" & versions
from posix import SocketHandle
import dynlib
type
SslStruct {.final, pure.} = object
SslPtr* = ptr SslStruct
@@ -185,16 +187,74 @@ const
BIO_C_DO_STATE_MACHINE = 101
BIO_C_GET_SSL = 110
proc SSL_library_init*(): cInt{.cdecl, dynlib: DLLSSLName, importc, discardable.}
proc SSL_load_error_strings*(){.cdecl, dynlib: DLLSSLName, importc.}
proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLUtilName, importc.}
proc SSLv23_client_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv23_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv2_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv3_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
when compileOption("dynlibOverride", "ssl"):
proc SSL_library_init*(): cint {.cdecl, dynlib: DLLSSLName, importc, discardable.}
proc SSL_load_error_strings*() {.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv23_client_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv23_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv2_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
proc SSLv3_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
template OpenSSL_add_all_algorithms*() = discard
else:
# Here we're trying to stay compatible with openssl 1.0.* and 1.1.*. Some
# symbols are loaded dynamically and we don't use them if not found.
proc thisModule(): LibHandle {.inline.} =
var thisMod {.global.}: LibHandle
if thisMod.isNil: thisMod = loadLib()
result = thisMod
proc sslModule(): LibHandle {.inline.} =
var sslMod {.global.}: LibHandle
if sslMod.isNil: sslMod = loadLibPattern(DLLSSLName)
result = sslMod
proc sslSym(name: string): pointer =
var dl = thisModule()
if not dl.isNil:
result = symAddr(dl, name)
if result.isNil:
dl = sslModule()
if not dl.isNil:
result = symAddr(dl, name)
proc SSL_library_init*(): cint {.discardable.} =
let theProc = cast[proc(): cint {.cdecl.}](sslSym("SSL_library_init"))
if not theProc.isNil: result = theProc()
proc SSL_load_error_strings*() =
let theProc = cast[proc() {.cdecl.}](sslSym("SSL_load_error_strings"))
if not theProc.isNil: theProc()
proc SSLv23_client_method*(): PSSL_METHOD =
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv23_client_method"))
if not theProc.isNil: result = theProc()
else: result = TLSv1_method()
proc SSLv23_method*(): PSSL_METHOD =
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv23_method"))
if not theProc.isNil: result = theProc()
else: result = TLSv1_method()
proc SSLv2_method*(): PSSL_METHOD =
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv2_method"))
if not theProc.isNil: result = theProc()
else: result = TLSv1_method()
proc SSLv3_method*(): PSSL_METHOD =
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv3_method"))
if not theProc.isNil: result = theProc()
else: result = TLSv1_method()
proc OpenSSL_add_all_algorithms*() =
let theProc = cast[proc() {.cdecl.}](sslSym("OPENSSL_add_all_algorithms_conf"))
if not theProc.isNil: theProc()
proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLUtilName, importc.}
proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_get_SSL_CTX*(ssl: SslPtr): SslCtx {.cdecl, dynlib: DLLSSLName, importc.}
@@ -261,11 +321,6 @@ proc ERR_error_string*(e: cInt, buf: cstring): cstring{.cdecl,
proc ERR_get_error*(): cInt{.cdecl, dynlib: DLLUtilName, importc.}
proc ERR_peek_last_error*(): cInt{.cdecl, dynlib: DLLUtilName, importc.}
when defined(android):
template OpenSSL_add_all_algorithms*() = discard
else:
proc OpenSSL_add_all_algorithms*(){.cdecl, dynlib: DLLUtilName, importc: "OPENSSL_add_all_algorithms_conf".}
proc OPENSSL_config*(configName: cstring){.cdecl, dynlib: DLLSSLName, importc.}
when not useWinVersion and not defined(macosx) and not defined(android):