From e5e445f042f9037e77a0bc798960549acd4bf99a Mon Sep 17 00:00:00 2001 From: ringabout <43030857+ringabout@users.noreply.github.com> Date: Wed, 24 Aug 2022 00:58:08 +0800 Subject: [PATCH] fixes #19973; switch to poll on posix (#20212) * fixes #19973; switch to poll on posix * it is fd * exclude lwip * fixes lwip * rename select to timeoutRead * refactor into timeoutRead/timeoutWrite * refactor common parts Co-authored-by: xflywind <43030857+xflywind@users.noreply.github.com> (cherry picked from commit 2b8f0a797149329158f8874067e18badd54dd76a) --- lib/pure/net.nim | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/lib/pure/net.nim b/lib/pure/net.nim index a98f0c674e..bd022e3918 100644 --- a/lib/pure/net.nim +++ b/lib/pure/net.nim @@ -204,6 +204,30 @@ type when defined(nimHasStyleChecks): {.pop.} + +when defined(posix) and not defined(lwip): + from posix import TPollfd, POLLIN, POLLPRI, POLLOUT, POLLWRBAND, Tnfds + + template monitorPollEvent(x: var SocketHandle, y: cint, timeout: int): int = + var tpollfd: TPollfd + tpollfd.fd = cast[cint](x) + tpollfd.events = y + posix.poll(addr(tpollfd), Tnfds(1), timeout) + +proc timeoutRead(fd: var SocketHandle, timeout = 500): int = + when defined(windows) or defined(lwip): + var fds = @[fd] + selectRead(fds, timeout) + else: + monitorPollEvent(fd, POLLIN or POLLPRI, timeout) + +proc timeoutWrite(fd: var SocketHandle, timeout = 500): int = + when defined(windows) or defined(lwip): + var fds = @[fd] + selectWrite(fds, timeout) + else: + monitorPollEvent(fd, POLLOUT or POLLWRBAND, timeout) + proc socketError*(socket: Socket, err: int = -1, async = false, lastError = (-1).OSErrorCode, flags: set[SocketFlag] = {}) {.gcsafe.} @@ -1308,14 +1332,6 @@ proc hasDataBuffered*(s: Socket): bool = if s.isSsl and not result: result = s.sslHasPeekChar -proc select(readfd: Socket, timeout = 500): int = - ## Used for socket operation timeouts. - if readfd.hasDataBuffered: - return 1 - - var fds = @[readfd.fd] - result = selectRead(fds, timeout) - proc isClosed(socket: Socket): bool = socket.fd == osInvalidSocket @@ -1429,7 +1445,9 @@ proc waitFor(socket: Socket, waited: var Duration, timeout, size: int, return min(sslPending, size) var startTime = getMonoTime() - let selRet = select(socket, (timeout - waited.inMilliseconds).int) + let selRet = if socket.hasDataBuffered: 1 + else: + timeoutRead(socket.fd, (timeout - waited.inMilliseconds).int) if selRet < 0: raiseOSError(osLastError()) if selRet != 1: raise newException(TimeoutError, "Call to '" & funcName & "' timed out.") @@ -2109,8 +2127,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), socket.fd.setBlocking(false) socket.connectAsync(address, port, socket.domain) - var s = @[socket.fd] - if selectWrite(s, timeout) != 1: + if timeoutWrite(socket.fd, timeout) != 1: raise newException(TimeoutError, "Call to 'connect' timed out.") else: let res = getSockOptInt(socket.fd, SOL_SOCKET, SO_ERROR)