mirror of
https://github.com/nim-lang/Nim.git
synced 2026-05-23 05:09:54 +00:00
Under Linux (probably POSIX), the current code tests for timeout, but does not test for connection failure. connectAsync() returns succesfully upon an EINPROGRESS result; but at this point, the connection state is still unknown. After selectWrite() is done, we need to test the socket for errors again.
1677 lines
60 KiB
Nim
1677 lines
60 KiB
Nim
#
|
|
#
|
|
# Nim's Runtime Library
|
|
# (c) Copyright 2015 Dominik Picheta
|
|
#
|
|
# See the file "copying.txt", included in this
|
|
# distribution, for details about the copyright.
|
|
#
|
|
|
|
## This module implements a high-level cross-platform sockets interface.
|
|
## The procedures implemented in this module are primarily for blocking sockets.
|
|
## For asynchronous non-blocking sockets use the ``asyncnet`` module together
|
|
## with the ``asyncdispatch`` module.
|
|
##
|
|
## The first thing you will always need to do in order to start using sockets,
|
|
## is to create a new instance of the ``Socket`` type using the ``newSocket``
|
|
## procedure.
|
|
##
|
|
## SSL
|
|
## ====
|
|
##
|
|
## In order to use the SSL procedures defined in this module, you will need to
|
|
## compile your application with the ``-d:ssl`` flag.
|
|
##
|
|
## Examples
|
|
## ========
|
|
##
|
|
## Connecting to a server
|
|
## ----------------------
|
|
##
|
|
## After you create a socket with the ``newSocket`` procedure, you can easily
|
|
## connect it to a server running at a known hostname (or IP address) and port.
|
|
## To do so over TCP, use the example below.
|
|
##
|
|
## .. code-block:: Nim
|
|
## var socket = newSocket()
|
|
## socket.connect("google.com", Port(80))
|
|
##
|
|
## UDP is a connectionless protocol, so UDP sockets don't have to explicitly
|
|
## call the ``connect`` procedure. They can simply start sending data
|
|
## immediately.
|
|
##
|
|
## .. code-block:: Nim
|
|
## var socket = newSocket()
|
|
## socket.sendTo("192.168.0.1", Port(27960), "status\n")
|
|
##
|
|
## Creating a server
|
|
## -----------------
|
|
##
|
|
## After you create a socket with the ``newSocket`` procedure, you can create a
|
|
## TCP server by calling the ``bindAddr`` and ``listen`` procedures.
|
|
##
|
|
## .. code-block:: Nim
|
|
## var socket = newSocket()
|
|
## socket.bindAddr(Port(1234))
|
|
## socket.listen()
|
|
##
|
|
## You can then begin accepting connections using the ``accept`` procedure.
|
|
##
|
|
## .. code-block:: Nim
|
|
## var client = new Socket
|
|
## var address = ""
|
|
## while true:
|
|
## socket.acceptAddr(client, address)
|
|
## echo("Client connected from: ", address)
|
|
##
|
|
## **Note:** The ``client`` variable is initialised with ``new Socket`` **not**
|
|
## ``newSocket()``. The difference is that the latter creates a new file
|
|
## descriptor.
|
|
|
|
{.deadCodeElim: on.}
|
|
import nativesockets, os, strutils, parseutils, times, sets, options
|
|
export Port, `$`, `==`
|
|
export Domain, SockType, Protocol
|
|
|
|
const useWinVersion = defined(Windows) or defined(nimdoc)
|
|
const defineSsl = defined(ssl) or defined(nimdoc)
|
|
|
|
when defineSsl:
|
|
import openssl
|
|
|
|
# Note: The enumerations are mapped to Window's constants.
|
|
|
|
when defineSsl:
|
|
type
|
|
SslError* = object of Exception
|
|
|
|
SslCVerifyMode* = enum
|
|
CVerifyNone, CVerifyPeer
|
|
|
|
SslProtVersion* = enum
|
|
protSSLv2, protSSLv3, protTLSv1, protSSLv23
|
|
|
|
SslContext* = ref object
|
|
context*: SslCtx
|
|
referencedData: HashSet[int]
|
|
extraInternal: SslContextExtraInternal
|
|
|
|
SslAcceptResult* = enum
|
|
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
|
|
|
|
SslHandshakeType* = enum
|
|
handshakeAsClient, handshakeAsServer
|
|
|
|
SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string]
|
|
|
|
SslServerGetPskFunc* = proc(identity: string): string
|
|
|
|
SslContextExtraInternal = ref object of RootRef
|
|
serverGetPskFunc: SslServerGetPskFunc
|
|
clientGetPskFunc: SslClientGetPskFunc
|
|
|
|
{.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
|
|
TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
|
|
TSSLAcceptResult: SSLAcceptResult].}
|
|
else:
|
|
type
|
|
SslContext* = void # TODO: Workaround #4797.
|
|
|
|
const
|
|
BufferSize*: int = 4000 ## size of a buffered socket's buffer
|
|
MaxLineLength* = 1_000_000
|
|
|
|
type
|
|
SocketImpl* = object ## socket type
|
|
fd: SocketHandle
|
|
case isBuffered: bool # determines whether this socket is buffered.
|
|
of true:
|
|
buffer: array[0..BufferSize, char]
|
|
currPos: int # current index in buffer
|
|
bufLen: int # current length of buffer
|
|
of false: nil
|
|
when defineSsl:
|
|
case isSsl: bool
|
|
of true:
|
|
sslHandle: SSLPtr
|
|
sslContext: SSLContext
|
|
sslNoHandshake: bool # True if needs handshake.
|
|
sslHasPeekChar: bool
|
|
sslPeekChar: char
|
|
of false: nil
|
|
lastError: OSErrorCode ## stores the last error on this socket
|
|
domain: Domain
|
|
sockType: SockType
|
|
protocol: Protocol
|
|
|
|
Socket* = ref SocketImpl
|
|
|
|
SOBool* = enum ## Boolean socket options.
|
|
OptAcceptConn, OptBroadcast, OptDebug, OptDontRoute, OptKeepAlive,
|
|
OptOOBInline, OptReuseAddr, OptReusePort, OptNoDelay
|
|
|
|
ReadLineResult* = enum ## result for readLineAsync
|
|
ReadFullLine, ReadPartialLine, ReadDisconnected, ReadNone
|
|
|
|
TimeoutError* = object of Exception
|
|
|
|
SocketFlag* {.pure.} = enum
|
|
Peek,
|
|
SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown.
|
|
|
|
{.deprecated: [TSocketFlags: SocketFlag, ETimeout: TimeoutError,
|
|
TReadLineResult: ReadLineResult, TSOBool: SOBool, PSocket: Socket,
|
|
TSocketImpl: SocketImpl].}
|
|
|
|
type
|
|
IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
|
|
IPv6, ## IPv6 address
|
|
IPv4 ## IPv4 address
|
|
|
|
IpAddress* = object ## stores an arbitrary IP address
|
|
case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
|
|
of IpAddressFamily.IPv6:
|
|
address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
|
|
## case of IPv6
|
|
of IpAddressFamily.IPv4:
|
|
address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
|
|
## case of IPv4
|
|
{.deprecated: [TIpAddress: IpAddress].}
|
|
|
|
|
|
proc socketError*(socket: Socket, err: int = -1, async = false,
|
|
lastError = (-1).OSErrorCode): void {.gcsafe.}
|
|
|
|
proc isDisconnectionError*(flags: set[SocketFlag],
|
|
lastError: OSErrorCode): bool =
|
|
## Determines whether ``lastError`` is a disconnection error. Only does this
|
|
## if flags contains ``SafeDisconn``.
|
|
when useWinVersion:
|
|
SocketFlag.SafeDisconn in flags and
|
|
lastError.int32 in {WSAECONNRESET, WSAECONNABORTED, WSAENETRESET,
|
|
WSAEDISCON, ERROR_NETNAME_DELETED}
|
|
else:
|
|
SocketFlag.SafeDisconn in flags and
|
|
lastError.int32 in {ECONNRESET, EPIPE, ENETRESET}
|
|
|
|
proc toOSFlags*(socketFlags: set[SocketFlag]): cint =
|
|
## Converts the flags into the underlying OS representation.
|
|
for f in socketFlags:
|
|
case f
|
|
of SocketFlag.Peek:
|
|
result = result or MSG_PEEK
|
|
of SocketFlag.SafeDisconn: continue
|
|
|
|
proc newSocket*(fd: SocketHandle, domain: Domain = AF_INET,
|
|
sockType: SockType = SOCK_STREAM,
|
|
protocol: Protocol = IPPROTO_TCP, buffered = true): Socket =
|
|
## Creates a new socket as specified by the params.
|
|
assert fd != osInvalidSocket
|
|
result = Socket(
|
|
fd: fd,
|
|
isBuffered: buffered,
|
|
domain: domain,
|
|
sockType: sockType,
|
|
protocol: protocol)
|
|
if buffered:
|
|
result.currPos = 0
|
|
|
|
# Set SO_NOSIGPIPE on OS X.
|
|
when defined(macosx) and not defined(nimdoc):
|
|
setSockOptInt(fd, SOL_SOCKET, SO_NOSIGPIPE, 1)
|
|
|
|
proc newSocket*(domain, sockType, protocol: cint, buffered = true): Socket =
|
|
## Creates a new socket.
|
|
##
|
|
## If an error occurs EOS will be raised.
|
|
let fd = createNativeSocket(domain, sockType, protocol)
|
|
if fd == osInvalidSocket:
|
|
raiseOSError(osLastError())
|
|
result = newSocket(fd, domain.Domain, sockType.SockType, protocol.Protocol,
|
|
buffered)
|
|
|
|
proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
|
|
protocol: Protocol = IPPROTO_TCP, buffered = true): Socket =
|
|
## Creates a new socket.
|
|
##
|
|
## If an error occurs EOS will be raised.
|
|
let fd = createNativeSocket(domain, sockType, protocol)
|
|
if fd == osInvalidSocket:
|
|
raiseOSError(osLastError())
|
|
result = newSocket(fd, domain, sockType, protocol, buffered)
|
|
|
|
proc parseIPv4Address(address_str: string): IpAddress =
|
|
## Parses IPv4 adresses
|
|
## Raises EInvalidValue on errors
|
|
var
|
|
byteCount = 0
|
|
currentByte:uint16 = 0
|
|
seperatorValid = false
|
|
|
|
result.family = IpAddressFamily.IPv4
|
|
|
|
for i in 0 .. high(address_str):
|
|
if address_str[i] in strutils.Digits: # Character is a number
|
|
currentByte = currentByte * 10 +
|
|
cast[uint16](ord(address_str[i]) - ord('0'))
|
|
if currentByte > 255'u16:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Value is out of range")
|
|
seperatorValid = true
|
|
elif address_str[i] == '.': # IPv4 address separator
|
|
if not seperatorValid or byteCount >= 3:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. The address consists of too many groups")
|
|
result.address_v4[byteCount] = cast[uint8](currentByte)
|
|
currentByte = 0
|
|
byteCount.inc
|
|
seperatorValid = false
|
|
else:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Address contains an invalid character")
|
|
|
|
if byteCount != 3 or not seperatorValid:
|
|
raise newException(ValueError, "Invalid IP Address")
|
|
result.address_v4[byteCount] = cast[uint8](currentByte)
|
|
|
|
proc parseIPv6Address(address_str: string): IpAddress =
|
|
## Parses IPv6 adresses
|
|
## Raises EInvalidValue on errors
|
|
result.family = IpAddressFamily.IPv6
|
|
if address_str.len < 2:
|
|
raise newException(ValueError, "Invalid IP Address")
|
|
|
|
var
|
|
groupCount = 0
|
|
currentGroupStart = 0
|
|
currentShort:uint32 = 0
|
|
seperatorValid = true
|
|
dualColonGroup = -1
|
|
lastWasColon = false
|
|
v4StartPos = -1
|
|
byteCount = 0
|
|
|
|
for i,c in address_str:
|
|
if c == ':':
|
|
if not seperatorValid:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Address contains an invalid seperator")
|
|
if lastWasColon:
|
|
if dualColonGroup != -1:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Address contains more than one \"::\" seperator")
|
|
dualColonGroup = groupCount
|
|
seperatorValid = false
|
|
elif i != 0 and i != high(address_str):
|
|
if groupCount >= 8:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. The address consists of too many groups")
|
|
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
|
|
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
|
|
currentShort = 0
|
|
groupCount.inc()
|
|
if dualColonGroup != -1: seperatorValid = false
|
|
elif i == 0: # only valid if address starts with ::
|
|
if address_str[1] != ':':
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Address may not start with \":\"")
|
|
else: # i == high(address_str) - only valid if address ends with ::
|
|
if address_str[high(address_str)-1] != ':':
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Address may not end with \":\"")
|
|
lastWasColon = true
|
|
currentGroupStart = i + 1
|
|
elif c == '.': # Switch to parse IPv4 mode
|
|
if i < 3 or not seperatorValid or groupCount >= 7:
|
|
raise newException(ValueError, "Invalid IP Address")
|
|
v4StartPos = currentGroupStart
|
|
currentShort = 0
|
|
seperatorValid = false
|
|
break
|
|
elif c in strutils.HexDigits:
|
|
if c in strutils.Digits: # Normal digit
|
|
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
|
|
elif c >= 'a' and c <= 'f': # Lower case hex
|
|
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
|
|
else: # Upper case hex
|
|
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
|
|
if currentShort > 65535'u32:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Value is out of range")
|
|
lastWasColon = false
|
|
seperatorValid = true
|
|
else:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Address contains an invalid character")
|
|
|
|
|
|
if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
|
|
if seperatorValid: # Copy remaining data
|
|
if groupCount >= 8:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. The address consists of too many groups")
|
|
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
|
|
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
|
|
groupCount.inc()
|
|
else: # Must parse IPv4 address
|
|
for i,c in address_str[v4StartPos..high(address_str)]:
|
|
if c in strutils.Digits: # Character is a number
|
|
currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
|
|
if currentShort > 255'u32:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Value is out of range")
|
|
seperatorValid = true
|
|
elif c == '.': # IPv4 address separator
|
|
if not seperatorValid or byteCount >= 3:
|
|
raise newException(ValueError, "Invalid IP Address")
|
|
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
|
|
currentShort = 0
|
|
byteCount.inc()
|
|
seperatorValid = false
|
|
else: # Invalid character
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. Address contains an invalid character")
|
|
|
|
if byteCount != 3 or not seperatorValid:
|
|
raise newException(ValueError, "Invalid IP Address")
|
|
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
|
|
groupCount += 2
|
|
|
|
# Shift and fill zeros in case of ::
|
|
if groupCount > 8:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. The address consists of too many groups")
|
|
elif groupCount < 8: # must fill
|
|
if dualColonGroup == -1:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. The address consists of too few groups")
|
|
var toFill = 8 - groupCount # The number of groups to fill
|
|
var toShift = groupCount - dualColonGroup # Nr of known groups after ::
|
|
for i in 0..2*toShift-1: # shift
|
|
result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
|
|
for i in 0..2*toFill-1: # fill with 0s
|
|
result.address_v6[dualColonGroup*2+i] = 0
|
|
elif dualColonGroup != -1:
|
|
raise newException(ValueError,
|
|
"Invalid IP Address. The address consists of too many groups")
|
|
|
|
proc parseIpAddress*(address_str: string): IpAddress =
|
|
## Parses an IP address
|
|
## Raises EInvalidValue on error
|
|
if address_str == nil:
|
|
raise newException(ValueError, "IP Address string is nil")
|
|
if address_str.contains(':'):
|
|
return parseIPv6Address(address_str)
|
|
else:
|
|
return parseIPv4Address(address_str)
|
|
|
|
proc isIpAddress*(address_str: string): bool {.tags: [].} =
|
|
## Checks if a string is an IP address
|
|
## Returns true if it is, false otherwise
|
|
try:
|
|
discard parseIpAddress(address_str)
|
|
except ValueError:
|
|
return false
|
|
return true
|
|
|
|
when defineSsl:
|
|
CRYPTO_malloc_init()
|
|
doAssert SslLibraryInit() == 1
|
|
SslLoadErrorStrings()
|
|
ErrLoadBioStrings()
|
|
OpenSSL_add_all_algorithms()
|
|
|
|
proc raiseSSLError*(s = "") =
|
|
## Raises a new SSL error.
|
|
if s != "":
|
|
raise newException(SSLError, s)
|
|
let err = ErrPeekLastError()
|
|
if err == 0:
|
|
raise newException(SSLError, "No error reported.")
|
|
if err == -1:
|
|
raiseOSError(osLastError())
|
|
var errStr = $ErrErrorString(err, nil)
|
|
case err
|
|
of 336032814, 336032784:
|
|
errStr = "Please upgrade your OpenSSL library, it does not support the " &
|
|
"necessary protocols. OpenSSL error is: " & errStr
|
|
else:
|
|
discard
|
|
raise newException(SSLError, errStr)
|
|
|
|
proc getExtraData*(ctx: SSLContext, index: int): RootRef =
|
|
## Retrieves arbitrary data stored inside SSLContext.
|
|
if index notin ctx.referencedData:
|
|
raise newException(IndexError, "No data with that index.")
|
|
let res = ctx.context.SSL_CTX_get_ex_data(index.cint)
|
|
if cast[int](res) == 0:
|
|
raiseSSLError()
|
|
return cast[RootRef](res)
|
|
|
|
proc setExtraData*(ctx: SSLContext, index: int, data: RootRef) =
|
|
## Stores arbitrary data inside SSLContext. The unique `index`
|
|
## should be retrieved using getSslContextExtraDataIndex.
|
|
if index in ctx.referencedData:
|
|
GC_unref(getExtraData(ctx, index))
|
|
|
|
if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1:
|
|
raiseSSLError()
|
|
|
|
if index notin ctx.referencedData:
|
|
ctx.referencedData.incl(index)
|
|
GC_ref(data)
|
|
|
|
# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
|
|
proc loadCertificates(ctx: SSL_CTX, certFile, keyFile: string) =
|
|
if certFile != "" and not existsFile(certFile):
|
|
raise newException(system.IOError, "Certificate file could not be found: " & certFile)
|
|
if keyFile != "" and not existsFile(keyFile):
|
|
raise newException(system.IOError, "Key file could not be found: " & keyFile)
|
|
|
|
if certFile != "":
|
|
var ret = SSLCTXUseCertificateChainFile(ctx, certFile)
|
|
if ret != 1:
|
|
raiseSSLError()
|
|
|
|
# TODO: Password? www.rtfm.com/openssl-examples/part1.pdf
|
|
if keyFile != "":
|
|
if SSL_CTX_use_PrivateKey_file(ctx, keyFile,
|
|
SSL_FILETYPE_PEM) != 1:
|
|
raiseSSLError()
|
|
|
|
if SSL_CTX_check_private_key(ctx) != 1:
|
|
raiseSSLError("Verification of private key file failed.")
|
|
|
|
proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
|
|
certFile = "", keyFile = "", cipherList = "ALL"): SSLContext =
|
|
## Creates an SSL context.
|
|
##
|
|
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
|
|
## are available with the addition of ``protSSLv23`` which allows for
|
|
## compatibility with all of them.
|
|
##
|
|
## There are currently only two options for verify mode;
|
|
## one is ``CVerifyNone`` and with it certificates will not be verified
|
|
## the other is ``CVerifyPeer`` and certificates will be verified for
|
|
## it, ``CVerifyPeer`` is the safest choice.
|
|
##
|
|
## The last two parameters specify the certificate file path and the key file
|
|
## path, a server socket will most likely not work without these.
|
|
## Certificates can be generated using the following command:
|
|
## ``openssl req -x509 -nodes -days 365 -newkey rsa:1024 -keyout mycert.pem -out mycert.pem``.
|
|
var newCTX: SSL_CTX
|
|
case protVersion
|
|
of protSSLv23:
|
|
newCTX = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support.
|
|
of protSSLv2:
|
|
raiseSslError("SSLv2 is no longer secure and has been deprecated, use protSSLv23")
|
|
of protSSLv3:
|
|
raiseSslError("SSLv3 is no longer secure and has been deprecated, use protSSLv23")
|
|
of protTLSv1:
|
|
newCTX = SSL_CTX_new(TLSv1_method())
|
|
|
|
if newCTX.SSLCTXSetCipherList(cipherList) != 1:
|
|
raiseSSLError()
|
|
case verifyMode
|
|
of CVerifyPeer:
|
|
newCTX.SSLCTXSetVerify(SSLVerifyPeer, nil)
|
|
of CVerifyNone:
|
|
newCTX.SSLCTXSetVerify(SSLVerifyNone, nil)
|
|
if newCTX == nil:
|
|
raiseSSLError()
|
|
|
|
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
|
|
newCTX.loadCertificates(certFile, keyFile)
|
|
|
|
result = SSLContext(context: newCTX, referencedData: initSet[int](),
|
|
extraInternal: new(SslContextExtraInternal))
|
|
|
|
proc getExtraInternal(ctx: SSLContext): SslContextExtraInternal =
|
|
return ctx.extraInternal
|
|
|
|
proc destroyContext*(ctx: SSLContext) =
|
|
## Free memory referenced by SSLContext.
|
|
|
|
# We assume here that OpenSSL's internal indexes increase by 1 each time.
|
|
# That means we can assume that the next internal index is the length of
|
|
# extra data indexes.
|
|
for i in ctx.referencedData:
|
|
GC_unref(getExtraData(ctx, i).RootRef)
|
|
ctx.context.SSL_CTX_free()
|
|
|
|
proc `pskIdentityHint=`*(ctx: SSLContext, hint: string) =
|
|
## Sets the identity hint passed to server.
|
|
##
|
|
## Only used in PSK ciphersuites.
|
|
if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0:
|
|
raiseSSLError()
|
|
|
|
proc clientGetPskFunc*(ctx: SSLContext): SslClientGetPskFunc =
|
|
return ctx.getExtraInternal().clientGetPskFunc
|
|
|
|
proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring; max_identity_len: cuint; psk: ptr cuchar;
|
|
max_psk_len: cuint): cuint {.cdecl.} =
|
|
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
|
|
let hintString = if hint == nil: nil else: $hint
|
|
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
|
|
if psk.len.cuint > max_psk_len:
|
|
return 0
|
|
if identityString.len.cuint >= max_identity_len:
|
|
return 0
|
|
|
|
copyMem(identity, identityString.cstring, pskString.len + 1) # with the last zero byte
|
|
copyMem(psk, pskString.cstring, pskString.len)
|
|
|
|
return pskString.len.cuint
|
|
|
|
proc `clientGetPskFunc=`*(ctx: SSLContext, fun: SslClientGetPskFunc) =
|
|
## Sets function that returns the client identity and the PSK based on identity
|
|
## hint from the server.
|
|
##
|
|
## Only used in PSK ciphersuites.
|
|
ctx.getExtraInternal().clientGetPskFunc = fun
|
|
ctx.context.SSL_CTX_set_psk_client_callback(
|
|
if fun == nil: nil else: pskClientCallback)
|
|
|
|
proc serverGetPskFunc*(ctx: SSLContext): SslServerGetPskFunc =
|
|
return ctx.getExtraInternal().serverGetPskFunc
|
|
|
|
proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr cuchar; max_psk_len: cint): cuint {.cdecl.} =
|
|
let ctx = SSLContext(context: ssl.SSL_get_SSL_CTX)
|
|
let pskString = (ctx.serverGetPskFunc)($identity)
|
|
if psk.len.cint > max_psk_len:
|
|
return 0
|
|
copyMem(psk, pskString.cstring, pskString.len)
|
|
|
|
return pskString.len.cuint
|
|
|
|
proc `serverGetPskFunc=`*(ctx: SSLContext, fun: SslServerGetPskFunc) =
|
|
## Sets function that returns PSK based on the client identity.
|
|
##
|
|
## Only used in PSK ciphersuites.
|
|
ctx.getExtraInternal().serverGetPskFunc = fun
|
|
ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil
|
|
else: pskServerCallback)
|
|
|
|
proc getPskIdentity*(socket: Socket): string =
|
|
## Gets the PSK identity provided by the client.
|
|
assert socket.isSSL
|
|
return $(socket.sslHandle.SSL_get_psk_identity)
|
|
|
|
proc wrapSocket*(ctx: SSLContext, socket: Socket) =
|
|
## Wraps a socket in an SSL context. This function effectively turns
|
|
## ``socket`` into an SSL socket.
|
|
##
|
|
## This must be called on an unconnected socket; an SSL session will
|
|
## be started when the socket is connected.
|
|
##
|
|
## **Disclaimer**: This code is not well tested, may be very unsafe and
|
|
## prone to security vulnerabilities.
|
|
|
|
assert(not socket.isSSL)
|
|
socket.isSSL = true
|
|
socket.sslContext = ctx
|
|
socket.sslHandle = SSLNew(socket.sslContext.context)
|
|
socket.sslNoHandshake = false
|
|
socket.sslHasPeekChar = false
|
|
if socket.sslHandle == nil:
|
|
raiseSSLError()
|
|
|
|
if SSLSetFd(socket.sslHandle, socket.fd) != 1:
|
|
raiseSSLError()
|
|
|
|
proc wrapConnectedSocket*(ctx: SSLContext, socket: Socket,
|
|
handshake: SslHandshakeType,
|
|
hostname: string = nil) =
|
|
## Wraps a connected socket in an SSL context. This function effectively
|
|
## turns ``socket`` into an SSL socket.
|
|
## ``hostname`` should be specified so that the client knows which hostname
|
|
## the server certificate should be validated against.
|
|
##
|
|
## This should be called on a connected socket, and will perform
|
|
## an SSL handshake immediately.
|
|
##
|
|
## **Disclaimer**: This code is not well tested, may be very unsafe and
|
|
## prone to security vulnerabilities.
|
|
wrapSocket(ctx, socket)
|
|
case handshake
|
|
of handshakeAsClient:
|
|
if not hostname.isNil and not isIpAddress(hostname):
|
|
# Discard result in case OpenSSL version doesn't support SNI, or we're
|
|
# not using TLSv1+
|
|
discard SSL_set_tlsext_host_name(socket.sslHandle, hostname)
|
|
let ret = SSLConnect(socket.sslHandle)
|
|
socketError(socket, ret)
|
|
of handshakeAsServer:
|
|
let ret = SSLAccept(socket.sslHandle)
|
|
socketError(socket, ret)
|
|
|
|
proc getSocketError*(socket: Socket): OSErrorCode =
|
|
## Checks ``osLastError`` for a valid error. If it has been reset it uses
|
|
## the last error stored in the socket object.
|
|
result = osLastError()
|
|
if result == 0.OSErrorCode:
|
|
result = socket.lastError
|
|
if result == 0.OSErrorCode:
|
|
raiseOSError(result, "No valid socket error code available")
|
|
|
|
proc socketError*(socket: Socket, err: int = -1, async = false,
|
|
lastError = (-1).OSErrorCode) =
|
|
## Raises an OSError based on the error code returned by ``SSLGetError``
|
|
## (for SSL sockets) and ``osLastError`` otherwise.
|
|
##
|
|
## If ``async`` is ``true`` no error will be thrown in the case when the
|
|
## error was caused by no data being available to be read.
|
|
##
|
|
## If ``err`` is not lower than 0 no exception will be raised.
|
|
when defineSsl:
|
|
if socket.isSSL:
|
|
if err <= 0:
|
|
var ret = SSLGetError(socket.sslHandle, err.cint)
|
|
case ret
|
|
of SSL_ERROR_ZERO_RETURN:
|
|
raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
|
|
of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
|
|
if async:
|
|
return
|
|
else: raiseSSLError("Not enough data on socket.")
|
|
of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ:
|
|
if async:
|
|
return
|
|
else: raiseSSLError("Not enough data on socket.")
|
|
of SSL_ERROR_WANT_X509_LOOKUP:
|
|
raiseSSLError("Function for x509 lookup has been called.")
|
|
of SSL_ERROR_SYSCALL:
|
|
var errStr = "IO error has occurred "
|
|
let sslErr = ErrPeekLastError()
|
|
if sslErr == 0 and err == 0:
|
|
errStr.add "because an EOF was observed that violates the protocol"
|
|
elif sslErr == 0 and err == -1:
|
|
errStr.add "in the BIO layer"
|
|
else:
|
|
let errStr = $ErrErrorString(sslErr, nil)
|
|
raiseSSLError(errStr & ": " & errStr)
|
|
let osErr = osLastError()
|
|
raiseOSError(osErr, errStr)
|
|
of SSL_ERROR_SSL:
|
|
raiseSSLError()
|
|
else: raiseSSLError("Unknown Error")
|
|
|
|
if err == -1 and not (when defineSsl: socket.isSSL else: false):
|
|
var lastE = if lastError.int == -1: getSocketError(socket) else: lastError
|
|
if async:
|
|
when useWinVersion:
|
|
if lastE.int32 == WSAEWOULDBLOCK:
|
|
return
|
|
else: raiseOSError(lastE)
|
|
else:
|
|
if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK:
|
|
return
|
|
else: raiseOSError(lastE)
|
|
else: raiseOSError(lastE)
|
|
|
|
proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
|
|
## Marks ``socket`` as accepting connections.
|
|
## ``Backlog`` specifies the maximum length of the
|
|
## queue of pending connections.
|
|
##
|
|
## Raises an EOS error upon failure.
|
|
if nativesockets.listen(socket.fd, backlog) < 0'i32:
|
|
raiseOSError(osLastError())
|
|
|
|
proc bindAddr*(socket: Socket, port = Port(0), address = "") {.
|
|
tags: [ReadIOEffect].} =
|
|
## Binds ``address``:``port`` to the socket.
|
|
##
|
|
## If ``address`` is "" then ADDR_ANY will be bound.
|
|
|
|
if address == "":
|
|
var name: Sockaddr_in
|
|
when useWinVersion:
|
|
name.sin_family = toInt(AF_INET).int16
|
|
else:
|
|
name.sin_family = toInt(AF_INET)
|
|
name.sin_port = htons(port.uint16)
|
|
name.sin_addr.s_addr = htonl(INADDR_ANY)
|
|
if bindAddr(socket.fd, cast[ptr SockAddr](addr(name)),
|
|
sizeof(name).SockLen) < 0'i32:
|
|
raiseOSError(osLastError())
|
|
else:
|
|
var aiList = getAddrInfo(address, port, socket.domain)
|
|
if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.SockLen) < 0'i32:
|
|
freeAddrInfo(aiList)
|
|
raiseOSError(osLastError())
|
|
freeAddrInfo(aiList)
|
|
|
|
proc acceptAddr*(server: Socket, client: var Socket, address: var string,
|
|
flags = {SocketFlag.SafeDisconn}) {.
|
|
tags: [ReadIOEffect], gcsafe, locks: 0.} =
|
|
## Blocks until a connection is being made from a client. When a connection
|
|
## is made sets ``client`` to the client socket and ``address`` to the address
|
|
## of the connecting client.
|
|
## This function will raise EOS if an error occurs.
|
|
##
|
|
## The resulting client will inherit any properties of the server socket. For
|
|
## example: whether the socket is buffered or not.
|
|
##
|
|
## **Note**: ``client`` must be initialised (with ``new``), this function
|
|
## makes no effort to initialise the ``client`` variable.
|
|
##
|
|
## The ``accept`` call may result in an error if the connecting socket
|
|
## disconnects during the duration of the ``accept``. If the ``SafeDisconn``
|
|
## flag is specified then this error will not be raised and instead
|
|
## accept will be called again.
|
|
assert(client != nil)
|
|
assert client.fd.int <= 0, "Client socket needs to be initialised with " &
|
|
"`new`, not `newSocket`."
|
|
let ret = accept(server.fd)
|
|
let sock = ret[0]
|
|
|
|
if sock == osInvalidSocket:
|
|
let err = osLastError()
|
|
if flags.isDisconnectionError(err):
|
|
acceptAddr(server, client, address, flags)
|
|
raiseOSError(err)
|
|
else:
|
|
address = ret[1]
|
|
client.fd = sock
|
|
client.isBuffered = server.isBuffered
|
|
|
|
# Handle SSL.
|
|
when defineSsl:
|
|
if server.isSSL:
|
|
# We must wrap the client sock in a ssl context.
|
|
|
|
server.sslContext.wrapSocket(client)
|
|
let ret = SSLAccept(client.sslHandle)
|
|
socketError(client, ret, false)
|
|
|
|
when false: #defineSsl:
|
|
proc acceptAddrSSL*(server: Socket, client: var Socket,
|
|
address: var string): SSLAcceptResult {.
|
|
tags: [ReadIOEffect].} =
|
|
## This procedure should only be used for non-blocking **SSL** sockets.
|
|
## It will immediately return with one of the following values:
|
|
##
|
|
## ``AcceptSuccess`` will be returned when a client has been successfully
|
|
## accepted and the handshake has been successfully performed between
|
|
## ``server`` and the newly connected client.
|
|
##
|
|
## ``AcceptNoHandshake`` will be returned when a client has been accepted
|
|
## but no handshake could be performed. This can happen when the client
|
|
## connects but does not yet initiate a handshake. In this case
|
|
## ``acceptAddrSSL`` should be called again with the same parameters.
|
|
##
|
|
## ``AcceptNoClient`` will be returned when no client is currently attempting
|
|
## to connect.
|
|
template doHandshake(): untyped =
|
|
when defineSsl:
|
|
if server.isSSL:
|
|
client.setBlocking(false)
|
|
# We must wrap the client sock in a ssl context.
|
|
|
|
if not client.isSSL or client.sslHandle == nil:
|
|
server.sslContext.wrapSocket(client)
|
|
let ret = SSLAccept(client.sslHandle)
|
|
while ret <= 0:
|
|
let err = SSLGetError(client.sslHandle, ret)
|
|
if err != SSL_ERROR_WANT_ACCEPT:
|
|
case err
|
|
of SSL_ERROR_ZERO_RETURN:
|
|
raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
|
|
of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE,
|
|
SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
|
|
client.sslNoHandshake = true
|
|
return AcceptNoHandshake
|
|
of SSL_ERROR_WANT_X509_LOOKUP:
|
|
raiseSSLError("Function for x509 lookup has been called.")
|
|
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
|
|
raiseSSLError()
|
|
else:
|
|
raiseSSLError("Unknown error")
|
|
client.sslNoHandshake = false
|
|
|
|
if client.isSSL and client.sslNoHandshake:
|
|
doHandshake()
|
|
return AcceptSuccess
|
|
else:
|
|
acceptAddrPlain(AcceptNoClient, AcceptSuccess):
|
|
doHandshake()
|
|
|
|
proc accept*(server: Socket, client: var Socket,
|
|
flags = {SocketFlag.SafeDisconn}) {.tags: [ReadIOEffect].} =
|
|
## Equivalent to ``acceptAddr`` but doesn't return the address, only the
|
|
## socket.
|
|
##
|
|
## **Note**: ``client`` must be initialised (with ``new``), this function
|
|
## makes no effort to initialise the ``client`` variable.
|
|
##
|
|
## The ``accept`` call may result in an error if the connecting socket
|
|
## disconnects during the duration of the ``accept``. If the ``SafeDisconn``
|
|
## flag is specified then this error will not be raised and instead
|
|
## accept will be called again.
|
|
var addrDummy = ""
|
|
acceptAddr(server, client, addrDummy, flags)
|
|
|
|
proc close*(socket: Socket) =
|
|
## Closes a socket.
|
|
try:
|
|
when defineSsl:
|
|
if socket.isSSL and socket.sslHandle != nil:
|
|
ErrClearError()
|
|
# As we are closing the underlying socket immediately afterwards,
|
|
# it is valid, under the TLS standard, to perform a unidirectional
|
|
# shutdown i.e not wait for the peers "close notify" alert with a second
|
|
# call to SSLShutdown
|
|
let res = SSLShutdown(socket.sslHandle)
|
|
if res == 0:
|
|
discard
|
|
elif res != 1:
|
|
socketError(socket, res)
|
|
finally:
|
|
when defineSsl:
|
|
if socket.isSSL and socket.sslHandle != nil:
|
|
SSLFree(socket.sslHandle)
|
|
socket.sslHandle = nil
|
|
|
|
socket.fd.close()
|
|
socket.fd = osInvalidSocket
|
|
|
|
when defined(posix):
|
|
from posix import TCP_NODELAY
|
|
else:
|
|
from winlean import TCP_NODELAY
|
|
|
|
proc toCInt*(opt: SOBool): cint =
|
|
## Converts a ``SOBool`` into its Socket Option cint representation.
|
|
case opt
|
|
of OptAcceptConn: SO_ACCEPTCONN
|
|
of OptBroadcast: SO_BROADCAST
|
|
of OptDebug: SO_DEBUG
|
|
of OptDontRoute: SO_DONTROUTE
|
|
of OptKeepAlive: SO_KEEPALIVE
|
|
of OptOOBInline: SO_OOBINLINE
|
|
of OptReuseAddr: SO_REUSEADDR
|
|
of OptReusePort: SO_REUSEPORT
|
|
of OptNoDelay: TCP_NODELAY
|
|
|
|
proc getSockOpt*(socket: Socket, opt: SOBool, level = SOL_SOCKET): bool {.
|
|
tags: [ReadIOEffect].} =
|
|
## Retrieves option ``opt`` as a boolean value.
|
|
var res = getSockOptInt(socket.fd, cint(level), toCInt(opt))
|
|
result = res != 0
|
|
|
|
proc getLocalAddr*(socket: Socket): (string, Port) =
|
|
## Get the socket's local address and port number.
|
|
##
|
|
## This is high-level interface for `getsockname`:idx:.
|
|
getLocalAddr(socket.fd, socket.domain)
|
|
|
|
proc getPeerAddr*(socket: Socket): (string, Port) =
|
|
## Get the socket's peer address and port number.
|
|
##
|
|
## This is high-level interface for `getpeername`:idx:.
|
|
getPeerAddr(socket.fd, socket.domain)
|
|
|
|
proc setSockOpt*(socket: Socket, opt: SOBool, value: bool, level = SOL_SOCKET) {.
|
|
tags: [WriteIOEffect].} =
|
|
## Sets option ``opt`` to a boolean value specified by ``value``.
|
|
##
|
|
## .. code-block:: Nim
|
|
## var socket = newSocket()
|
|
## socket.setSockOpt(OptReusePort, true)
|
|
## socket.setSockOpt(OptNoDelay, true, level=IPPROTO_TCP.toInt)
|
|
##
|
|
var valuei = cint(if value: 1 else: 0)
|
|
setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei)
|
|
|
|
when defined(posix) and not defined(nimdoc):
|
|
proc makeUnixAddr(path: string): Sockaddr_un =
|
|
result.sun_family = AF_UNIX.toInt
|
|
if path.len >= Sockaddr_un_path_length:
|
|
raise newException(ValueError, "socket path too long")
|
|
copyMem(addr result.sun_path, path.cstring, path.len + 1)
|
|
|
|
when defined(posix):
|
|
proc connectUnix*(socket: Socket, path: string) =
|
|
## Connects to Unix socket on `path`.
|
|
## This only works on Unix-style systems: Mac OS X, BSD and Linux
|
|
when not defined(nimdoc):
|
|
var socketAddr = makeUnixAddr(path)
|
|
if socket.fd.connect(cast[ptr SockAddr](addr socketAddr),
|
|
sizeof(socketAddr).Socklen) != 0'i32:
|
|
raiseOSError(osLastError())
|
|
|
|
proc bindUnix*(socket: Socket, path: string) =
|
|
## Binds Unix socket to `path`.
|
|
## This only works on Unix-style systems: Mac OS X, BSD and Linux
|
|
when not defined(nimdoc):
|
|
var socketAddr = makeUnixAddr(path)
|
|
if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr),
|
|
sizeof(socketAddr).Socklen) != 0'i32:
|
|
raiseOSError(osLastError())
|
|
|
|
when defined(ssl):
|
|
proc handshake*(socket: Socket): bool
|
|
{.tags: [ReadIOEffect, WriteIOEffect], deprecated.} =
|
|
## This proc needs to be called on a socket after it connects. This is
|
|
## only applicable when using ``connectAsync``.
|
|
## This proc performs the SSL handshake.
|
|
##
|
|
## Returns ``False`` whenever the socket is not yet ready for a handshake,
|
|
## ``True`` whenever handshake completed successfully.
|
|
##
|
|
## A ESSL error is raised on any other errors.
|
|
##
|
|
## **Note:** This procedure is deprecated since version 0.14.0.
|
|
result = true
|
|
if socket.isSSL:
|
|
var ret = SSLConnect(socket.sslHandle)
|
|
if ret <= 0:
|
|
var errret = SSLGetError(socket.sslHandle, ret)
|
|
case errret
|
|
of SSL_ERROR_ZERO_RETURN:
|
|
raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
|
|
of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT,
|
|
SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE:
|
|
return false
|
|
of SSL_ERROR_WANT_X509_LOOKUP:
|
|
raiseSSLError("Function for x509 lookup has been called.")
|
|
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
|
|
raiseSSLError()
|
|
else:
|
|
raiseSSLError("Unknown Error")
|
|
socket.sslNoHandshake = false
|
|
else:
|
|
raiseSSLError("Socket is not an SSL socket.")
|
|
|
|
proc gotHandshake*(socket: Socket): bool =
|
|
## Determines whether a handshake has occurred between a client (``socket``)
|
|
## and the server that ``socket`` is connected to.
|
|
##
|
|
## Throws ESSL if ``socket`` is not an SSL socket.
|
|
if socket.isSSL:
|
|
return not socket.sslNoHandshake
|
|
else:
|
|
raiseSSLError("Socket is not an SSL socket.")
|
|
|
|
proc hasDataBuffered*(s: Socket): bool =
|
|
## Determines whether a socket has data buffered.
|
|
result = false
|
|
if s.isBuffered:
|
|
result = s.bufLen > 0 and s.currPos != s.bufLen
|
|
|
|
when defineSsl:
|
|
if s.isSSL and not result:
|
|
result = s.sslHasPeekChar
|
|
|
|
proc select(readfd: Socket, timeout = 500): int =
|
|
## Used for socket operation timeouts.
|
|
if readfd.hasDataBuffered:
|
|
return 1
|
|
|
|
var fds = @[readfd.fd]
|
|
result = select(fds, timeout)
|
|
|
|
proc isClosed(socket: Socket): bool =
|
|
socket.fd == osInvalidSocket
|
|
|
|
proc uniRecv(socket: Socket, buffer: pointer, size, flags: cint): int =
|
|
## Handles SSL and non-ssl recv in a nice package.
|
|
##
|
|
## In particular handles the case where socket has been closed properly
|
|
## for both SSL and non-ssl.
|
|
result = 0
|
|
assert(not socket.isClosed, "Cannot `recv` on a closed socket")
|
|
when defineSsl:
|
|
if socket.isSsl:
|
|
return SSLRead(socket.sslHandle, buffer, size)
|
|
|
|
return recv(socket.fd, buffer, size, flags)
|
|
|
|
proc readIntoBuf(socket: Socket, flags: int32): int =
|
|
result = 0
|
|
result = uniRecv(socket, addr(socket.buffer), socket.buffer.high, flags)
|
|
if result < 0:
|
|
# Save it in case it gets reset (the Nim codegen occasionally may call
|
|
# Win API functions which reset it).
|
|
socket.lastError = osLastError()
|
|
if result <= 0:
|
|
socket.bufLen = 0
|
|
socket.currPos = 0
|
|
return result
|
|
socket.bufLen = result
|
|
socket.currPos = 0
|
|
|
|
template retRead(flags, readBytes: int) {.dirty.} =
|
|
let res = socket.readIntoBuf(flags.int32)
|
|
if res <= 0:
|
|
if readBytes > 0:
|
|
return readBytes
|
|
else:
|
|
return res
|
|
|
|
proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect].} =
|
|
## Receives data from a socket.
|
|
##
|
|
## **Note**: This is a low-level function, you may be interested in the higher
|
|
## level versions of this function which are also named ``recv``.
|
|
if size == 0: return
|
|
if socket.isBuffered:
|
|
if socket.bufLen == 0:
|
|
retRead(0'i32, 0)
|
|
|
|
var read = 0
|
|
while read < size:
|
|
if socket.currPos >= socket.bufLen:
|
|
retRead(0'i32, read)
|
|
|
|
let chunk = min(socket.bufLen-socket.currPos, size-read)
|
|
var d = cast[cstring](data)
|
|
assert size-read >= chunk
|
|
copyMem(addr(d[read]), addr(socket.buffer[socket.currPos]), chunk)
|
|
read.inc(chunk)
|
|
socket.currPos.inc(chunk)
|
|
|
|
result = read
|
|
else:
|
|
when defineSsl:
|
|
if socket.isSSL:
|
|
if socket.sslHasPeekChar: # TODO: Merge this peek char mess into uniRecv
|
|
copyMem(data, addr(socket.sslPeekChar), 1)
|
|
socket.sslHasPeekChar = false
|
|
if size-1 > 0:
|
|
var d = cast[cstring](data)
|
|
result = uniRecv(socket, addr(d[1]), cint(size-1), 0'i32) + 1
|
|
else:
|
|
result = 1
|
|
else:
|
|
result = uniRecv(socket, data, size.cint, 0'i32)
|
|
else:
|
|
result = recv(socket.fd, data, size.cint, 0'i32)
|
|
else:
|
|
result = recv(socket.fd, data, size.cint, 0'i32)
|
|
if result < 0:
|
|
# Save the error in case it gets reset.
|
|
socket.lastError = osLastError()
|
|
|
|
proc waitFor(socket: Socket, waited: var float, timeout, size: int,
|
|
funcName: string): int {.tags: [TimeEffect].} =
|
|
## determines the amount of characters that can be read. Result will never
|
|
## be larger than ``size``. For unbuffered sockets this will be ``1``.
|
|
## For buffered sockets it can be as big as ``BufferSize``.
|
|
##
|
|
## If this function does not determine that there is data on the socket
|
|
## within ``timeout`` ms, an ETimeout error will be raised.
|
|
result = 1
|
|
if size <= 0: assert false
|
|
if timeout == -1: return size
|
|
if socket.isBuffered and socket.bufLen != 0 and socket.bufLen != socket.currPos:
|
|
result = socket.bufLen - socket.currPos
|
|
result = min(result, size)
|
|
else:
|
|
if timeout - int(waited * 1000.0) < 1:
|
|
raise newException(TimeoutError, "Call to '" & funcName & "' timed out.")
|
|
|
|
when defineSsl:
|
|
if socket.isSSL:
|
|
if socket.hasDataBuffered:
|
|
# sslPeekChar is present.
|
|
return 1
|
|
let sslPending = SSLPending(socket.sslHandle)
|
|
if sslPending != 0:
|
|
return sslPending
|
|
|
|
var startTime = epochTime()
|
|
let selRet = select(socket, timeout - int(waited * 1000.0))
|
|
if selRet < 0: raiseOSError(osLastError())
|
|
if selRet != 1:
|
|
raise newException(TimeoutError, "Call to '" & funcName & "' timed out.")
|
|
waited += (epochTime() - startTime)
|
|
|
|
proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {.
|
|
tags: [ReadIOEffect, TimeEffect].} =
|
|
## overload with a ``timeout`` parameter in milliseconds.
|
|
var waited = 0.0 # number of seconds already waited
|
|
|
|
var read = 0
|
|
while read < size:
|
|
let avail = waitFor(socket, waited, timeout, size-read, "recv")
|
|
var d = cast[cstring](data)
|
|
assert avail <= size-read
|
|
result = recv(socket, addr(d[read]), avail)
|
|
if result == 0: break
|
|
if result < 0:
|
|
return result
|
|
inc(read, result)
|
|
|
|
result = read
|
|
|
|
proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
|
|
flags = {SocketFlag.SafeDisconn}): int =
|
|
## Higher-level version of ``recv``.
|
|
##
|
|
## When 0 is returned the socket's connection has been closed.
|
|
##
|
|
## This function will throw an OSError exception when an error occurs. A value
|
|
## lower than 0 is never returned.
|
|
##
|
|
## A timeout may be specified in milliseconds, if enough data is not received
|
|
## within the time specified an TimeoutError exception will be raised.
|
|
##
|
|
## **Note**: ``data`` must be initialised.
|
|
##
|
|
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
|
|
data.setLen(size)
|
|
result =
|
|
if timeout == -1:
|
|
recv(socket, cstring(data), size)
|
|
else:
|
|
recv(socket, cstring(data), size, timeout)
|
|
if result < 0:
|
|
data.setLen(0)
|
|
let lastError = getSocketError(socket)
|
|
if flags.isDisconnectionError(lastError): return
|
|
socket.socketError(result, lastError = lastError)
|
|
data.setLen(result)
|
|
|
|
proc recv*(socket: Socket, size: int, timeout = -1,
|
|
flags = {SocketFlag.SafeDisconn}): string {.inline.} =
|
|
## Higher-level version of ``recv`` which returns a string.
|
|
##
|
|
## When ``""`` is returned the socket's connection has been closed.
|
|
##
|
|
## This function will throw an EOS exception when an error occurs.
|
|
##
|
|
## A timeout may be specified in milliseconds, if enough data is not received
|
|
## within the time specified an ETimeout exception will be raised.
|
|
##
|
|
##
|
|
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
|
|
result = newString(size)
|
|
discard recv(socket, result, size, timeout, flags)
|
|
|
|
proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
|
|
if socket.isBuffered:
|
|
result = 1
|
|
if socket.bufLen == 0 or socket.currPos > socket.bufLen-1:
|
|
var res = socket.readIntoBuf(0'i32)
|
|
if res <= 0:
|
|
result = res
|
|
|
|
c = socket.buffer[socket.currPos]
|
|
else:
|
|
when defineSsl:
|
|
if socket.isSSL:
|
|
if not socket.sslHasPeekChar:
|
|
result = uniRecv(socket, addr(socket.sslPeekChar), 1, 0'i32)
|
|
socket.sslHasPeekChar = true
|
|
|
|
c = socket.sslPeekChar
|
|
return
|
|
result = recv(socket.fd, addr(c), 1, MSG_PEEK)
|
|
|
|
proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
|
|
flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.
|
|
tags: [ReadIOEffect, TimeEffect].} =
|
|
## Reads a line of data from ``socket``.
|
|
##
|
|
## If a full line is read ``\r\L`` is not
|
|
## added to ``line``, however if solely ``\r\L`` is read then ``line``
|
|
## will be set to it.
|
|
##
|
|
## If the socket is disconnected, ``line`` will be set to ``""``.
|
|
##
|
|
## An EOS exception will be raised in the case of a socket error.
|
|
##
|
|
## A timeout can be specified in milliseconds, if data is not received within
|
|
## the specified time an ETimeout exception will be raised.
|
|
##
|
|
## The ``maxLength`` parameter determines the maximum amount of characters
|
|
## that can be read. The result is truncated after that.
|
|
##
|
|
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
|
|
|
|
template addNLIfEmpty() =
|
|
if line.len == 0:
|
|
line.string.add("\c\L")
|
|
|
|
template raiseSockError() {.dirty.} =
|
|
let lastError = getSocketError(socket)
|
|
if flags.isDisconnectionError(lastError): setLen(line.string, 0); return
|
|
socket.socketError(n, lastError = lastError)
|
|
|
|
var waited = 0.0
|
|
|
|
setLen(line.string, 0)
|
|
while true:
|
|
var c: char
|
|
discard waitFor(socket, waited, timeout, 1, "readLine")
|
|
var n = recv(socket, addr(c), 1)
|
|
if n < 0: raiseSockError()
|
|
elif n == 0: setLen(line.string, 0); return
|
|
if c == '\r':
|
|
discard waitFor(socket, waited, timeout, 1, "readLine")
|
|
n = peekChar(socket, c)
|
|
if n > 0 and c == '\L':
|
|
discard recv(socket, addr(c), 1)
|
|
elif n <= 0: raiseSockError()
|
|
addNLIfEmpty()
|
|
return
|
|
elif c == '\L':
|
|
addNLIfEmpty()
|
|
return
|
|
add(line.string, c)
|
|
|
|
# Verify that this isn't a DOS attack: #3847.
|
|
if line.string.len > maxLength: break
|
|
|
|
proc recvLine*(socket: Socket, timeout = -1,
|
|
flags = {SocketFlag.SafeDisconn},
|
|
maxLength = MaxLineLength): TaintedString =
|
|
## Reads a line of data from ``socket``.
|
|
##
|
|
## If a full line is read ``\r\L`` is not
|
|
## added to the result, however if solely ``\r\L`` is read then the result
|
|
## will be set to it.
|
|
##
|
|
## If the socket is disconnected, the result will be set to ``""``.
|
|
##
|
|
## An EOS exception will be raised in the case of a socket error.
|
|
##
|
|
## A timeout can be specified in milliseconds, if data is not received within
|
|
## the specified time an ETimeout exception will be raised.
|
|
##
|
|
## The ``maxLength`` parameter determines the maximum amount of characters
|
|
## that can be read. The result is truncated after that.
|
|
##
|
|
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
|
|
result = ""
|
|
readLine(socket, result, timeout, flags, maxLength)
|
|
|
|
proc recvFrom*(socket: Socket, data: var string, length: int,
|
|
address: var string, port: var Port, flags = 0'i32): int {.
|
|
tags: [ReadIOEffect].} =
|
|
## Receives data from ``socket``. This function should normally be used with
|
|
## connection-less sockets (UDP sockets).
|
|
##
|
|
## If an error occurs an EOS exception will be raised. Otherwise the return
|
|
## value will be the length of data received.
|
|
##
|
|
## **Warning:** This function does not yet have a buffered implementation,
|
|
## so when ``socket`` is buffered the non-buffered implementation will be
|
|
## used. Therefore if ``socket`` contains something in its buffer this
|
|
## function will make no effort to return it.
|
|
|
|
# TODO: Buffered sockets
|
|
data.setLen(length)
|
|
var sockAddress: Sockaddr_in
|
|
var addrLen = sizeof(sockAddress).SockLen
|
|
result = recvfrom(socket.fd, cstring(data), length.cint, flags.cint,
|
|
cast[ptr SockAddr](addr(sockAddress)), addr(addrLen))
|
|
|
|
if result != -1:
|
|
data.setLen(result)
|
|
address = $inet_ntoa(sockAddress.sin_addr)
|
|
port = ntohs(sockAddress.sin_port).Port
|
|
else:
|
|
raiseOSError(osLastError())
|
|
|
|
proc skip*(socket: Socket, size: int, timeout = -1) =
|
|
## Skips ``size`` amount of bytes.
|
|
##
|
|
## An optional timeout can be specified in milliseconds, if skipping the
|
|
## bytes takes longer than specified an ETimeout exception will be raised.
|
|
##
|
|
## Returns the number of skipped bytes.
|
|
var waited = 0.0
|
|
var dummy = alloc(size)
|
|
var bytesSkipped = 0
|
|
while bytesSkipped != size:
|
|
let avail = waitFor(socket, waited, timeout, size-bytesSkipped, "skip")
|
|
bytesSkipped += recv(socket, dummy, avail)
|
|
dealloc(dummy)
|
|
|
|
proc send*(socket: Socket, data: pointer, size: int): int {.
|
|
tags: [WriteIOEffect].} =
|
|
## Sends data to a socket.
|
|
##
|
|
## **Note**: This is a low-level version of ``send``. You likely should use
|
|
## the version below.
|
|
assert(not socket.isClosed, "Cannot `send` on a closed socket")
|
|
when defineSsl:
|
|
if socket.isSSL:
|
|
return SSLWrite(socket.sslHandle, cast[cstring](data), size)
|
|
|
|
when useWinVersion or defined(macosx):
|
|
result = send(socket.fd, data, size.cint, 0'i32)
|
|
else:
|
|
when defined(solaris):
|
|
const MSG_NOSIGNAL = 0
|
|
result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
|
|
|
|
proc send*(socket: Socket, data: string,
|
|
flags = {SocketFlag.SafeDisconn}) {.tags: [WriteIOEffect].} =
|
|
## sends data to a socket.
|
|
let sent = send(socket, cstring(data), data.len)
|
|
if sent < 0:
|
|
let lastError = osLastError()
|
|
if flags.isDisconnectionError(lastError): return
|
|
socketError(socket, lastError = lastError)
|
|
|
|
if sent != data.len:
|
|
raiseOSError(osLastError(), "Could not send all data.")
|
|
|
|
template `&=`*(socket: Socket; data: typed) =
|
|
## an alias for 'send'.
|
|
send(socket, data)
|
|
|
|
proc trySend*(socket: Socket, data: string): bool {.tags: [WriteIOEffect].} =
|
|
## Safe alternative to ``send``. Does not raise an EOS when an error occurs,
|
|
## and instead returns ``false`` on failure.
|
|
result = send(socket, cstring(data), data.len) == data.len
|
|
|
|
proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
|
|
size: int, af: Domain = AF_INET, flags = 0'i32): int {.
|
|
tags: [WriteIOEffect].} =
|
|
## This proc sends ``data`` to the specified ``address``,
|
|
## which may be an IP address or a hostname, if a hostname is specified
|
|
## this function will try each IP of that hostname.
|
|
##
|
|
##
|
|
## **Note:** You may wish to use the high-level version of this function
|
|
## which is defined below.
|
|
##
|
|
## **Note:** This proc is not available for SSL sockets.
|
|
assert(not socket.isClosed, "Cannot `sendTo` on a closed socket")
|
|
var aiList = getAddrInfo(address, port, af)
|
|
|
|
# try all possibilities:
|
|
var success = false
|
|
var it = aiList
|
|
while it != nil:
|
|
result = sendto(socket.fd, data, size.cint, flags.cint, it.ai_addr,
|
|
it.ai_addrlen.SockLen)
|
|
if result != -1'i32:
|
|
success = true
|
|
break
|
|
it = it.ai_next
|
|
|
|
freeAddrInfo(aiList)
|
|
|
|
proc sendTo*(socket: Socket, address: string, port: Port,
|
|
data: string): int {.tags: [WriteIOEffect].} =
|
|
## This proc sends ``data`` to the specified ``address``,
|
|
## which may be an IP address or a hostname, if a hostname is specified
|
|
## this function will try each IP of that hostname.
|
|
##
|
|
## This is the high-level version of the above ``sendTo`` function.
|
|
result = socket.sendTo(address, port, cstring(data), data.len)
|
|
|
|
|
|
proc isSsl*(socket: Socket): bool =
|
|
## Determines whether ``socket`` is a SSL socket.
|
|
when defineSsl:
|
|
result = socket.isSSL
|
|
else:
|
|
result = false
|
|
|
|
proc getFd*(socket: Socket): SocketHandle = return socket.fd
|
|
## Returns the socket's file descriptor
|
|
|
|
proc IPv4_any*(): IpAddress =
|
|
## Returns the IPv4 any address, which can be used to listen on all available
|
|
## network adapters
|
|
result = IpAddress(
|
|
family: IpAddressFamily.IPv4,
|
|
address_v4: [0'u8, 0, 0, 0])
|
|
|
|
proc IPv4_loopback*(): IpAddress =
|
|
## Returns the IPv4 loopback address (127.0.0.1)
|
|
result = IpAddress(
|
|
family: IpAddressFamily.IPv4,
|
|
address_v4: [127'u8, 0, 0, 1])
|
|
|
|
proc IPv4_broadcast*(): IpAddress =
|
|
## Returns the IPv4 broadcast address (255.255.255.255)
|
|
result = IpAddress(
|
|
family: IpAddressFamily.IPv4,
|
|
address_v4: [255'u8, 255, 255, 255])
|
|
|
|
proc IPv6_any*(): IpAddress =
|
|
## Returns the IPv6 any address (::0), which can be used
|
|
## to listen on all available network adapters
|
|
result = IpAddress(
|
|
family: IpAddressFamily.IPv6,
|
|
address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
|
|
|
proc IPv6_loopback*(): IpAddress =
|
|
## Returns the IPv6 loopback address (::1)
|
|
result = IpAddress(
|
|
family: IpAddressFamily.IPv6,
|
|
address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
|
|
|
|
proc `==`*(lhs, rhs: IpAddress): bool =
|
|
## Compares two IpAddresses for Equality. Returns true if the addresses are equal
|
|
if lhs.family != rhs.family: return false
|
|
if lhs.family == IpAddressFamily.IPv4:
|
|
for i in low(lhs.address_v4) .. high(lhs.address_v4):
|
|
if lhs.address_v4[i] != rhs.address_v4[i]: return false
|
|
else: # IPv6
|
|
for i in low(lhs.address_v6) .. high(lhs.address_v6):
|
|
if lhs.address_v6[i] != rhs.address_v6[i]: return false
|
|
return true
|
|
|
|
proc `$`*(address: IpAddress): string =
|
|
## Converts an IpAddress into the textual representation
|
|
result = ""
|
|
case address.family
|
|
of IpAddressFamily.IPv4:
|
|
for i in 0 .. 3:
|
|
if i != 0:
|
|
result.add('.')
|
|
result.add($address.address_v4[i])
|
|
of IpAddressFamily.IPv6:
|
|
var
|
|
currentZeroStart = -1
|
|
currentZeroCount = 0
|
|
biggestZeroStart = -1
|
|
biggestZeroCount = 0
|
|
# Look for the largest block of zeros
|
|
for i in 0..7:
|
|
var isZero = address.address_v6[i*2] == 0 and address.address_v6[i*2+1] == 0
|
|
if isZero:
|
|
if currentZeroStart == -1:
|
|
currentZeroStart = i
|
|
currentZeroCount = 1
|
|
else:
|
|
currentZeroCount.inc()
|
|
if currentZeroCount > biggestZeroCount:
|
|
biggestZeroCount = currentZeroCount
|
|
biggestZeroStart = currentZeroStart
|
|
else:
|
|
currentZeroStart = -1
|
|
|
|
if biggestZeroCount == 8: # Special case ::0
|
|
result.add("::")
|
|
else: # Print address
|
|
var printedLastGroup = false
|
|
for i in 0..7:
|
|
var word:uint16 = (cast[uint16](address.address_v6[i*2])) shl 8
|
|
word = word or cast[uint16](address.address_v6[i*2+1])
|
|
|
|
if biggestZeroCount != 0 and # Check if group is in skip group
|
|
(i >= biggestZeroStart and i < (biggestZeroStart + biggestZeroCount)):
|
|
if i == biggestZeroStart: # skip start
|
|
result.add("::")
|
|
printedLastGroup = false
|
|
else:
|
|
if printedLastGroup:
|
|
result.add(':')
|
|
var
|
|
afterLeadingZeros = false
|
|
mask = 0xF000'u16
|
|
for j in 0'u16..3'u16:
|
|
var val = (mask and word) shr (4'u16*(3'u16-j))
|
|
if val != 0 or afterLeadingZeros:
|
|
if val < 0xA:
|
|
result.add(chr(uint16(ord('0'))+val))
|
|
else: # val >= 0xA
|
|
result.add(chr(uint16(ord('a'))+val-0xA))
|
|
afterLeadingZeros = true
|
|
mask = mask shr 4
|
|
printedLastGroup = true
|
|
|
|
proc dial*(address: string, port: Port,
|
|
protocol = IPPROTO_TCP, buffered = true): Socket
|
|
{.tags: [ReadIOEffect, WriteIOEffect].} =
|
|
## Establishes connection to the specified ``address``:``port`` pair via the
|
|
## specified protocol. The procedure iterates through possible
|
|
## resolutions of the ``address`` until it succeeds, meaning that it
|
|
## seamlessly works with both IPv4 and IPv6.
|
|
## Returns Socket ready to send or receive data.
|
|
let sockType = protocol.toSockType()
|
|
|
|
let aiList = getAddrInfo(address, port, AF_UNSPEC, sockType, protocol)
|
|
|
|
var fdPerDomain: array[low(Domain).ord..high(Domain).ord, SocketHandle]
|
|
for i in low(fdPerDomain)..high(fdPerDomain):
|
|
fdPerDomain[i] = osInvalidSocket
|
|
template closeUnusedFds(domainToKeep = -1) {.dirty.} =
|
|
for i, fd in fdPerDomain:
|
|
if fd != osInvalidSocket and i != domainToKeep:
|
|
fd.close()
|
|
|
|
var success = false
|
|
var lastError: OSErrorCode
|
|
var it = aiList
|
|
var domain: Domain
|
|
var lastFd: SocketHandle
|
|
while it != nil:
|
|
let domainOpt = it.ai_family.toKnownDomain()
|
|
if domainOpt.isNone:
|
|
it = it.ai_next
|
|
continue
|
|
domain = domainOpt.unsafeGet()
|
|
lastFd = fdPerDomain[ord(domain)]
|
|
if lastFd == osInvalidSocket:
|
|
lastFd = createNativeSocket(domain, sockType, protocol)
|
|
if lastFd == osInvalidSocket:
|
|
# we always raise if socket creation failed, because it means a
|
|
# network system problem (e.g. not enough FDs), and not an unreachable
|
|
# address.
|
|
let err = osLastError()
|
|
freeAddrInfo(aiList)
|
|
closeUnusedFds()
|
|
raiseOSError(err)
|
|
fdPerDomain[ord(domain)] = lastFd
|
|
if connect(lastFd, it.ai_addr, it.ai_addrlen.SockLen) == 0'i32:
|
|
success = true
|
|
break
|
|
lastError = osLastError()
|
|
it = it.ai_next
|
|
freeAddrInfo(aiList)
|
|
closeUnusedFds(ord(domain))
|
|
|
|
if success:
|
|
result = newSocket(lastFd, domain, sockType, protocol)
|
|
elif lastError != 0.OSErrorCode:
|
|
raiseOSError(lastError)
|
|
else:
|
|
raise newException(IOError, "Couldn't resolve address: " & address)
|
|
|
|
proc connect*(socket: Socket, address: string,
|
|
port = Port(0)) {.tags: [ReadIOEffect].} =
|
|
## Connects socket to ``address``:``port``. ``Address`` can be an IP address or a
|
|
## host name. If ``address`` is a host name, this function will try each IP
|
|
## of that host name. ``htons`` is already performed on ``port`` so you must
|
|
## not do it.
|
|
##
|
|
## If ``socket`` is an SSL socket a handshake will be automatically performed.
|
|
var aiList = getAddrInfo(address, port, socket.domain)
|
|
# try all possibilities:
|
|
var success = false
|
|
var lastError: OSErrorCode
|
|
var it = aiList
|
|
while it != nil:
|
|
if connect(socket.fd, it.ai_addr, it.ai_addrlen.SockLen) == 0'i32:
|
|
success = true
|
|
break
|
|
else: lastError = osLastError()
|
|
it = it.ai_next
|
|
|
|
freeAddrInfo(aiList)
|
|
if not success: raiseOSError(lastError)
|
|
|
|
when defineSsl:
|
|
if socket.isSSL:
|
|
# RFC3546 for SNI specifies that IP addresses are not allowed.
|
|
if not isIpAddress(address):
|
|
# Discard result in case OpenSSL version doesn't support SNI, or we're
|
|
# not using TLSv1+
|
|
discard SSL_set_tlsext_host_name(socket.sslHandle, address)
|
|
|
|
let ret = SSLConnect(socket.sslHandle)
|
|
socketError(socket, ret)
|
|
|
|
proc connectAsync(socket: Socket, name: string, port = Port(0),
|
|
af: Domain = AF_INET) {.tags: [ReadIOEffect].} =
|
|
## A variant of ``connect`` for non-blocking sockets.
|
|
##
|
|
## This procedure will immediately return, it will not block until a connection
|
|
## is made. It is up to the caller to make sure the connection has been established
|
|
## by checking (using ``select``) whether the socket is writeable.
|
|
##
|
|
## **Note**: For SSL sockets, the ``handshake`` procedure must be called
|
|
## whenever the socket successfully connects to a server.
|
|
var aiList = getAddrInfo(name, port, af)
|
|
# try all possibilities:
|
|
var success = false
|
|
var lastError: OSErrorCode
|
|
var it = aiList
|
|
while it != nil:
|
|
var ret = connect(socket.fd, it.ai_addr, it.ai_addrlen.SockLen)
|
|
if ret == 0'i32:
|
|
success = true
|
|
break
|
|
else:
|
|
lastError = osLastError()
|
|
when useWinVersion:
|
|
# Windows EINTR doesn't behave same as POSIX.
|
|
if lastError.int32 == WSAEWOULDBLOCK:
|
|
success = true
|
|
break
|
|
else:
|
|
if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS:
|
|
success = true
|
|
break
|
|
|
|
it = it.ai_next
|
|
|
|
freeAddrInfo(aiList)
|
|
if not success: raiseOSError(lastError)
|
|
|
|
proc connect*(socket: Socket, address: string, port = Port(0),
|
|
timeout: int) {.tags: [ReadIOEffect, WriteIOEffect].} =
|
|
## Connects to server as specified by ``address`` on port specified by ``port``.
|
|
##
|
|
## The ``timeout`` paremeter specifies the time in milliseconds to allow for
|
|
## the connection to the server to be made.
|
|
socket.fd.setBlocking(false)
|
|
|
|
socket.connectAsync(address, port, socket.domain)
|
|
var s = @[socket.fd]
|
|
if selectWrite(s, timeout) != 1:
|
|
raise newException(TimeoutError, "Call to 'connect' timed out.")
|
|
else:
|
|
let res = getSockOptInt(socket.fd, SOL_SOCKET, SO_ERROR)
|
|
if res != 0:
|
|
raiseOSError(OSErrorCode(res))
|
|
when defineSsl and not defined(nimdoc):
|
|
if socket.isSSL:
|
|
socket.fd.setBlocking(true)
|
|
{.warning[Deprecated]: off.}
|
|
doAssert socket.handshake()
|
|
{.warning[Deprecated]: on.}
|
|
socket.fd.setBlocking(true)
|