mirror of
https://github.com/nim-lang/Nim.git
synced 2026-04-20 14:25:23 +00:00
asyncnet, net: don't attempt SSL_shutdown if a fatal error occurred (#15066)
* asyncnet, net: don't attempt SSL_shutdown if a fatal error occurred Per TLS standard and SSL_shutdown(3ssl). This should prevent errors coming from a close() after a bad event (ie. the other end of the pipe is closed before shutdown can be negotiated). Ref #9867 * tssl: try sending until an error occur * tssl: cleanup * tssl: actually run the test I forgot to make the test run :P * tssl: run the test on ARC, maybe then it'll be happy * tssl: turns off ARC, switch tlsEmulation on for freebsd * tssl: document why tlsEmulation is employed * net: move SafeDisconn handling logic to socketError
This commit is contained in:
@@ -123,6 +123,7 @@ type
|
||||
sslContext: SslContext
|
||||
bioIn: BIO
|
||||
bioOut: BIO
|
||||
sslNoShutdown: bool
|
||||
domain: Domain
|
||||
sockType: SockType
|
||||
protocol: Protocol
|
||||
@@ -200,9 +201,10 @@ proc newAsyncSocket*(domain, sockType, protocol: cint,
|
||||
Protocol(protocol), buffered, inheritable)
|
||||
|
||||
when defineSsl:
|
||||
proc getSslError(handle: SslPtr, err: cint): cint =
|
||||
proc getSslError(socket: AsyncSocket, err: cint): cint =
|
||||
assert socket.isSsl
|
||||
assert err < 0
|
||||
var ret = SSL_get_error(handle, err.cint)
|
||||
var ret = SSL_get_error(socket.sslHandle, err.cint)
|
||||
case ret
|
||||
of SSL_ERROR_ZERO_RETURN:
|
||||
raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
|
||||
@@ -213,6 +215,7 @@ when defineSsl:
|
||||
of SSL_ERROR_WANT_X509_LOOKUP:
|
||||
raiseSSLError("Function for x509 lookup has been called.")
|
||||
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
|
||||
socket.sslNoShutdown = true
|
||||
raiseSSLError()
|
||||
else: raiseSSLError("Unknown Error")
|
||||
|
||||
@@ -265,7 +268,7 @@ when defineSsl:
|
||||
# If the operation failed, try to see if SSL has some data to read
|
||||
# or write.
|
||||
if opResult < 0:
|
||||
let err = getSslError(socket.sslHandle, opResult.cint)
|
||||
let err = getSslError(socket, opResult.cint)
|
||||
let fut = appeaseSsl(socket, flags, err.cint)
|
||||
yield fut
|
||||
if not fut.read():
|
||||
@@ -718,7 +721,7 @@ proc close*(socket: AsyncSocket) =
|
||||
# Don't call SSL_shutdown if the connection has not been fully
|
||||
# established, see:
|
||||
# https://github.com/openssl/openssl/issues/710#issuecomment-253897666
|
||||
if SSL_in_init(socket.sslHandle) == 0:
|
||||
if not socket.sslNoShutdown and SSL_in_init(socket.sslHandle) == 0:
|
||||
ErrClearError()
|
||||
SSL_shutdown(socket.sslHandle)
|
||||
else:
|
||||
@@ -747,6 +750,8 @@ when defineSsl:
|
||||
socket.bioOut = bioNew(bioSMem())
|
||||
sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
|
||||
|
||||
socket.sslNoShutdown = true
|
||||
|
||||
proc wrapConnectedSocket*(ctx: SslContext, socket: AsyncSocket,
|
||||
handshake: SslHandshakeType,
|
||||
hostname: string = "") =
|
||||
|
||||
@@ -133,6 +133,7 @@ type
|
||||
sslNoHandshake: bool # True if needs handshake.
|
||||
sslHasPeekChar: bool
|
||||
sslPeekChar: char
|
||||
sslNoShutdown: bool # True if shutdown shouldn't be done.
|
||||
lastError: OSErrorCode ## stores the last error on this socket
|
||||
domain: Domain
|
||||
sockType: SockType
|
||||
@@ -173,7 +174,8 @@ when defined(nimHasStyleChecks):
|
||||
{.pop.}
|
||||
|
||||
proc socketError*(socket: Socket, err: int = -1, async = false,
|
||||
lastError = (-1).OSErrorCode): void {.gcsafe.}
|
||||
lastError = (-1).OSErrorCode,
|
||||
flags: set[SocketFlag] = {}): void {.gcsafe.}
|
||||
|
||||
proc isDisconnectionError*(flags: set[SocketFlag],
|
||||
lastError: OSErrorCode): bool =
|
||||
@@ -722,6 +724,7 @@ when defineSsl:
|
||||
socket.sslHandle = SSL_new(socket.sslContext.context)
|
||||
socket.sslNoHandshake = false
|
||||
socket.sslHasPeekChar = false
|
||||
socket.sslNoShutdown = false
|
||||
if socket.sslHandle == nil:
|
||||
raiseSSLError()
|
||||
|
||||
@@ -818,7 +821,8 @@ proc getSocketError*(socket: Socket): OSErrorCode =
|
||||
raiseOSError(result, "No valid socket error code available")
|
||||
|
||||
proc socketError*(socket: Socket, err: int = -1, async = false,
|
||||
lastError = (-1).OSErrorCode) =
|
||||
lastError = (-1).OSErrorCode,
|
||||
flags: set[SocketFlag] = {}) =
|
||||
## Raises an OSError based on the error code returned by ``SSL_get_error``
|
||||
## (for SSL sockets) and ``osLastError`` otherwise.
|
||||
##
|
||||
@@ -826,6 +830,9 @@ proc socketError*(socket: Socket, err: int = -1, async = false,
|
||||
## error was caused by no data being available to be read.
|
||||
##
|
||||
## If ``err`` is not lower than 0 no exception will be raised.
|
||||
##
|
||||
## If ``flags`` contains ``SafeDisconn``, no exception will be raised
|
||||
## when the error was caused by a peer disconnection.
|
||||
when defineSsl:
|
||||
if socket.isSsl:
|
||||
if err <= 0:
|
||||
@@ -844,33 +851,39 @@ proc socketError*(socket: Socket, err: int = -1, async = false,
|
||||
of SSL_ERROR_WANT_X509_LOOKUP:
|
||||
raiseSSLError("Function for x509 lookup has been called.")
|
||||
of SSL_ERROR_SYSCALL:
|
||||
var errStr = "IO error has occurred "
|
||||
let sslErr = ERR_peek_last_error()
|
||||
if sslErr == 0 and err == 0:
|
||||
errStr.add "because an EOF was observed that violates the protocol"
|
||||
elif sslErr == 0 and err == -1:
|
||||
errStr.add "in the BIO layer"
|
||||
else:
|
||||
let errStr = $ERR_error_string(sslErr, nil)
|
||||
raiseSSLError(errStr & ": " & errStr)
|
||||
# SSL shutdown must not be done if a fatal error occurred.
|
||||
socket.sslNoShutdown = true
|
||||
let osErr = osLastError()
|
||||
raiseOSError(osErr, errStr)
|
||||
if not flags.isDisconnectionError(osErr):
|
||||
var errStr = "IO error has occurred "
|
||||
let sslErr = ERR_peek_last_error()
|
||||
if sslErr == 0 and err == 0:
|
||||
errStr.add "because an EOF was observed that violates the protocol"
|
||||
elif sslErr == 0 and err == -1:
|
||||
errStr.add "in the BIO layer"
|
||||
else:
|
||||
let errStr = $ERR_error_string(sslErr, nil)
|
||||
raiseSSLError(errStr & ": " & errStr)
|
||||
raiseOSError(osErr, errStr)
|
||||
of SSL_ERROR_SSL:
|
||||
# SSL shutdown must not be done if a fatal error occurred.
|
||||
socket.sslNoShutdown = true
|
||||
raiseSSLError()
|
||||
else: raiseSSLError("Unknown Error")
|
||||
|
||||
if err == -1 and not (when defineSsl: socket.isSsl else: false):
|
||||
var lastE = if lastError.int == -1: getSocketError(socket) else: lastError
|
||||
if async:
|
||||
when useWinVersion:
|
||||
if lastE.int32 == WSAEWOULDBLOCK:
|
||||
return
|
||||
else: raiseOSError(lastE)
|
||||
else:
|
||||
if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK:
|
||||
return
|
||||
else: raiseOSError(lastE)
|
||||
else: raiseOSError(lastE)
|
||||
if not flags.isDisconnectionError(lastE):
|
||||
if async:
|
||||
when useWinVersion:
|
||||
if lastE.int32 == WSAEWOULDBLOCK:
|
||||
return
|
||||
else: raiseOSError(lastE)
|
||||
else:
|
||||
if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK:
|
||||
return
|
||||
else: raiseOSError(lastE)
|
||||
else: raiseOSError(lastE)
|
||||
|
||||
proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
|
||||
## Marks ``socket`` as accepting connections.
|
||||
@@ -1026,7 +1039,7 @@ proc close*(socket: Socket) =
|
||||
# Don't call SSL_shutdown if the connection has not been fully
|
||||
# established, see:
|
||||
# https://github.com/openssl/openssl/issues/710#issuecomment-253897666
|
||||
if SSL_in_init(socket.sslHandle) == 0:
|
||||
if not socket.sslNoShutdown and SSL_in_init(socket.sslHandle) == 0:
|
||||
# As we are closing the underlying socket immediately afterwards,
|
||||
# it is valid, under the TLS standard, to perform a unidirectional
|
||||
# shutdown i.e not wait for the peers "close notify" alert with a second
|
||||
@@ -1312,9 +1325,9 @@ proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
|
||||
if result < 0:
|
||||
data.setLen(0)
|
||||
let lastError = getSocketError(socket)
|
||||
if flags.isDisconnectionError(lastError): return
|
||||
socket.socketError(result, lastError = lastError)
|
||||
data.setLen(result)
|
||||
socket.socketError(result, lastError = lastError, flags = flags)
|
||||
else:
|
||||
data.setLen(result)
|
||||
|
||||
proc recv*(socket: Socket, size: int, timeout = -1,
|
||||
flags = {SocketFlag.SafeDisconn}): string {.inline.} =
|
||||
@@ -1388,8 +1401,9 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
|
||||
|
||||
template raiseSockError() {.dirty.} =
|
||||
let lastError = getSocketError(socket)
|
||||
if flags.isDisconnectionError(lastError): setLen(line.string, 0); return
|
||||
socket.socketError(n, lastError = lastError)
|
||||
if flags.isDisconnectionError(lastError):
|
||||
setLen(line.string, 0)
|
||||
socket.socketError(n, lastError = lastError, flags = flags)
|
||||
|
||||
var waited: Duration
|
||||
|
||||
@@ -1520,8 +1534,7 @@ proc send*(socket: Socket, data: string,
|
||||
let sent = send(socket, cstring(data), data.len)
|
||||
if sent < 0:
|
||||
let lastError = osLastError()
|
||||
if flags.isDisconnectionError(lastError): return
|
||||
socketError(socket, lastError = lastError)
|
||||
socketError(socket, lastError = lastError, flags = flags)
|
||||
|
||||
if sent != data.len:
|
||||
raiseOSError(osLastError(), "Could not send all data.")
|
||||
|
||||
63
tests/stdlib/tssl.nim
Normal file
63
tests/stdlib/tssl.nim
Normal file
@@ -0,0 +1,63 @@
|
||||
discard """
|
||||
joinable: false
|
||||
"""
|
||||
|
||||
import net, nativesockets
|
||||
|
||||
when defined(posix): import os, posix
|
||||
|
||||
when not defined(ssl):
|
||||
{.error: "This test must be compiled with -d:ssl".}
|
||||
|
||||
const DummyData = "dummy data\n"
|
||||
|
||||
proc connector(port: Port) {.thread.} =
|
||||
let clientContext = newContext(verifyMode = CVerifyNone)
|
||||
var client = newSocket(buffered = false)
|
||||
clientContext.wrapSocket(client)
|
||||
client.connect("localhost", port)
|
||||
|
||||
discard client.recvLine()
|
||||
client.getFd.close()
|
||||
|
||||
proc main() =
|
||||
let serverContext = newContext(verifyMode = CVerifyNone,
|
||||
certFile = "tests/testdata/mycert.pem",
|
||||
keyFile = "tests/testdata/mycert.pem")
|
||||
|
||||
when defined(posix):
|
||||
var
|
||||
ignoreAction = SigAction(sa_handler: SIG_IGN)
|
||||
oldSigPipeHandler: SigAction
|
||||
if sigemptyset(ignoreAction.sa_mask) == -1:
|
||||
raiseOSError(osLastError(), "Couldn't create an empty signal set")
|
||||
if sigaction(SIGPIPE, ignoreAction, oldSigPipeHandler) == -1:
|
||||
raiseOSError(osLastError(), "Couldn't ignore SIGPIPE")
|
||||
|
||||
block peer_close_without_shutdown:
|
||||
var server = newSocket(buffered = false)
|
||||
defer: server.close()
|
||||
serverContext.wrapSocket(server)
|
||||
server.bindAddr(address = "localhost")
|
||||
let (_, port) = server.getLocalAddr()
|
||||
server.listen()
|
||||
|
||||
var clientThread: Thread[Port]
|
||||
createThread(clientThread, connector, port)
|
||||
|
||||
var peer: Socket
|
||||
try:
|
||||
server.accept(peer)
|
||||
peer.send(DummyData)
|
||||
|
||||
joinThread clientThread
|
||||
|
||||
while true:
|
||||
# Send data until we get EPIPE.
|
||||
peer.send(DummyData, {})
|
||||
except OSError:
|
||||
discard
|
||||
finally:
|
||||
peer.close()
|
||||
|
||||
when isMainModule: main()
|
||||
5
tests/stdlib/tssl.nims
Normal file
5
tests/stdlib/tssl.nims
Normal file
@@ -0,0 +1,5 @@
|
||||
--threads:on
|
||||
--d:ssl
|
||||
when defined(freebsd):
|
||||
# See https://github.com/nim-lang/Nim/pull/15066#issuecomment-665541265
|
||||
--tlsEmulation:off
|
||||
Reference in New Issue
Block a user