diff --git a/lib/pure/asyncnet.nim b/lib/pure/asyncnet.nim index fb37afa427..76bacb162e 100644 --- a/lib/pure/asyncnet.nim +++ b/lib/pure/asyncnet.nim @@ -126,8 +126,6 @@ type when defineSsl: sslHandle: SslPtr sslContext: SslContext - bioIn: BIO - bioOut: BIO sslNoShutdown: bool domain: Domain sockType: SockType @@ -210,7 +208,7 @@ when defineSsl: proc raiseSslHandleError = raiseSSLError("The SSL Handle is closed/unset") - proc getSslError(socket: AsyncSocket, err: cint): cint = + proc getSslError(socket: AsyncSocket, flags: set[SocketFlag], err: cint): cint = assert socket.isSsl assert err < 0 var ret = SSL_get_error(socket.sslHandle, err.cint) @@ -223,47 +221,49 @@ when defineSsl: return ret of SSL_ERROR_WANT_X509_LOOKUP: raiseSSLError("Function for x509 lookup has been called.") - of SSL_ERROR_SYSCALL, SSL_ERROR_SSL: + of SSL_ERROR_SYSCALL: + socket.sslNoShutdown = true + let osErr = osLastError() + if not flags.isDisconnectionError(osErr): + var errStr = "IO error has occurred" + let sslErr = ERR_peek_last_error() + if sslErr == 0 and err == 0: + errStr.add ' ' + errStr.add "because an EOF was observed that violates the protocol" + elif sslErr == 0 and err == -1: + errStr.add ' ' + errStr.add "in the BIO layer" + else: + let errStr = $ERR_error_string(sslErr, nil) + raiseSSLError(errStr & ": " & errStr) + raiseOSError(osErr, errStr) + else: + return ret + of SSL_ERROR_SSL: socket.sslNoShutdown = true raiseSSLError() else: raiseSSLError("Unknown Error") - proc sendPendingSslData(socket: AsyncSocket, - flags: set[SocketFlag]) {.async.} = - if socket.sslHandle == nil: - raiseSslHandleError() - let len = bioCtrlPending(socket.bioOut) - if len > 0: - var data = newString(len) - let read = bioRead(socket.bioOut, cast[cstring](addr data[0]), len) - assert read != 0 - if read < 0: - raiseSSLError() - data.setLen(read) - await socket.fd.AsyncFD.send(data, flags) - - proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag], - sslError: cint): owned(Future[bool]) {.async.} = + proc handleSslFailure(socket: AsyncSocket, flags: set[SocketFlag], sslError: cint): Future[bool] = ## Returns `true` if `socket` is still connected, otherwise `false`. - result = true + let retFut = newFuture[bool]("asyncnet.handleSslFailure") case sslError - of SSL_ERROR_WANT_WRITE: - await sendPendingSslData(socket, flags) + of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT: + addWrite(socket.fd.AsyncFD, proc (sock: AsyncFD): bool = + retFut.complete(true) + return true + ) of SSL_ERROR_WANT_READ: - var data = await recv(socket.fd.AsyncFD, BufferSize, flags) - if socket.sslHandle == nil: - raiseSslHandleError() - let length = len(data) - if length > 0: - let ret = bioWrite(socket.bioIn, cast[cstring](addr data[0]), length.cint) - if ret < 0: - raiseSSLError() - elif length == 0: - # connection not properly closed by remote side or connection dropped - SSL_set_shutdown(socket.sslHandle, SSL_RECEIVED_SHUTDOWN) - result = false + addRead(socket.fd.AsyncFD, proc (sock: AsyncFD): bool = + retFut.complete(true) + return true + ) + of SSL_ERROR_SYSCALL: + assert flags.isDisconnectionError(osLastError()) + retFut.complete(false) else: - raiseSSLError("Cannot appease SSL.") + raiseSSLError("Cannot handle SSL failure.") + return retFut template sslLoop(socket: AsyncSocket, flags: set[SocketFlag], op: untyped) = @@ -274,20 +274,12 @@ when defineSsl: ErrClearError() # Call the desired operation. opResult = op - let err = - if opResult < 0: - getSslError(socket, opResult.cint) - else: - SSL_ERROR_NONE - # Send any remaining pending SSL data. - await sendPendingSslData(socket, flags) - # If the operation failed, try to see if SSL has some data to read # or write. if opResult < 0: - let fut = appeaseSsl(socket, flags, err.cint) - yield fut - if not fut.read(): + let err = getSslError(socket, flags, opResult.cint) + let connected = await handleSslFailure(socket, flags, err.cint) + if not connected: # Socket disconnected. if SocketFlag.SafeDisconn in flags: opResult = 0.cint @@ -323,8 +315,7 @@ proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} = discard SSL_set_tlsext_host_name(socket.sslHandle, address) let flags = {SocketFlag.SafeDisconn} - sslSetConnectState(socket.sslHandle) - sslLoop(socket, flags, sslDoHandshake(socket.sslHandle)) + sslLoop(socket, flags, SSL_connect(socket.sslHandle)) template readInto(buf: pointer, size: int, socket: AsyncSocket, flags: set[SocketFlag]): int = @@ -461,7 +452,6 @@ proc send*(socket: AsyncSocket, buf: pointer, size: int, when defineSsl: sslLoop(socket, flags, sslWrite(socket.sslHandle, cast[cstring](buf), size.cint)) - await sendPendingSslData(socket, flags) else: await send(socket.fd.AsyncFD, buf, size, flags) @@ -475,52 +465,9 @@ proc send*(socket: AsyncSocket, data: string, var copy = data sslLoop(socket, flags, sslWrite(socket.sslHandle, cast[cstring](addr copy[0]), copy.len.cint)) - await sendPendingSslData(socket, flags) else: await send(socket.fd.AsyncFD, data, flags) -proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}, - inheritable = defined(nimInheritHandles)): - owned(Future[tuple[address: string, client: AsyncSocket]]) = - ## Accepts a new connection. Returns a future containing the client socket - ## corresponding to that connection and the remote address of the client. - ## - ## If `inheritable` is false (the default), the resulting client socket will - ## not be inheritable by child processes. - ## - ## The future will complete when the connection is successfully accepted. - var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr") - var fut = acceptAddr(socket.fd.AsyncFD, flags, inheritable) - fut.callback = - proc (future: Future[tuple[address: string, client: AsyncFD]]) = - assert future.finished - if future.failed: - retFuture.fail(future.readError) - else: - let resultTup = (future.read.address, - newAsyncSocket(future.read.client, socket.domain, - socket.sockType, socket.protocol, socket.isBuffered, inheritable)) - retFuture.complete(resultTup) - return retFuture - -proc accept*(socket: AsyncSocket, - flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) = - ## Accepts a new connection. Returns a future containing the client socket - ## corresponding to that connection. - ## If `inheritable` is false (the default), the resulting client socket will - ## not be inheritable by child processes. - ## The future will complete when the connection is successfully accepted. - var retFut = newFuture[AsyncSocket]("asyncnet.accept") - var fut = acceptAddr(socket, flags) - fut.callback = - proc (future: Future[tuple[address: string, client: AsyncSocket]]) = - assert future.finished - if future.failed: - retFut.fail(future.readError) - else: - retFut.complete(future.read.client) - return retFut - proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string], flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.async.} = ## Reads a line of data from `socket` into `resString`. @@ -776,9 +723,8 @@ when defineSsl: if socket.sslHandle == nil: raiseSSLError() - socket.bioIn = bioNew(bioSMem()) - socket.bioOut = bioNew(bioSMem()) - sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut) + if SSL_set_fd(socket.sslHandle, socket.fd) != 1: + raiseSSLError() socket.sslNoShutdown = true @@ -795,6 +741,8 @@ when defineSsl: ## ## **Disclaimer**: This code is not well tested, may be very unsafe and ## prone to security vulnerabilities. + if socket.isSsl: + return wrapSocket(ctx, socket) case handshake @@ -818,6 +766,48 @@ when defineSsl: else: result = getPeerCertificates(socket.sslHandle) +proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn}, + inheritable = defined(nimInheritHandles)): + owned(Future[tuple[address: string, client: AsyncSocket]]) {.async.} = + ## Accepts a new connection. Returns a future containing the client socket + ## corresponding to that connection and the remote address of the client. + ## + ## If `inheritable` is false (the default), the resulting client socket will + ## not be inheritable by child processes. + ## + ## The future will complete when the connection is successfully accepted. + let (address, fd) = await acceptAddr(socket.fd.AsyncFD, flags, inheritable) + let client = newAsyncSocket(fd, socket.domain, socket.sockType, + socket.protocol, socket.isBuffered, inheritable) + result = (address, client) + if socket.isSsl: + when defineSsl: + if socket.sslContext == nil: + raiseSSLError("The SSL Context is closed/unset") + wrapSocket(socket.sslContext, result.client) + if result.client.sslHandle == nil: + raiseSslHandleError() + let flags = {SocketFlag.SafeDisconn} + sslLoop(result.client, flags, SSL_accept(result.client.sslHandle)) + +proc accept*(socket: AsyncSocket, + flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) = + ## Accepts a new connection. Returns a future containing the client socket + ## corresponding to that connection. + ## If `inheritable` is false (the default), the resulting client socket will + ## not be inheritable by child processes. + ## The future will complete when the connection is successfully accepted. + var retFut = newFuture[AsyncSocket]("asyncnet.accept") + var fut = acceptAddr(socket, flags) + fut.callback = + proc (future: Future[tuple[address: string, client: AsyncSocket]]) = + assert future.finished + if future.failed: + retFut.fail(future.readError) + else: + retFut.complete(future.read.client) + return retFut + proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {. tags: [ReadIOEffect].} = ## Retrieves option `opt` as a boolean value. diff --git a/tests/async/t24895.nim b/tests/async/t24895.nim new file mode 100644 index 0000000000..56d0d1268c --- /dev/null +++ b/tests/async/t24895.nim @@ -0,0 +1,79 @@ +discard """ + cmd: "nim $target --hints:on --define:ssl $options $file" +""" + +{.define: ssl.} + +import std/[asyncdispatch, asyncnet, net, openssl] + +var port0: Port +var checked = 0 + +proc server {.async.} = + let sock = newAsyncSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, buffered = true) + doAssert sock != nil + defer: sock.close() + let sslCtx = newContext( + protSSLv23, + verifyMode = CVerifyNone, + certFile = "tests/testdata/mycert.pem", + keyFile = "tests/testdata/mycert.pem" + ) + doAssert sslCtx != nil + defer: sslCtx.destroyContext() + wrapSocket(sslCtx, sock) + #sock.bindAddr(Port 8181) + sock.bindAddr() + port0 = getLocalAddr(sock)[1] + sock.listen() + echo "accept" + let clientSocket = await sock.accept() + defer: clientSocket.close() + wrapConnectedSocket( + sslCtx, clientSocket, handshakeAsServer, "localhost" + ) + let sdata = "x" & newString(41) + let sfut = clientSocket.send(sdata) + let rdata = newString(42) + let rfut = clientSocket.recvInto(addr rdata[0], rdata.len) + echo "send" + await sfut + echo "recv" + let rLen = await rfut # it hang here until the client closes the connection or sends more data + doAssert rLen == 42, $rLen + doAssert rdata[0] == 'x', $rdata[0] + echo "ok" + inc checked + +proc client {.async.} = + let sock = newAsyncSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, buffered = true) + doAssert sock != nil + defer: sock.close() + let sslCtx = newContext( + protSSLv23, + verifyMode = CVerifyNone + ) + doAssert sslCtx != nil + defer: sslCtx.destroyContext() + wrapSocket(sslCtx, sock) + #await sock.connect("127.0.0.1", Port 8181) + await sock.connect("localhost", port0) + let sdata = "x" & newString(41) + echo "send" + await sock.send(sdata) + let rdata = newString(42) + echo "recv" + let rLen = await sock.recvInto(addr rdata[0], rdata.len) + doAssert rLen == 42, $rLen + doAssert rdata[0] == 'x', $rdata[0] + #await sleepAsync(10_000) + #await sock.send("x") + echo "ok" + inc checked + +discard getGlobalDispatcher() +let serverFut = server() +waitFor client() +waitFor serverFut +doAssert checked == 2 +doAssert not hasPendingOperations()