mirror of
https://github.com/nim-lang/Nim.git
synced 2026-01-02 11:12:37 +00:00
Implement safe flags for socket operations.
This commit is contained in:
@@ -11,8 +11,9 @@ include "system/inclrtl"
|
||||
|
||||
import os, oids, tables, strutils, macros
|
||||
|
||||
import rawsockets
|
||||
export TPort
|
||||
import rawsockets, net
|
||||
|
||||
export TPort, TSocketFlags
|
||||
|
||||
#{.injectStmt: newGcInvariant().}
|
||||
|
||||
@@ -353,7 +354,7 @@ when defined(windows) or defined(nimdoc):
|
||||
return retFuture
|
||||
|
||||
proc recv*(socket: TAsyncFD, size: int,
|
||||
flags: int = 0): PFuture[string] =
|
||||
flags = {TSocketFlags.SafeDisconn}): PFuture[string] =
|
||||
## Reads **up to** ``size`` bytes from ``socket``. Returned future will
|
||||
## complete once all the data requested is read, a part of the data has been
|
||||
## read, or the socket has disconnected in which case the future will
|
||||
@@ -373,7 +374,7 @@ when defined(windows) or defined(nimdoc):
|
||||
dataBuf.len = size
|
||||
|
||||
var bytesReceived: DWord
|
||||
var flagsio = flags.DWord
|
||||
var flagsio = flags.toOSFlags().DWord
|
||||
var ol = PCustomOverlapped()
|
||||
GC_ref(ol)
|
||||
ol.data = TCompletionData(sock: socket, cb:
|
||||
@@ -403,7 +404,10 @@ when defined(windows) or defined(nimdoc):
|
||||
dealloc dataBuf.buf
|
||||
dataBuf.buf = nil
|
||||
GC_unref(ol)
|
||||
retFuture.fail(newException(EOS, osErrorMsg(err)))
|
||||
if flags.isDisconnectionError(err):
|
||||
retFuture.complete("")
|
||||
else:
|
||||
retFuture.fail(newException(EOS, osErrorMsg(err)))
|
||||
elif ret == 0 and bytesReceived == 0 and dataBuf.buf[0] == '\0':
|
||||
# We have to ensure that the buffer is empty because WSARecv will tell
|
||||
# us immediatelly when it was disconnected, even when there is still
|
||||
@@ -434,7 +438,8 @@ when defined(windows) or defined(nimdoc):
|
||||
# free ``ol``.
|
||||
return retFuture
|
||||
|
||||
proc send*(socket: TAsyncFD, data: string): PFuture[void] =
|
||||
proc send*(socket: TAsyncFD, data: string,
|
||||
flags = {TSocketFlags.SafeDisconn}): PFuture[void] =
|
||||
## Sends ``data`` to ``socket``. The returned future will complete once all
|
||||
## data has been sent.
|
||||
verifyPresence(socket)
|
||||
@@ -444,7 +449,7 @@ when defined(windows) or defined(nimdoc):
|
||||
dataBuf.buf = data # since this is not used in a callback, this is fine
|
||||
dataBuf.len = data.len
|
||||
|
||||
var bytesReceived, flags: DWord
|
||||
var bytesReceived, lowFlags: DWord
|
||||
var ol = PCustomOverlapped()
|
||||
GC_ref(ol)
|
||||
ol.data = TCompletionData(sock: socket, cb:
|
||||
@@ -457,12 +462,15 @@ when defined(windows) or defined(nimdoc):
|
||||
)
|
||||
|
||||
let ret = WSASend(socket.TSocketHandle, addr dataBuf, 1, addr bytesReceived,
|
||||
flags, cast[POverlapped](ol), nil)
|
||||
lowFlags, cast[POverlapped](ol), nil)
|
||||
if ret == -1:
|
||||
let err = osLastError()
|
||||
if err.int32 != ERROR_IO_PENDING:
|
||||
retFuture.fail(newException(EOS, osErrorMsg(err)))
|
||||
GC_unref(ol)
|
||||
if flags.isDisconnectionError(err):
|
||||
retFuture.complete()
|
||||
else:
|
||||
retFuture.fail(newException(EOS, osErrorMsg(err)))
|
||||
else:
|
||||
retFuture.complete()
|
||||
# We don't deallocate ``ol`` here because even though this completed
|
||||
@@ -706,7 +714,7 @@ else:
|
||||
return retFuture
|
||||
|
||||
proc recv*(socket: TAsyncFD, size: int,
|
||||
flags: int = 0): PFuture[string] =
|
||||
flags = {TSocketFlags.SafeDisconn}): PFuture[string] =
|
||||
var retFuture = newFuture[string]()
|
||||
|
||||
var readBuffer = newString(size)
|
||||
@@ -719,7 +727,10 @@ else:
|
||||
if res < 0:
|
||||
let lastError = osLastError()
|
||||
if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
|
||||
retFuture.fail(newException(EOS, osErrorMsg(lastError)))
|
||||
if flags.isDisconnectionError(lastError):
|
||||
retFuture.complete("")
|
||||
else:
|
||||
retFuture.fail(newException(EOS, osErrorMsg(lastError)))
|
||||
else:
|
||||
result = false # We still want this callback to be called.
|
||||
elif res == 0:
|
||||
@@ -733,7 +744,8 @@ else:
|
||||
addRead(socket, cb)
|
||||
return retFuture
|
||||
|
||||
proc send*(socket: TAsyncFD, data: string): PFuture[void] =
|
||||
proc send*(socket: TAsyncFD, data: string,
|
||||
flags = {TSocketFlags.SafeDisconn}): PFuture[void] =
|
||||
var retFuture = newFuture[void]()
|
||||
|
||||
var written = 0
|
||||
@@ -747,7 +759,10 @@ else:
|
||||
if res < 0:
|
||||
let lastError = osLastError()
|
||||
if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
|
||||
retFuture.fail(newException(EOS, osErrorMsg(lastError)))
|
||||
if flags.isDisconnectionError(lastError):
|
||||
retFuture.complete("")
|
||||
else:
|
||||
retFuture.fail(newException(EOS, osErrorMsg(lastError)))
|
||||
else:
|
||||
result = false # We still want this callback to be called.
|
||||
else:
|
||||
@@ -1065,7 +1080,7 @@ proc recvLine*(socket: TAsyncFD): PFuture[string] {.async.} =
|
||||
if c.len == 0:
|
||||
return ""
|
||||
if c == "\r":
|
||||
c = await recv(socket, 1, MSG_PEEK)
|
||||
c = await recv(socket, 1, {TSocketFlags.SafeDisconn, TSocketFlags.Peek})
|
||||
if c.len > 0 and c == "\L":
|
||||
discard await recv(socket, 1)
|
||||
addNLIfEmpty()
|
||||
|
||||
@@ -80,7 +80,8 @@ proc connect*(socket: PAsyncSocket, address: string, port: TPort,
|
||||
## or an error occurs.
|
||||
result = connect(socket.fd.TAsyncFD, address, port, af)
|
||||
|
||||
proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} =
|
||||
proc readIntoBuf(socket: PAsyncSocket,
|
||||
flags: set[TSocketFlags]): PFuture[int] {.async.} =
|
||||
var data = await recv(socket.fd.TAsyncFD, BufferSize, flags)
|
||||
if data.len != 0:
|
||||
copyMem(addr socket.buffer[0], addr data[0], data.len)
|
||||
@@ -89,7 +90,7 @@ proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} =
|
||||
result = data.len
|
||||
|
||||
proc recv*(socket: PAsyncSocket, size: int,
|
||||
flags: int = 0): PFuture[string] {.async.} =
|
||||
flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} =
|
||||
## Reads ``size`` bytes from ``socket``. Returned future will complete once
|
||||
## all of the requested data is read. If socket is disconnected during the
|
||||
## recv operation then the future may complete with only a part of the
|
||||
@@ -100,7 +101,7 @@ proc recv*(socket: PAsyncSocket, size: int,
|
||||
let originalBufPos = socket.currPos
|
||||
|
||||
if socket.bufLen == 0:
|
||||
let res = await socket.readIntoBuf(flags and (not MSG_PEEK))
|
||||
let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek})
|
||||
if res == 0:
|
||||
result.setLen(0)
|
||||
return
|
||||
@@ -108,10 +109,10 @@ proc recv*(socket: PAsyncSocket, size: int,
|
||||
var read = 0
|
||||
while read < size:
|
||||
if socket.currPos >= socket.bufLen:
|
||||
if (flags and MSG_PEEK) == MSG_PEEK:
|
||||
if TSocketFlags.Peek in flags:
|
||||
# We don't want to get another buffer if we're peeking.
|
||||
break
|
||||
let res = await socket.readIntoBuf(flags and (not MSG_PEEK))
|
||||
let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek})
|
||||
if res == 0:
|
||||
break
|
||||
|
||||
@@ -120,18 +121,19 @@ proc recv*(socket: PAsyncSocket, size: int,
|
||||
read.inc(chunk)
|
||||
socket.currPos.inc(chunk)
|
||||
|
||||
if (flags and MSG_PEEK) == MSG_PEEK:
|
||||
if TSocketFlags.Peek in flags:
|
||||
# Restore old buffer cursor position.
|
||||
socket.currPos = originalBufPos
|
||||
result.setLen(read)
|
||||
else:
|
||||
result = await recv(socket.fd.TAsyncFD, size, flags)
|
||||
|
||||
proc send*(socket: PAsyncSocket, data: string): PFuture[void] =
|
||||
proc send*(socket: PAsyncSocket, data: string,
|
||||
flags = {TSocketFlags.SafeDisconn}): PFuture[void] =
|
||||
## Sends ``data`` to ``socket``. The returned future will complete once all
|
||||
## data has been sent.
|
||||
assert socket != nil
|
||||
result = send(socket.fd.TAsyncFD, data)
|
||||
result = send(socket.fd.TAsyncFD, data, flags)
|
||||
|
||||
proc acceptAddr*(socket: PAsyncSocket):
|
||||
PFuture[tuple[address: string, client: PAsyncSocket]] =
|
||||
@@ -166,7 +168,8 @@ proc accept*(socket: PAsyncSocket): PFuture[PAsyncSocket] =
|
||||
retFut.complete(future.read.client)
|
||||
return retFut
|
||||
|
||||
proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} =
|
||||
proc recvLine*(socket: PAsyncSocket,
|
||||
flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} =
|
||||
## Reads a line of data from ``socket``. Returned future will complete once
|
||||
## a full line is read or an error occurs.
|
||||
##
|
||||
@@ -179,21 +182,23 @@ proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} =
|
||||
## If the socket is disconnected in the middle of a line (before ``\r\L``
|
||||
## is read) then line will be set to ``""``.
|
||||
## The partial line **will be lost**.
|
||||
##
|
||||
## **Warning**: The ``Peek`` flag is not yet implemented.
|
||||
template addNLIfEmpty(): stmt =
|
||||
if result.len == 0:
|
||||
result.add("\c\L")
|
||||
|
||||
assert TSocketFlags.Peek notin flags ## TODO:
|
||||
if socket.isBuffered:
|
||||
result = ""
|
||||
if socket.bufLen == 0:
|
||||
let res = await socket.readIntoBuf(0)
|
||||
let res = await socket.readIntoBuf(flags)
|
||||
if res == 0:
|
||||
return
|
||||
|
||||
var lastR = false
|
||||
while true:
|
||||
if socket.currPos >= socket.bufLen:
|
||||
let res = await socket.readIntoBuf(0)
|
||||
let res = await socket.readIntoBuf(flags)
|
||||
if res == 0:
|
||||
result = ""
|
||||
break
|
||||
@@ -214,18 +219,16 @@ proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} =
|
||||
result.add socket.buffer[socket.currPos]
|
||||
socket.currPos.inc()
|
||||
else:
|
||||
|
||||
|
||||
result = ""
|
||||
var c = ""
|
||||
while true:
|
||||
c = await recv(socket, 1)
|
||||
c = await recv(socket, 1, flags)
|
||||
if c.len == 0:
|
||||
return ""
|
||||
if c == "\r":
|
||||
c = await recv(socket, 1, MSG_PEEK)
|
||||
c = await recv(socket, 1, flags + {TSocketFlags.Peek})
|
||||
if c.len > 0 and c == "\L":
|
||||
let dummy = await recv(socket, 1)
|
||||
let dummy = await recv(socket, 1, flags)
|
||||
assert dummy == "\L"
|
||||
addNLIfEmpty()
|
||||
return
|
||||
|
||||
@@ -350,6 +350,30 @@ type
|
||||
|
||||
ETimeout* = object of ESynch
|
||||
|
||||
TSocketFlags* {.pure.} = enum
|
||||
Peek,
|
||||
SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown.
|
||||
|
||||
proc isDisconnectionError*(flags: set[TSocketFlags],
|
||||
lastError: TOSErrorCode): bool =
|
||||
## Determines whether ``lastError`` is a disconnection error. Only does this
|
||||
## if flags contains ``SafeDisconn``.
|
||||
when useWinVersion:
|
||||
TSocketFlags.SafeDisconn in flags and
|
||||
lastError.int32 in {WSAECONNRESET, WSAECONNABORTED, WSAENETRESET,
|
||||
WSAEDISCON}
|
||||
else:
|
||||
TSocketFlags.SafeDisconn in flags and
|
||||
lastError.int32 in {ECONNRESET, EPIPE, ENETRESET}
|
||||
|
||||
proc toOSFlags*(socketFlags: set[TSocketFlags]): cint =
|
||||
## Converts the flags into the underlying OS representation.
|
||||
for f in socketFlags:
|
||||
case f
|
||||
of TSocketFlags.Peek:
|
||||
result = result or MSG_PEEK
|
||||
of TSocketFlags.SafeDisconn: continue
|
||||
|
||||
proc createSocket(fd: TSocketHandle, isBuff: bool): PSocket =
|
||||
assert fd != osInvalidSocket
|
||||
new(result)
|
||||
@@ -470,7 +494,8 @@ when defined(ssl):
|
||||
if SSLSetFd(socket.sslHandle, socket.fd) != 1:
|
||||
SSLError()
|
||||
|
||||
proc socketError*(socket: PSocket, err: int = -1, async = false) =
|
||||
proc socketError*(socket: PSocket, err: int = -1, async = false,
|
||||
lastError = (-1).TOSErrorCode) =
|
||||
## Raises an EOS error based on the error code returned by ``SSLGetError``
|
||||
## (for SSL sockets) and ``osLastError`` otherwise.
|
||||
##
|
||||
@@ -500,17 +525,17 @@ proc socketError*(socket: PSocket, err: int = -1, async = false) =
|
||||
else: SSLError("Unknown Error")
|
||||
|
||||
if err == -1 and not (when defined(ssl): socket.isSSL else: false):
|
||||
let lastError = osLastError()
|
||||
let lastE = if lastError.int == -1: osLastError() else: lastError
|
||||
if async:
|
||||
when useWinVersion:
|
||||
if lastError.int32 == WSAEWOULDBLOCK:
|
||||
if lastE.int32 == WSAEWOULDBLOCK:
|
||||
return
|
||||
else: osError(lastError)
|
||||
else: osError(lastE)
|
||||
else:
|
||||
if lastError.int32 == EAGAIN or lastError.int32 == EWOULDBLOCK:
|
||||
if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK:
|
||||
return
|
||||
else: osError(lastError)
|
||||
else: osError(lastError)
|
||||
else: osError(lastE)
|
||||
else: osError(lastE)
|
||||
|
||||
proc listen*(socket: PSocket, backlog = SOMAXCONN) {.tags: [FReadIO].} =
|
||||
## Marks ``socket`` as accepting connections.
|
||||
@@ -881,7 +906,8 @@ proc recv*(socket: PSocket, data: pointer, size: int, timeout: int): int {.
|
||||
|
||||
result = read
|
||||
|
||||
proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int =
|
||||
proc recv*(socket: PSocket, data: var string, size: int, timeout = -1,
|
||||
flags = {TSocketFlags.SafeDisconn}): int =
|
||||
## Higher-level version of ``recv``.
|
||||
##
|
||||
## When 0 is returned the socket's connection has been closed.
|
||||
@@ -893,11 +919,15 @@ proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int =
|
||||
## within the time specified an ETimeout exception will be raised.
|
||||
##
|
||||
## **Note**: ``data`` must be initialised.
|
||||
##
|
||||
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
|
||||
data.setLen(size)
|
||||
result = recv(socket, cstring(data), size, timeout)
|
||||
if result < 0:
|
||||
data.setLen(0)
|
||||
socket.socketError(result)
|
||||
let lastError = osLastError()
|
||||
if flags.isDisconnectionError(lastError): return
|
||||
socket.socketError(result, lastError = lastError)
|
||||
data.setLen(result)
|
||||
|
||||
proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} =
|
||||
@@ -920,7 +950,8 @@ proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} =
|
||||
return
|
||||
result = recv(socket.fd, addr(c), 1, MSG_PEEK)
|
||||
|
||||
proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
|
||||
proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1,
|
||||
flags = {TSocketFlags.SafeDisconn}) {.
|
||||
tags: [FReadIO, FTime].} =
|
||||
## Reads a line of data from ``socket``.
|
||||
##
|
||||
@@ -934,11 +965,18 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
|
||||
##
|
||||
## A timeout can be specified in miliseconds, if data is not received within
|
||||
## the specified time an ETimeout exception will be raised.
|
||||
##
|
||||
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
|
||||
|
||||
template addNLIfEmpty(): stmt =
|
||||
if line.len == 0:
|
||||
line.add("\c\L")
|
||||
|
||||
template raiseSockError(): stmt {.dirty, immediate.} =
|
||||
let lastError = osLastError()
|
||||
if flags.isDisconnectionError(lastError): setLen(line.string, 0); return
|
||||
socket.socketError(n, lastError = lastError)
|
||||
|
||||
var waited = 0.0
|
||||
|
||||
setLen(line.string, 0)
|
||||
@@ -946,14 +984,14 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
|
||||
var c: char
|
||||
discard waitFor(socket, waited, timeout, 1, "readLine")
|
||||
var n = recv(socket, addr(c), 1)
|
||||
if n < 0: socket.socketError()
|
||||
elif n == 0: return
|
||||
if n < 0: raiseSockError()
|
||||
elif n == 0: setLen(line.string, 0); return
|
||||
if c == '\r':
|
||||
discard waitFor(socket, waited, timeout, 1, "readLine")
|
||||
n = peekChar(socket, c)
|
||||
if n > 0 and c == '\L':
|
||||
discard recv(socket, addr(c), 1)
|
||||
elif n <= 0: socket.socketError()
|
||||
elif n <= 0: raiseSockError()
|
||||
addNLIfEmpty()
|
||||
return
|
||||
elif c == '\L':
|
||||
@@ -1021,11 +1059,14 @@ proc send*(socket: PSocket, data: pointer, size: int): int {.
|
||||
const MSG_NOSIGNAL = 0
|
||||
result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
|
||||
|
||||
proc send*(socket: PSocket, data: string) {.tags: [FWriteIO].} =
|
||||
proc send*(socket: PSocket, data: string,
|
||||
flags = {TSocketFlags.SafeDisconn}) {.tags: [FWriteIO].} =
|
||||
## sends data to a socket.
|
||||
let sent = send(socket, cstring(data), data.len)
|
||||
if sent < 0:
|
||||
socketError(socket)
|
||||
let lastError = osLastError()
|
||||
if flags.isDisconnectionError(lastError): return
|
||||
socketError(socket, lastError = lastError)
|
||||
|
||||
if sent != data.len:
|
||||
raise newException(EOS, "Could not send all data.")
|
||||
|
||||
@@ -21,11 +21,12 @@ const useWinVersion = defined(Windows) or defined(nimdoc)
|
||||
|
||||
when useWinVersion:
|
||||
import winlean
|
||||
export WSAEWOULDBLOCK
|
||||
export WSAEWOULDBLOCK, WSAECONNRESET, WSAECONNABORTED, WSAENETRESET,
|
||||
WSAEDISCON
|
||||
else:
|
||||
import posix
|
||||
export fcntl, F_GETFL, O_NONBLOCK, F_SETFL, EAGAIN, EWOULDBLOCK, MSG_NOSIGNAL,
|
||||
EINTR, EINPROGRESS
|
||||
EINTR, EINPROGRESS, ECONNRESET, EPIPE, ENETRESET
|
||||
|
||||
export TSocketHandle, TSockaddr_in, TAddrinfo, INADDR_ANY, TSockAddr, TSockLen,
|
||||
inet_ntoa, recv, `==`, connect, send, accept, recvfrom, sendto
|
||||
|
||||
Reference in New Issue
Block a user