Implement safe flags for socket operations.

This commit is contained in:
Dominik Picheta
2014-07-13 22:32:39 +01:00
parent cdcdab49b7
commit ac8ddb0720
4 changed files with 108 additions and 48 deletions

View File

@@ -11,8 +11,9 @@ include "system/inclrtl"
import os, oids, tables, strutils, macros
import rawsockets
export TPort
import rawsockets, net
export TPort, TSocketFlags
#{.injectStmt: newGcInvariant().}
@@ -353,7 +354,7 @@ when defined(windows) or defined(nimdoc):
return retFuture
proc recv*(socket: TAsyncFD, size: int,
flags: int = 0): PFuture[string] =
flags = {TSocketFlags.SafeDisconn}): PFuture[string] =
## Reads **up to** ``size`` bytes from ``socket``. Returned future will
## complete once all the data requested is read, a part of the data has been
## read, or the socket has disconnected in which case the future will
@@ -373,7 +374,7 @@ when defined(windows) or defined(nimdoc):
dataBuf.len = size
var bytesReceived: DWord
var flagsio = flags.DWord
var flagsio = flags.toOSFlags().DWord
var ol = PCustomOverlapped()
GC_ref(ol)
ol.data = TCompletionData(sock: socket, cb:
@@ -403,7 +404,10 @@ when defined(windows) or defined(nimdoc):
dealloc dataBuf.buf
dataBuf.buf = nil
GC_unref(ol)
retFuture.fail(newException(EOS, osErrorMsg(err)))
if flags.isDisconnectionError(err):
retFuture.complete("")
else:
retFuture.fail(newException(EOS, osErrorMsg(err)))
elif ret == 0 and bytesReceived == 0 and dataBuf.buf[0] == '\0':
# We have to ensure that the buffer is empty because WSARecv will tell
# us immediatelly when it was disconnected, even when there is still
@@ -434,7 +438,8 @@ when defined(windows) or defined(nimdoc):
# free ``ol``.
return retFuture
proc send*(socket: TAsyncFD, data: string): PFuture[void] =
proc send*(socket: TAsyncFD, data: string,
flags = {TSocketFlags.SafeDisconn}): PFuture[void] =
## Sends ``data`` to ``socket``. The returned future will complete once all
## data has been sent.
verifyPresence(socket)
@@ -444,7 +449,7 @@ when defined(windows) or defined(nimdoc):
dataBuf.buf = data # since this is not used in a callback, this is fine
dataBuf.len = data.len
var bytesReceived, flags: DWord
var bytesReceived, lowFlags: DWord
var ol = PCustomOverlapped()
GC_ref(ol)
ol.data = TCompletionData(sock: socket, cb:
@@ -457,12 +462,15 @@ when defined(windows) or defined(nimdoc):
)
let ret = WSASend(socket.TSocketHandle, addr dataBuf, 1, addr bytesReceived,
flags, cast[POverlapped](ol), nil)
lowFlags, cast[POverlapped](ol), nil)
if ret == -1:
let err = osLastError()
if err.int32 != ERROR_IO_PENDING:
retFuture.fail(newException(EOS, osErrorMsg(err)))
GC_unref(ol)
if flags.isDisconnectionError(err):
retFuture.complete()
else:
retFuture.fail(newException(EOS, osErrorMsg(err)))
else:
retFuture.complete()
# We don't deallocate ``ol`` here because even though this completed
@@ -706,7 +714,7 @@ else:
return retFuture
proc recv*(socket: TAsyncFD, size: int,
flags: int = 0): PFuture[string] =
flags = {TSocketFlags.SafeDisconn}): PFuture[string] =
var retFuture = newFuture[string]()
var readBuffer = newString(size)
@@ -719,7 +727,10 @@ else:
if res < 0:
let lastError = osLastError()
if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
retFuture.fail(newException(EOS, osErrorMsg(lastError)))
if flags.isDisconnectionError(lastError):
retFuture.complete("")
else:
retFuture.fail(newException(EOS, osErrorMsg(lastError)))
else:
result = false # We still want this callback to be called.
elif res == 0:
@@ -733,7 +744,8 @@ else:
addRead(socket, cb)
return retFuture
proc send*(socket: TAsyncFD, data: string): PFuture[void] =
proc send*(socket: TAsyncFD, data: string,
flags = {TSocketFlags.SafeDisconn}): PFuture[void] =
var retFuture = newFuture[void]()
var written = 0
@@ -747,7 +759,10 @@ else:
if res < 0:
let lastError = osLastError()
if lastError.int32 notin {EINTR, EWOULDBLOCK, EAGAIN}:
retFuture.fail(newException(EOS, osErrorMsg(lastError)))
if flags.isDisconnectionError(lastError):
retFuture.complete("")
else:
retFuture.fail(newException(EOS, osErrorMsg(lastError)))
else:
result = false # We still want this callback to be called.
else:
@@ -1065,7 +1080,7 @@ proc recvLine*(socket: TAsyncFD): PFuture[string] {.async.} =
if c.len == 0:
return ""
if c == "\r":
c = await recv(socket, 1, MSG_PEEK)
c = await recv(socket, 1, {TSocketFlags.SafeDisconn, TSocketFlags.Peek})
if c.len > 0 and c == "\L":
discard await recv(socket, 1)
addNLIfEmpty()

View File

@@ -80,7 +80,8 @@ proc connect*(socket: PAsyncSocket, address: string, port: TPort,
## or an error occurs.
result = connect(socket.fd.TAsyncFD, address, port, af)
proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} =
proc readIntoBuf(socket: PAsyncSocket,
flags: set[TSocketFlags]): PFuture[int] {.async.} =
var data = await recv(socket.fd.TAsyncFD, BufferSize, flags)
if data.len != 0:
copyMem(addr socket.buffer[0], addr data[0], data.len)
@@ -89,7 +90,7 @@ proc readIntoBuf(socket: PAsyncSocket, flags: int): PFuture[int] {.async.} =
result = data.len
proc recv*(socket: PAsyncSocket, size: int,
flags: int = 0): PFuture[string] {.async.} =
flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} =
## Reads ``size`` bytes from ``socket``. Returned future will complete once
## all of the requested data is read. If socket is disconnected during the
## recv operation then the future may complete with only a part of the
@@ -100,7 +101,7 @@ proc recv*(socket: PAsyncSocket, size: int,
let originalBufPos = socket.currPos
if socket.bufLen == 0:
let res = await socket.readIntoBuf(flags and (not MSG_PEEK))
let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek})
if res == 0:
result.setLen(0)
return
@@ -108,10 +109,10 @@ proc recv*(socket: PAsyncSocket, size: int,
var read = 0
while read < size:
if socket.currPos >= socket.bufLen:
if (flags and MSG_PEEK) == MSG_PEEK:
if TSocketFlags.Peek in flags:
# We don't want to get another buffer if we're peeking.
break
let res = await socket.readIntoBuf(flags and (not MSG_PEEK))
let res = await socket.readIntoBuf(flags - {TSocketFlags.Peek})
if res == 0:
break
@@ -120,18 +121,19 @@ proc recv*(socket: PAsyncSocket, size: int,
read.inc(chunk)
socket.currPos.inc(chunk)
if (flags and MSG_PEEK) == MSG_PEEK:
if TSocketFlags.Peek in flags:
# Restore old buffer cursor position.
socket.currPos = originalBufPos
result.setLen(read)
else:
result = await recv(socket.fd.TAsyncFD, size, flags)
proc send*(socket: PAsyncSocket, data: string): PFuture[void] =
proc send*(socket: PAsyncSocket, data: string,
flags = {TSocketFlags.SafeDisconn}): PFuture[void] =
## Sends ``data`` to ``socket``. The returned future will complete once all
## data has been sent.
assert socket != nil
result = send(socket.fd.TAsyncFD, data)
result = send(socket.fd.TAsyncFD, data, flags)
proc acceptAddr*(socket: PAsyncSocket):
PFuture[tuple[address: string, client: PAsyncSocket]] =
@@ -166,7 +168,8 @@ proc accept*(socket: PAsyncSocket): PFuture[PAsyncSocket] =
retFut.complete(future.read.client)
return retFut
proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} =
proc recvLine*(socket: PAsyncSocket,
flags = {TSocketFlags.SafeDisconn}): PFuture[string] {.async.} =
## Reads a line of data from ``socket``. Returned future will complete once
## a full line is read or an error occurs.
##
@@ -179,21 +182,23 @@ proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} =
## If the socket is disconnected in the middle of a line (before ``\r\L``
## is read) then line will be set to ``""``.
## The partial line **will be lost**.
##
## **Warning**: The ``Peek`` flag is not yet implemented.
template addNLIfEmpty(): stmt =
if result.len == 0:
result.add("\c\L")
assert TSocketFlags.Peek notin flags ## TODO:
if socket.isBuffered:
result = ""
if socket.bufLen == 0:
let res = await socket.readIntoBuf(0)
let res = await socket.readIntoBuf(flags)
if res == 0:
return
var lastR = false
while true:
if socket.currPos >= socket.bufLen:
let res = await socket.readIntoBuf(0)
let res = await socket.readIntoBuf(flags)
if res == 0:
result = ""
break
@@ -214,18 +219,16 @@ proc recvLine*(socket: PAsyncSocket): PFuture[string] {.async.} =
result.add socket.buffer[socket.currPos]
socket.currPos.inc()
else:
result = ""
var c = ""
while true:
c = await recv(socket, 1)
c = await recv(socket, 1, flags)
if c.len == 0:
return ""
if c == "\r":
c = await recv(socket, 1, MSG_PEEK)
c = await recv(socket, 1, flags + {TSocketFlags.Peek})
if c.len > 0 and c == "\L":
let dummy = await recv(socket, 1)
let dummy = await recv(socket, 1, flags)
assert dummy == "\L"
addNLIfEmpty()
return

View File

@@ -350,6 +350,30 @@ type
ETimeout* = object of ESynch
TSocketFlags* {.pure.} = enum
Peek,
SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown.
proc isDisconnectionError*(flags: set[TSocketFlags],
lastError: TOSErrorCode): bool =
## Determines whether ``lastError`` is a disconnection error. Only does this
## if flags contains ``SafeDisconn``.
when useWinVersion:
TSocketFlags.SafeDisconn in flags and
lastError.int32 in {WSAECONNRESET, WSAECONNABORTED, WSAENETRESET,
WSAEDISCON}
else:
TSocketFlags.SafeDisconn in flags and
lastError.int32 in {ECONNRESET, EPIPE, ENETRESET}
proc toOSFlags*(socketFlags: set[TSocketFlags]): cint =
## Converts the flags into the underlying OS representation.
for f in socketFlags:
case f
of TSocketFlags.Peek:
result = result or MSG_PEEK
of TSocketFlags.SafeDisconn: continue
proc createSocket(fd: TSocketHandle, isBuff: bool): PSocket =
assert fd != osInvalidSocket
new(result)
@@ -470,7 +494,8 @@ when defined(ssl):
if SSLSetFd(socket.sslHandle, socket.fd) != 1:
SSLError()
proc socketError*(socket: PSocket, err: int = -1, async = false) =
proc socketError*(socket: PSocket, err: int = -1, async = false,
lastError = (-1).TOSErrorCode) =
## Raises an EOS error based on the error code returned by ``SSLGetError``
## (for SSL sockets) and ``osLastError`` otherwise.
##
@@ -500,17 +525,17 @@ proc socketError*(socket: PSocket, err: int = -1, async = false) =
else: SSLError("Unknown Error")
if err == -1 and not (when defined(ssl): socket.isSSL else: false):
let lastError = osLastError()
let lastE = if lastError.int == -1: osLastError() else: lastError
if async:
when useWinVersion:
if lastError.int32 == WSAEWOULDBLOCK:
if lastE.int32 == WSAEWOULDBLOCK:
return
else: osError(lastError)
else: osError(lastE)
else:
if lastError.int32 == EAGAIN or lastError.int32 == EWOULDBLOCK:
if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK:
return
else: osError(lastError)
else: osError(lastError)
else: osError(lastE)
else: osError(lastE)
proc listen*(socket: PSocket, backlog = SOMAXCONN) {.tags: [FReadIO].} =
## Marks ``socket`` as accepting connections.
@@ -881,7 +906,8 @@ proc recv*(socket: PSocket, data: pointer, size: int, timeout: int): int {.
result = read
proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int =
proc recv*(socket: PSocket, data: var string, size: int, timeout = -1,
flags = {TSocketFlags.SafeDisconn}): int =
## Higher-level version of ``recv``.
##
## When 0 is returned the socket's connection has been closed.
@@ -893,11 +919,15 @@ proc recv*(socket: PSocket, data: var string, size: int, timeout = -1): int =
## within the time specified an ETimeout exception will be raised.
##
## **Note**: ``data`` must be initialised.
##
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
data.setLen(size)
result = recv(socket, cstring(data), size, timeout)
if result < 0:
data.setLen(0)
socket.socketError(result)
let lastError = osLastError()
if flags.isDisconnectionError(lastError): return
socket.socketError(result, lastError = lastError)
data.setLen(result)
proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} =
@@ -920,7 +950,8 @@ proc peekChar(socket: PSocket, c: var char): int {.tags: [FReadIO].} =
return
result = recv(socket.fd, addr(c), 1, MSG_PEEK)
proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1,
flags = {TSocketFlags.SafeDisconn}) {.
tags: [FReadIO, FTime].} =
## Reads a line of data from ``socket``.
##
@@ -934,11 +965,18 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
##
## A timeout can be specified in miliseconds, if data is not received within
## the specified time an ETimeout exception will be raised.
##
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
template addNLIfEmpty(): stmt =
if line.len == 0:
line.add("\c\L")
template raiseSockError(): stmt {.dirty, immediate.} =
let lastError = osLastError()
if flags.isDisconnectionError(lastError): setLen(line.string, 0); return
socket.socketError(n, lastError = lastError)
var waited = 0.0
setLen(line.string, 0)
@@ -946,14 +984,14 @@ proc readLine*(socket: PSocket, line: var TaintedString, timeout = -1) {.
var c: char
discard waitFor(socket, waited, timeout, 1, "readLine")
var n = recv(socket, addr(c), 1)
if n < 0: socket.socketError()
elif n == 0: return
if n < 0: raiseSockError()
elif n == 0: setLen(line.string, 0); return
if c == '\r':
discard waitFor(socket, waited, timeout, 1, "readLine")
n = peekChar(socket, c)
if n > 0 and c == '\L':
discard recv(socket, addr(c), 1)
elif n <= 0: socket.socketError()
elif n <= 0: raiseSockError()
addNLIfEmpty()
return
elif c == '\L':
@@ -1021,11 +1059,14 @@ proc send*(socket: PSocket, data: pointer, size: int): int {.
const MSG_NOSIGNAL = 0
result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
proc send*(socket: PSocket, data: string) {.tags: [FWriteIO].} =
proc send*(socket: PSocket, data: string,
flags = {TSocketFlags.SafeDisconn}) {.tags: [FWriteIO].} =
## sends data to a socket.
let sent = send(socket, cstring(data), data.len)
if sent < 0:
socketError(socket)
let lastError = osLastError()
if flags.isDisconnectionError(lastError): return
socketError(socket, lastError = lastError)
if sent != data.len:
raise newException(EOS, "Could not send all data.")

View File

@@ -21,11 +21,12 @@ const useWinVersion = defined(Windows) or defined(nimdoc)
when useWinVersion:
import winlean
export WSAEWOULDBLOCK
export WSAEWOULDBLOCK, WSAECONNRESET, WSAECONNABORTED, WSAENETRESET,
WSAEDISCON
else:
import posix
export fcntl, F_GETFL, O_NONBLOCK, F_SETFL, EAGAIN, EWOULDBLOCK, MSG_NOSIGNAL,
EINTR, EINPROGRESS
EINTR, EINPROGRESS, ECONNRESET, EPIPE, ENETRESET
export TSocketHandle, TSockaddr_in, TAddrinfo, INADDR_ANY, TSockAddr, TSockLen,
inet_ntoa, recv, `==`, connect, send, accept, recvfrom, sendto