mirror of
https://github.com/nim-lang/Nim.git
synced 2025-12-31 02:12:11 +00:00
Merge pull request #7066 from nim-lang/net-fixes
Assert on closed sockets
This commit is contained in:
@@ -277,6 +277,7 @@ template readInto(buf: pointer, size: int, socket: AsyncSocket,
|
||||
flags: set[SocketFlag]): int =
|
||||
## Reads **up to** ``size`` bytes from ``socket`` into ``buf``. Note that
|
||||
## this is a template and not a proc.
|
||||
assert(not socket.closed, "Cannot `recv` on a closed socket")
|
||||
var res = 0
|
||||
if socket.isSsl:
|
||||
when defineSsl:
|
||||
@@ -403,6 +404,7 @@ proc send*(socket: AsyncSocket, buf: pointer, size: int,
|
||||
## Sends ``size`` bytes from ``buf`` to ``socket``. The returned future will complete once all
|
||||
## data has been sent.
|
||||
assert socket != nil
|
||||
assert(not socket.closed, "Cannot `send` on a closed socket")
|
||||
if socket.isSsl:
|
||||
when defineSsl:
|
||||
sslLoop(socket, flags,
|
||||
|
||||
@@ -868,6 +868,7 @@ proc close*(socket: Socket) =
|
||||
socket.sslHandle = nil
|
||||
|
||||
socket.fd.close()
|
||||
socket.fd = osInvalidSocket
|
||||
|
||||
when defined(posix):
|
||||
from posix import TCP_NODELAY
|
||||
@@ -1005,15 +1006,25 @@ proc select(readfd: Socket, timeout = 500): int =
|
||||
var fds = @[readfd.fd]
|
||||
result = select(fds, timeout)
|
||||
|
||||
proc isClosed(socket: Socket): bool =
|
||||
socket.fd == osInvalidSocket
|
||||
|
||||
proc uniRecv(socket: Socket, buffer: pointer, size, flags: cint): int =
|
||||
## Handles SSL and non-ssl recv in a nice package.
|
||||
##
|
||||
## In particular handles the case where socket has been closed properly
|
||||
## for both SSL and non-ssl.
|
||||
result = 0
|
||||
assert(not socket.isClosed, "Cannot `recv` on a closed socket")
|
||||
when defineSsl:
|
||||
if socket.isSsl:
|
||||
return SSLRead(socket.sslHandle, buffer, size)
|
||||
|
||||
return recv(socket.fd, buffer, size, flags)
|
||||
|
||||
proc readIntoBuf(socket: Socket, flags: int32): int =
|
||||
result = 0
|
||||
when defineSsl:
|
||||
if socket.isSSL:
|
||||
result = SSLRead(socket.sslHandle, addr(socket.buffer), int(socket.buffer.high))
|
||||
else:
|
||||
result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags)
|
||||
else:
|
||||
result = recv(socket.fd, addr(socket.buffer), cint(socket.buffer.high), flags)
|
||||
result = uniRecv(socket, addr(socket.buffer), socket.buffer.high, flags)
|
||||
if result < 0:
|
||||
# Save it in case it gets reset (the Nim codegen occasionally may call
|
||||
# Win API functions which reset it).
|
||||
@@ -1059,16 +1070,16 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect]
|
||||
else:
|
||||
when defineSsl:
|
||||
if socket.isSSL:
|
||||
if socket.sslHasPeekChar:
|
||||
if socket.sslHasPeekChar: # TODO: Merge this peek char mess into uniRecv
|
||||
copyMem(data, addr(socket.sslPeekChar), 1)
|
||||
socket.sslHasPeekChar = false
|
||||
if size-1 > 0:
|
||||
var d = cast[cstring](data)
|
||||
result = SSLRead(socket.sslHandle, addr(d[1]), size-1) + 1
|
||||
result = uniRecv(socket, addr(d[1]), cint(size-1), 0'i32) + 1
|
||||
else:
|
||||
result = 1
|
||||
else:
|
||||
result = SSLRead(socket.sslHandle, data, size)
|
||||
result = uniRecv(socket, data, size.cint, 0'i32)
|
||||
else:
|
||||
result = recv(socket.fd, data, size.cint, 0'i32)
|
||||
else:
|
||||
@@ -1145,7 +1156,11 @@ proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
|
||||
##
|
||||
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
|
||||
data.setLen(size)
|
||||
result = recv(socket, cstring(data), size, timeout)
|
||||
result =
|
||||
if timeout == -1:
|
||||
recv(socket, cstring(data), size)
|
||||
else:
|
||||
recv(socket, cstring(data), size, timeout)
|
||||
if result < 0:
|
||||
data.setLen(0)
|
||||
let lastError = getSocketError(socket)
|
||||
@@ -1182,7 +1197,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
|
||||
when defineSsl:
|
||||
if socket.isSSL:
|
||||
if not socket.sslHasPeekChar:
|
||||
result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1)
|
||||
result = uniRecv(socket, addr(socket.sslPeekChar), 1, 0'i32)
|
||||
socket.sslHasPeekChar = true
|
||||
|
||||
c = socket.sslPeekChar
|
||||
@@ -1316,6 +1331,7 @@ proc send*(socket: Socket, data: pointer, size: int): int {.
|
||||
##
|
||||
## **Note**: This is a low-level version of ``send``. You likely should use
|
||||
## the version below.
|
||||
assert(not socket.isClosed, "Cannot `send` on a closed socket")
|
||||
when defineSsl:
|
||||
if socket.isSSL:
|
||||
return SSLWrite(socket.sslHandle, cast[cstring](data), size)
|
||||
@@ -1360,6 +1376,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
|
||||
## which is defined below.
|
||||
##
|
||||
## **Note:** This proc is not available for SSL sockets.
|
||||
assert(not socket.isClosed, "Cannot `sendTo` on a closed socket")
|
||||
var aiList = getAddrInfo(address, port, af)
|
||||
|
||||
# try all possibilities:
|
||||
|
||||
Reference in New Issue
Block a user