mirror of
https://github.com/nim-lang/Nim.git
synced 2026-01-05 20:47:53 +00:00
Merge pull request #5069 from yglukhov/ssl-init-fix
Fixed dynlink with OpenSSL >1.1.0. Added loadLibPattern.
This commit is contained in:
@@ -17,6 +17,7 @@ import
|
||||
lowerings, semparallel
|
||||
|
||||
from modulegraphs import ModuleGraph
|
||||
from dynlib import libCandidates
|
||||
|
||||
import strutils except `%` # collides with ropes.`%`
|
||||
|
||||
|
||||
@@ -226,6 +226,8 @@ proc testCompileOptionArg*(switch, arg: string, info: TLineInfo): bool =
|
||||
of "staticlib": result = contains(gGlobalOptions, optGenStaticLib) and
|
||||
not contains(gGlobalOptions, optGenGuiApp)
|
||||
else: localError(info, errGuiConsoleOrLibExpectedButXFound, arg)
|
||||
of "dynliboverride":
|
||||
result = isDynlibOverride(arg)
|
||||
else: invalidCmdLineOption(passCmd1, switch, info)
|
||||
|
||||
proc testCompileOption*(switch: string, info: TLineInfo): bool =
|
||||
|
||||
@@ -372,17 +372,6 @@ proc findModule*(modulename, currentModule: string): string =
|
||||
result = findFile(m)
|
||||
patchModule()
|
||||
|
||||
proc libCandidates*(s: string, dest: var seq[string]) =
|
||||
var le = strutils.find(s, '(')
|
||||
var ri = strutils.find(s, ')', le+1)
|
||||
if le >= 0 and ri > le:
|
||||
var prefix = substr(s, 0, le - 1)
|
||||
var suffix = substr(s, ri + 1)
|
||||
for middle in split(substr(s, le + 1, ri - 1), '|'):
|
||||
libCandidates(prefix & middle & suffix, dest)
|
||||
else:
|
||||
add(dest, s)
|
||||
|
||||
proc canonDynlibName(s: string): string =
|
||||
let start = if s.startsWith("lib"): 3 else: 0
|
||||
let ende = strutils.find(s, {'(', ')', '.'})
|
||||
|
||||
@@ -11,20 +11,22 @@
|
||||
## libraries. On POSIX this uses the ``dlsym`` mechanism, on
|
||||
## Windows ``LoadLibrary``.
|
||||
|
||||
import strutils
|
||||
|
||||
type
|
||||
LibHandle* = pointer ## a handle to a dynamically loaded library
|
||||
|
||||
{.deprecated: [TLibHandle: LibHandle].}
|
||||
|
||||
proc loadLib*(path: string, global_symbols=false): LibHandle
|
||||
proc loadLib*(path: string, global_symbols=false): LibHandle {.gcsafe.}
|
||||
## loads a library from `path`. Returns nil if the library could not
|
||||
## be loaded.
|
||||
|
||||
proc loadLib*(): LibHandle
|
||||
proc loadLib*(): LibHandle {.gcsafe.}
|
||||
## gets the handle from the current executable. Returns nil if the
|
||||
## library could not be loaded.
|
||||
|
||||
proc unloadLib*(lib: LibHandle)
|
||||
proc unloadLib*(lib: LibHandle) {.gcsafe.}
|
||||
## unloads the library `lib`
|
||||
|
||||
proc raiseInvalidLibrary*(name: cstring) {.noinline, noreturn.} =
|
||||
@@ -34,7 +36,7 @@ proc raiseInvalidLibrary*(name: cstring) {.noinline, noreturn.} =
|
||||
e.msg = "could not find symbol: " & $name
|
||||
raise e
|
||||
|
||||
proc symAddr*(lib: LibHandle, name: cstring): pointer
|
||||
proc symAddr*(lib: LibHandle, name: cstring): pointer {.gcsafe.}
|
||||
## retrieves the address of a procedure/variable from `lib`. Returns nil
|
||||
## if the symbol could not be found.
|
||||
|
||||
@@ -44,6 +46,28 @@ proc checkedSymAddr*(lib: LibHandle, name: cstring): pointer =
|
||||
result = symAddr(lib, name)
|
||||
if result == nil: raiseInvalidLibrary(name)
|
||||
|
||||
proc libCandidates*(s: string, dest: var seq[string]) =
|
||||
## given a library name pattern `s` write possible library names to `dest`.
|
||||
var le = strutils.find(s, '(')
|
||||
var ri = strutils.find(s, ')', le+1)
|
||||
if le >= 0 and ri > le:
|
||||
var prefix = substr(s, 0, le - 1)
|
||||
var suffix = substr(s, ri + 1)
|
||||
for middle in split(substr(s, le + 1, ri - 1), '|'):
|
||||
libCandidates(prefix & middle & suffix, dest)
|
||||
else:
|
||||
add(dest, s)
|
||||
|
||||
proc loadLibPattern*(pattern: string, global_symbols=false): LibHandle =
|
||||
## loads a library with name matching `pattern`, similar to what `dlimport`
|
||||
## pragma does. Returns nil if the library could not be loaded.
|
||||
## Warning: this proc uses the GC and so cannot be used to load the GC.
|
||||
var candidates = newSeq[string]()
|
||||
libCandidates(pattern, candidates)
|
||||
for c in candidates:
|
||||
result = loadLib(c, global_symbols)
|
||||
if not result.isNil: break
|
||||
|
||||
when defined(posix):
|
||||
#
|
||||
# =========================================================================
|
||||
|
||||
@@ -90,8 +90,8 @@ when defineSsl:
|
||||
|
||||
SslContext* = ref object
|
||||
context*: SslCtx
|
||||
extraInternalIndex: int
|
||||
referencedData: HashSet[int]
|
||||
extraInternal: SslContextExtraInternal
|
||||
|
||||
SslAcceptResult* = enum
|
||||
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
|
||||
@@ -103,6 +103,10 @@ when defineSsl:
|
||||
|
||||
SslServerGetPskFunc* = proc(identity: string): string
|
||||
|
||||
SslContextExtraInternal = ref object of RootRef
|
||||
serverGetPskFunc: SslServerGetPskFunc
|
||||
clientGetPskFunc: SslClientGetPskFunc
|
||||
|
||||
{.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
|
||||
TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
|
||||
TSSLAcceptResult: SSLAcceptResult].}
|
||||
@@ -240,11 +244,6 @@ when defineSsl:
|
||||
ErrLoadBioStrings()
|
||||
OpenSSL_add_all_algorithms()
|
||||
|
||||
type
|
||||
SslContextExtraInternal = ref object of RootRef
|
||||
serverGetPskFunc: SslServerGetPskFunc
|
||||
clientGetPskFunc: SslClientGetPskFunc
|
||||
|
||||
proc raiseSSLError*(s = "") =
|
||||
## Raises a new SSL error.
|
||||
if s != "":
|
||||
@@ -257,12 +256,6 @@ when defineSsl:
|
||||
var errStr = ErrErrorString(err, nil)
|
||||
raise newException(SSLError, $errStr)
|
||||
|
||||
proc getExtraDataIndex*(ctx: SSLContext): int =
|
||||
## Retrieves unique index for storing extra data in SSLContext.
|
||||
result = SSL_CTX_get_ex_new_index(0, nil, nil, nil, nil).int
|
||||
if result < 0:
|
||||
raiseSSLError()
|
||||
|
||||
proc getExtraData*(ctx: SSLContext, index: int): RootRef =
|
||||
## Retrieves arbitrary data stored inside SSLContext.
|
||||
if index notin ctx.referencedData:
|
||||
@@ -347,15 +340,11 @@ when defineSsl:
|
||||
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
|
||||
newCTX.loadCertificates(certFile, keyFile)
|
||||
|
||||
result = SSLContext(context: newCTX, extraInternalIndex: 0,
|
||||
referencedData: initSet[int]())
|
||||
result.extraInternalIndex = getExtraDataIndex(result)
|
||||
|
||||
let extraInternal = new(SslContextExtraInternal)
|
||||
result.setExtraData(result.extraInternalIndex, extraInternal)
|
||||
result = SSLContext(context: newCTX, referencedData: initSet[int](),
|
||||
extraInternal: new(SslContextExtraInternal))
|
||||
|
||||
proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
|
||||
return SslContextExtraInternal(ctx.getExtraData(ctx.extraInternalIndex))
|
||||
return ctx.extraInternal
|
||||
|
||||
proc destroyContext*(ctx: SSLContext) =
|
||||
## Free memory referenced by SSLContext.
|
||||
@@ -379,7 +368,7 @@ when defineSsl:
|
||||
|
||||
proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
|
||||
max_psk_len: cuint): cuint {.cdecl.} =
|
||||
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
|
||||
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
|
||||
let hintString = if hint == nil: nil else: $hint
|
||||
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
|
||||
if psk.len.cuint > max_psk_len:
|
||||
@@ -398,8 +387,6 @@ when defineSsl:
|
||||
##
|
||||
## Only used in PSK ciphersuites.
|
||||
ctx.getExtraInternal().clientGetPskFunc = fun
|
||||
assert ctx.extraInternalIndex == 0,
|
||||
"The pskClientCallback assumes the extraInternalIndex is 0"
|
||||
ctx.context.SSL_CTX_set_psk_client_callback(
|
||||
if fun == nil: nil else: pskClientCallback)
|
||||
|
||||
@@ -407,7 +394,7 @@ when defineSsl:
|
||||
return ctx.getExtraInternal().serverGetPskFunc
|
||||
|
||||
proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
|
||||
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX, extraInternalIndex: 0)
|
||||
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
|
||||
let pskString = (ctx.serverGetPskFunc)($identity)
|
||||
if psk.len.cint > max_psk_len:
|
||||
return 0
|
||||
|
||||
@@ -37,6 +37,8 @@ else:
|
||||
DLLUtilName = "libcrypto.so" & versions
|
||||
from posix import SocketHandle
|
||||
|
||||
import dynlib
|
||||
|
||||
type
|
||||
SslStruct {.final, pure.} = object
|
||||
SslPtr* = ptr SslStruct
|
||||
@@ -185,16 +187,74 @@ const
|
||||
BIO_C_DO_STATE_MACHINE = 101
|
||||
BIO_C_GET_SSL = 110
|
||||
|
||||
proc SSL_library_init*(): cInt{.cdecl, dynlib: DLLSSLName, importc, discardable.}
|
||||
proc SSL_load_error_strings*(){.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLUtilName, importc.}
|
||||
|
||||
proc SSLv23_client_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSLv23_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSLv2_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSLv3_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc TLSv1_method*(): PSSL_METHOD{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
|
||||
when compileOption("dynlibOverride", "ssl"):
|
||||
proc SSL_library_init*(): cint {.cdecl, dynlib: DLLSSLName, importc, discardable.}
|
||||
proc SSL_load_error_strings*() {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSLv23_client_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
|
||||
proc SSLv23_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSLv2_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSLv3_method*(): PSSL_METHOD {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
|
||||
template OpenSSL_add_all_algorithms*() = discard
|
||||
else:
|
||||
# Here we're trying to stay compatible with openssl 1.0.* and 1.1.*. Some
|
||||
# symbols are loaded dynamically and we don't use them if not found.
|
||||
proc thisModule(): LibHandle {.inline.} =
|
||||
var thisMod {.global.}: LibHandle
|
||||
if thisMod.isNil: thisMod = loadLib()
|
||||
result = thisMod
|
||||
|
||||
proc sslModule(): LibHandle {.inline.} =
|
||||
var sslMod {.global.}: LibHandle
|
||||
if sslMod.isNil: sslMod = loadLibPattern(DLLSSLName)
|
||||
result = sslMod
|
||||
|
||||
proc sslSym(name: string): pointer =
|
||||
var dl = thisModule()
|
||||
if not dl.isNil:
|
||||
result = symAddr(dl, name)
|
||||
if result.isNil:
|
||||
dl = sslModule()
|
||||
if not dl.isNil:
|
||||
result = symAddr(dl, name)
|
||||
|
||||
proc SSL_library_init*(): cint {.discardable.} =
|
||||
let theProc = cast[proc(): cint {.cdecl.}](sslSym("SSL_library_init"))
|
||||
if not theProc.isNil: result = theProc()
|
||||
|
||||
proc SSL_load_error_strings*() =
|
||||
let theProc = cast[proc() {.cdecl.}](sslSym("SSL_load_error_strings"))
|
||||
if not theProc.isNil: theProc()
|
||||
|
||||
proc SSLv23_client_method*(): PSSL_METHOD =
|
||||
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv23_client_method"))
|
||||
if not theProc.isNil: result = theProc()
|
||||
else: result = TLSv1_method()
|
||||
|
||||
proc SSLv23_method*(): PSSL_METHOD =
|
||||
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv23_method"))
|
||||
if not theProc.isNil: result = theProc()
|
||||
else: result = TLSv1_method()
|
||||
|
||||
proc SSLv2_method*(): PSSL_METHOD =
|
||||
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv2_method"))
|
||||
if not theProc.isNil: result = theProc()
|
||||
else: result = TLSv1_method()
|
||||
|
||||
proc SSLv3_method*(): PSSL_METHOD =
|
||||
let theProc = cast[proc(): PSSL_METHOD {.cdecl, gcsafe.}](sslSym("SSLv3_method"))
|
||||
if not theProc.isNil: result = theProc()
|
||||
else: result = TLSv1_method()
|
||||
|
||||
proc OpenSSL_add_all_algorithms*() =
|
||||
let theProc = cast[proc() {.cdecl.}](sslSym("OPENSSL_add_all_algorithms_conf"))
|
||||
if not theProc.isNil: theProc()
|
||||
|
||||
proc ERR_load_BIO_strings*(){.cdecl, dynlib: DLLUtilName, importc.}
|
||||
|
||||
proc SSL_new*(context: SslCtx): SslPtr{.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSL_free*(ssl: SslPtr){.cdecl, dynlib: DLLSSLName, importc.}
|
||||
proc SSL_get_SSL_CTX*(ssl: SslPtr): SslCtx {.cdecl, dynlib: DLLSSLName, importc.}
|
||||
@@ -261,11 +321,6 @@ proc ERR_error_string*(e: cInt, buf: cstring): cstring{.cdecl,
|
||||
proc ERR_get_error*(): cInt{.cdecl, dynlib: DLLUtilName, importc.}
|
||||
proc ERR_peek_last_error*(): cInt{.cdecl, dynlib: DLLUtilName, importc.}
|
||||
|
||||
when defined(android):
|
||||
template OpenSSL_add_all_algorithms*() = discard
|
||||
else:
|
||||
proc OpenSSL_add_all_algorithms*(){.cdecl, dynlib: DLLUtilName, importc: "OPENSSL_add_all_algorithms_conf".}
|
||||
|
||||
proc OPENSSL_config*(configName: cstring){.cdecl, dynlib: DLLSSLName, importc.}
|
||||
|
||||
when not useWinVersion and not defined(macosx) and not defined(android):
|
||||
|
||||
Reference in New Issue
Block a user