mirror of
https://github.com/nim-lang/Nim.git
synced 2025-12-28 17:04:41 +00:00
asyncnet ssl overhaul (#24896)
Fixes #24895 - Remove all bio handling - Remove all `sendPendingSslData` which only seems to make things work by chance - Wrap the client socket on `acceptAddr` (std/net does this) - Do the SSL handshake on accept (std/net does this) The only concern is if addWrite/addRead works well on Windows.
This commit is contained in:
committed by
GitHub
parent
d7b1f0a99a
commit
8518cf079f
@@ -126,8 +126,6 @@ type
|
||||
when defineSsl:
|
||||
sslHandle: SslPtr
|
||||
sslContext: SslContext
|
||||
bioIn: BIO
|
||||
bioOut: BIO
|
||||
sslNoShutdown: bool
|
||||
domain: Domain
|
||||
sockType: SockType
|
||||
@@ -210,7 +208,7 @@ when defineSsl:
|
||||
proc raiseSslHandleError =
|
||||
raiseSSLError("The SSL Handle is closed/unset")
|
||||
|
||||
proc getSslError(socket: AsyncSocket, err: cint): cint =
|
||||
proc getSslError(socket: AsyncSocket, flags: set[SocketFlag], err: cint): cint =
|
||||
assert socket.isSsl
|
||||
assert err < 0
|
||||
var ret = SSL_get_error(socket.sslHandle, err.cint)
|
||||
@@ -223,47 +221,49 @@ when defineSsl:
|
||||
return ret
|
||||
of SSL_ERROR_WANT_X509_LOOKUP:
|
||||
raiseSSLError("Function for x509 lookup has been called.")
|
||||
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
|
||||
of SSL_ERROR_SYSCALL:
|
||||
socket.sslNoShutdown = true
|
||||
let osErr = osLastError()
|
||||
if not flags.isDisconnectionError(osErr):
|
||||
var errStr = "IO error has occurred"
|
||||
let sslErr = ERR_peek_last_error()
|
||||
if sslErr == 0 and err == 0:
|
||||
errStr.add ' '
|
||||
errStr.add "because an EOF was observed that violates the protocol"
|
||||
elif sslErr == 0 and err == -1:
|
||||
errStr.add ' '
|
||||
errStr.add "in the BIO layer"
|
||||
else:
|
||||
let errStr = $ERR_error_string(sslErr, nil)
|
||||
raiseSSLError(errStr & ": " & errStr)
|
||||
raiseOSError(osErr, errStr)
|
||||
else:
|
||||
return ret
|
||||
of SSL_ERROR_SSL:
|
||||
socket.sslNoShutdown = true
|
||||
raiseSSLError()
|
||||
else: raiseSSLError("Unknown Error")
|
||||
|
||||
proc sendPendingSslData(socket: AsyncSocket,
|
||||
flags: set[SocketFlag]) {.async.} =
|
||||
if socket.sslHandle == nil:
|
||||
raiseSslHandleError()
|
||||
let len = bioCtrlPending(socket.bioOut)
|
||||
if len > 0:
|
||||
var data = newString(len)
|
||||
let read = bioRead(socket.bioOut, cast[cstring](addr data[0]), len)
|
||||
assert read != 0
|
||||
if read < 0:
|
||||
raiseSSLError()
|
||||
data.setLen(read)
|
||||
await socket.fd.AsyncFD.send(data, flags)
|
||||
|
||||
proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag],
|
||||
sslError: cint): owned(Future[bool]) {.async.} =
|
||||
proc handleSslFailure(socket: AsyncSocket, flags: set[SocketFlag], sslError: cint): Future[bool] =
|
||||
## Returns `true` if `socket` is still connected, otherwise `false`.
|
||||
result = true
|
||||
let retFut = newFuture[bool]("asyncnet.handleSslFailure")
|
||||
case sslError
|
||||
of SSL_ERROR_WANT_WRITE:
|
||||
await sendPendingSslData(socket, flags)
|
||||
of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
|
||||
addWrite(socket.fd.AsyncFD, proc (sock: AsyncFD): bool =
|
||||
retFut.complete(true)
|
||||
return true
|
||||
)
|
||||
of SSL_ERROR_WANT_READ:
|
||||
var data = await recv(socket.fd.AsyncFD, BufferSize, flags)
|
||||
if socket.sslHandle == nil:
|
||||
raiseSslHandleError()
|
||||
let length = len(data)
|
||||
if length > 0:
|
||||
let ret = bioWrite(socket.bioIn, cast[cstring](addr data[0]), length.cint)
|
||||
if ret < 0:
|
||||
raiseSSLError()
|
||||
elif length == 0:
|
||||
# connection not properly closed by remote side or connection dropped
|
||||
SSL_set_shutdown(socket.sslHandle, SSL_RECEIVED_SHUTDOWN)
|
||||
result = false
|
||||
addRead(socket.fd.AsyncFD, proc (sock: AsyncFD): bool =
|
||||
retFut.complete(true)
|
||||
return true
|
||||
)
|
||||
of SSL_ERROR_SYSCALL:
|
||||
assert flags.isDisconnectionError(osLastError())
|
||||
retFut.complete(false)
|
||||
else:
|
||||
raiseSSLError("Cannot appease SSL.")
|
||||
raiseSSLError("Cannot handle SSL failure.")
|
||||
return retFut
|
||||
|
||||
template sslLoop(socket: AsyncSocket, flags: set[SocketFlag],
|
||||
op: untyped) =
|
||||
@@ -274,20 +274,12 @@ when defineSsl:
|
||||
ErrClearError()
|
||||
# Call the desired operation.
|
||||
opResult = op
|
||||
let err =
|
||||
if opResult < 0:
|
||||
getSslError(socket, opResult.cint)
|
||||
else:
|
||||
SSL_ERROR_NONE
|
||||
# Send any remaining pending SSL data.
|
||||
await sendPendingSslData(socket, flags)
|
||||
|
||||
# If the operation failed, try to see if SSL has some data to read
|
||||
# or write.
|
||||
if opResult < 0:
|
||||
let fut = appeaseSsl(socket, flags, err.cint)
|
||||
yield fut
|
||||
if not fut.read():
|
||||
let err = getSslError(socket, flags, opResult.cint)
|
||||
let connected = await handleSslFailure(socket, flags, err.cint)
|
||||
if not connected:
|
||||
# Socket disconnected.
|
||||
if SocketFlag.SafeDisconn in flags:
|
||||
opResult = 0.cint
|
||||
@@ -323,8 +315,7 @@ proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} =
|
||||
discard SSL_set_tlsext_host_name(socket.sslHandle, address)
|
||||
|
||||
let flags = {SocketFlag.SafeDisconn}
|
||||
sslSetConnectState(socket.sslHandle)
|
||||
sslLoop(socket, flags, sslDoHandshake(socket.sslHandle))
|
||||
sslLoop(socket, flags, SSL_connect(socket.sslHandle))
|
||||
|
||||
template readInto(buf: pointer, size: int, socket: AsyncSocket,
|
||||
flags: set[SocketFlag]): int =
|
||||
@@ -461,7 +452,6 @@ proc send*(socket: AsyncSocket, buf: pointer, size: int,
|
||||
when defineSsl:
|
||||
sslLoop(socket, flags,
|
||||
sslWrite(socket.sslHandle, cast[cstring](buf), size.cint))
|
||||
await sendPendingSslData(socket, flags)
|
||||
else:
|
||||
await send(socket.fd.AsyncFD, buf, size, flags)
|
||||
|
||||
@@ -475,52 +465,9 @@ proc send*(socket: AsyncSocket, data: string,
|
||||
var copy = data
|
||||
sslLoop(socket, flags,
|
||||
sslWrite(socket.sslHandle, cast[cstring](addr copy[0]), copy.len.cint))
|
||||
await sendPendingSslData(socket, flags)
|
||||
else:
|
||||
await send(socket.fd.AsyncFD, data, flags)
|
||||
|
||||
proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn},
|
||||
inheritable = defined(nimInheritHandles)):
|
||||
owned(Future[tuple[address: string, client: AsyncSocket]]) =
|
||||
## Accepts a new connection. Returns a future containing the client socket
|
||||
## corresponding to that connection and the remote address of the client.
|
||||
##
|
||||
## If `inheritable` is false (the default), the resulting client socket will
|
||||
## not be inheritable by child processes.
|
||||
##
|
||||
## The future will complete when the connection is successfully accepted.
|
||||
var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr")
|
||||
var fut = acceptAddr(socket.fd.AsyncFD, flags, inheritable)
|
||||
fut.callback =
|
||||
proc (future: Future[tuple[address: string, client: AsyncFD]]) =
|
||||
assert future.finished
|
||||
if future.failed:
|
||||
retFuture.fail(future.readError)
|
||||
else:
|
||||
let resultTup = (future.read.address,
|
||||
newAsyncSocket(future.read.client, socket.domain,
|
||||
socket.sockType, socket.protocol, socket.isBuffered, inheritable))
|
||||
retFuture.complete(resultTup)
|
||||
return retFuture
|
||||
|
||||
proc accept*(socket: AsyncSocket,
|
||||
flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) =
|
||||
## Accepts a new connection. Returns a future containing the client socket
|
||||
## corresponding to that connection.
|
||||
## If `inheritable` is false (the default), the resulting client socket will
|
||||
## not be inheritable by child processes.
|
||||
## The future will complete when the connection is successfully accepted.
|
||||
var retFut = newFuture[AsyncSocket]("asyncnet.accept")
|
||||
var fut = acceptAddr(socket, flags)
|
||||
fut.callback =
|
||||
proc (future: Future[tuple[address: string, client: AsyncSocket]]) =
|
||||
assert future.finished
|
||||
if future.failed:
|
||||
retFut.fail(future.readError)
|
||||
else:
|
||||
retFut.complete(future.read.client)
|
||||
return retFut
|
||||
|
||||
proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
|
||||
flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.async.} =
|
||||
## Reads a line of data from `socket` into `resString`.
|
||||
@@ -776,9 +723,8 @@ when defineSsl:
|
||||
if socket.sslHandle == nil:
|
||||
raiseSSLError()
|
||||
|
||||
socket.bioIn = bioNew(bioSMem())
|
||||
socket.bioOut = bioNew(bioSMem())
|
||||
sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
|
||||
if SSL_set_fd(socket.sslHandle, socket.fd) != 1:
|
||||
raiseSSLError()
|
||||
|
||||
socket.sslNoShutdown = true
|
||||
|
||||
@@ -795,6 +741,8 @@ when defineSsl:
|
||||
##
|
||||
## **Disclaimer**: This code is not well tested, may be very unsafe and
|
||||
## prone to security vulnerabilities.
|
||||
if socket.isSsl:
|
||||
return
|
||||
wrapSocket(ctx, socket)
|
||||
|
||||
case handshake
|
||||
@@ -818,6 +766,48 @@ when defineSsl:
|
||||
else:
|
||||
result = getPeerCertificates(socket.sslHandle)
|
||||
|
||||
proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn},
|
||||
inheritable = defined(nimInheritHandles)):
|
||||
owned(Future[tuple[address: string, client: AsyncSocket]]) {.async.} =
|
||||
## Accepts a new connection. Returns a future containing the client socket
|
||||
## corresponding to that connection and the remote address of the client.
|
||||
##
|
||||
## If `inheritable` is false (the default), the resulting client socket will
|
||||
## not be inheritable by child processes.
|
||||
##
|
||||
## The future will complete when the connection is successfully accepted.
|
||||
let (address, fd) = await acceptAddr(socket.fd.AsyncFD, flags, inheritable)
|
||||
let client = newAsyncSocket(fd, socket.domain, socket.sockType,
|
||||
socket.protocol, socket.isBuffered, inheritable)
|
||||
result = (address, client)
|
||||
if socket.isSsl:
|
||||
when defineSsl:
|
||||
if socket.sslContext == nil:
|
||||
raiseSSLError("The SSL Context is closed/unset")
|
||||
wrapSocket(socket.sslContext, result.client)
|
||||
if result.client.sslHandle == nil:
|
||||
raiseSslHandleError()
|
||||
let flags = {SocketFlag.SafeDisconn}
|
||||
sslLoop(result.client, flags, SSL_accept(result.client.sslHandle))
|
||||
|
||||
proc accept*(socket: AsyncSocket,
|
||||
flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) =
|
||||
## Accepts a new connection. Returns a future containing the client socket
|
||||
## corresponding to that connection.
|
||||
## If `inheritable` is false (the default), the resulting client socket will
|
||||
## not be inheritable by child processes.
|
||||
## The future will complete when the connection is successfully accepted.
|
||||
var retFut = newFuture[AsyncSocket]("asyncnet.accept")
|
||||
var fut = acceptAddr(socket, flags)
|
||||
fut.callback =
|
||||
proc (future: Future[tuple[address: string, client: AsyncSocket]]) =
|
||||
assert future.finished
|
||||
if future.failed:
|
||||
retFut.fail(future.readError)
|
||||
else:
|
||||
retFut.complete(future.read.client)
|
||||
return retFut
|
||||
|
||||
proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {.
|
||||
tags: [ReadIOEffect].} =
|
||||
## Retrieves option `opt` as a boolean value.
|
||||
|
||||
79
tests/async/t24895.nim
Normal file
79
tests/async/t24895.nim
Normal file
@@ -0,0 +1,79 @@
|
||||
discard """
|
||||
cmd: "nim $target --hints:on --define:ssl $options $file"
|
||||
"""
|
||||
|
||||
{.define: ssl.}
|
||||
|
||||
import std/[asyncdispatch, asyncnet, net, openssl]
|
||||
|
||||
var port0: Port
|
||||
var checked = 0
|
||||
|
||||
proc server {.async.} =
|
||||
let sock = newAsyncSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, buffered = true)
|
||||
doAssert sock != nil
|
||||
defer: sock.close()
|
||||
let sslCtx = newContext(
|
||||
protSSLv23,
|
||||
verifyMode = CVerifyNone,
|
||||
certFile = "tests/testdata/mycert.pem",
|
||||
keyFile = "tests/testdata/mycert.pem"
|
||||
)
|
||||
doAssert sslCtx != nil
|
||||
defer: sslCtx.destroyContext()
|
||||
wrapSocket(sslCtx, sock)
|
||||
#sock.bindAddr(Port 8181)
|
||||
sock.bindAddr()
|
||||
port0 = getLocalAddr(sock)[1]
|
||||
sock.listen()
|
||||
echo "accept"
|
||||
let clientSocket = await sock.accept()
|
||||
defer: clientSocket.close()
|
||||
wrapConnectedSocket(
|
||||
sslCtx, clientSocket, handshakeAsServer, "localhost"
|
||||
)
|
||||
let sdata = "x" & newString(41)
|
||||
let sfut = clientSocket.send(sdata)
|
||||
let rdata = newString(42)
|
||||
let rfut = clientSocket.recvInto(addr rdata[0], rdata.len)
|
||||
echo "send"
|
||||
await sfut
|
||||
echo "recv"
|
||||
let rLen = await rfut # it hang here until the client closes the connection or sends more data
|
||||
doAssert rLen == 42, $rLen
|
||||
doAssert rdata[0] == 'x', $rdata[0]
|
||||
echo "ok"
|
||||
inc checked
|
||||
|
||||
proc client {.async.} =
|
||||
let sock = newAsyncSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, buffered = true)
|
||||
doAssert sock != nil
|
||||
defer: sock.close()
|
||||
let sslCtx = newContext(
|
||||
protSSLv23,
|
||||
verifyMode = CVerifyNone
|
||||
)
|
||||
doAssert sslCtx != nil
|
||||
defer: sslCtx.destroyContext()
|
||||
wrapSocket(sslCtx, sock)
|
||||
#await sock.connect("127.0.0.1", Port 8181)
|
||||
await sock.connect("localhost", port0)
|
||||
let sdata = "x" & newString(41)
|
||||
echo "send"
|
||||
await sock.send(sdata)
|
||||
let rdata = newString(42)
|
||||
echo "recv"
|
||||
let rLen = await sock.recvInto(addr rdata[0], rdata.len)
|
||||
doAssert rLen == 42, $rLen
|
||||
doAssert rdata[0] == 'x', $rdata[0]
|
||||
#await sleepAsync(10_000)
|
||||
#await sock.send("x")
|
||||
echo "ok"
|
||||
inc checked
|
||||
|
||||
discard getGlobalDispatcher()
|
||||
let serverFut = server()
|
||||
waitFor client()
|
||||
waitFor serverFut
|
||||
doAssert checked == 2
|
||||
doAssert not hasPendingOperations()
|
||||
Reference in New Issue
Block a user