Files
Nim/lib/pure/net.nim
bptato c1bb144fdc Fix broken poll and nfds_t bindings (#24331)
This fixes several cases of the Nim binding of nfds_t being inconsistent
with the target platform signedness and/or size.

Additionally, it fixes poll's third argument (timeout) being set to Nim
"int" when it should have been "cint".

The former is the same issue that #23045 had attempted to fix, but
failed because it only considered Linux. (Also, it was only applied to
version 2.0, so the two branches now have incompatible versions of the
same bug.)

Notes:

* SVR4's original "unsigned long" definition is cloned by Linux and
Haiku. Nim got this right for Haiku and Linux-amd64, but it was wrong on
non-amd64 Linux.
* Zephyr does not have nfds_t, but simply uses (signed) "int". This was
already correctly reflected by Nim.
* OpenBSD poll.h uses "unsigned int", and other BSD derivatives follow
suit. This being the most commonly copied definition, the fallback case
now returns cuint. (This also seems to be correct for the OS X headers I
could find on the web.)
* This changes Nintendo Switch nfds_t to cuint from culong. It is
purportedly a FreeBSD derivative, so I *think* this is correct, but I
can't tell because I don't have access to the Nintendo Switch headers.

I have also moved the platform-specific Tnfds to posix.nim so that we
can reuse the fallback logic on all platforms. (e.g. specifying the size
in posix_linux_amd64 only to then use when defined(linux) in posix_other
seems redundant.)

(cherry picked from commit 67442471ae)
2024-12-17 14:35:25 +01:00

2176 lines
80 KiB
Nim

#
#
# Nim's Runtime Library
# (c) Copyright 2015 Dominik Picheta
#
# See the file "copying.txt", included in this
# distribution, for details about the copyright.
#
## This module implements a high-level cross-platform sockets interface.
## The procedures implemented in this module are primarily for blocking sockets.
## For asynchronous non-blocking sockets use the `asyncnet` module together
## with the `asyncdispatch` module.
##
## The first thing you will always need to do in order to start using sockets,
## is to create a new instance of the `Socket` type using the `newSocket`
## procedure.
##
## SSL
## ====
##
## In order to use the SSL procedures defined in this module, you will need to
## compile your application with the `-d:ssl` flag. See the
## `newContext<net.html#newContext%2Cstring%2Cstring%2Cstring%2Cstring>`_
## procedure for additional details.
##
##
## SSL on Windows
## ==============
##
## On Windows the SSL library checks for valid certificates.
## It uses the `cacert.pem` file for this purpose which was extracted
## from `https://curl.se/ca/cacert.pem`. Besides
## the OpenSSL DLLs (e.g. libssl-1_1-x64.dll, libcrypto-1_1-x64.dll) you
## also need to ship `cacert.pem` with your `.exe` file.
##
##
## Examples
## ========
##
## Connecting to a server
## ----------------------
##
## After you create a socket with the `newSocket` procedure, you can easily
## connect it to a server running at a known hostname (or IP address) and port.
## To do so over TCP, use the example below.
runnableExamples("-r:off"):
let socket = newSocket()
socket.connect("google.com", Port(80))
## For SSL, use the following example:
runnableExamples("-r:off -d:ssl"):
let socket = newSocket()
let ctx = newContext()
wrapSocket(ctx, socket)
socket.connect("google.com", Port(443))
## UDP is a connectionless protocol, so UDP sockets don't have to explicitly
## call the `connect <net.html#connect%2CSocket%2Cstring>`_ procedure. They can
## simply start sending data immediately.
runnableExamples("-r:off"):
let socket = newSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)
socket.sendTo("192.168.0.1", Port(27960), "status\n")
runnableExamples("-r:off"):
let socket = newSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)
let ip = parseIpAddress("192.168.0.1")
doAssert socket.sendTo(ip, Port(27960), "status\c\l") == 8
## Creating a server
## -----------------
##
## After you create a socket with the `newSocket` procedure, you can create a
## TCP server by calling the `bindAddr` and `listen` procedures.
runnableExamples("-r:off"):
let socket = newSocket()
socket.bindAddr(Port(1234))
socket.listen()
# You can then begin accepting connections using the `accept` procedure.
var client: Socket
var address = ""
while true:
socket.acceptAddr(client, address)
echo "Client connected from: ", address
import std/private/since
when defined(nimPreviewSlimSystem):
import std/assertions
import nativesockets
import os, strutils, times, sets, options, std/monotimes
import ssl_config
export nativesockets.Port, nativesockets.`$`, nativesockets.`==`
export Domain, SockType, Protocol
const useWinVersion = defined(windows) or defined(nimdoc)
const useNimNetLite = defined(nimNetLite) or defined(freertos) or defined(zephyr) or
defined(nuttx)
const defineSsl = defined(ssl) or defined(nimdoc)
when useWinVersion:
from winlean import WSAESHUTDOWN
when defineSsl:
import openssl
when not defined(nimDisableCertificateValidation):
from ssl_certs import scanSSLCertificates
# Note: The enumerations are mapped to Window's constants.
when defineSsl:
type
Certificate* = string ## DER encoded certificate
SslError* = object of CatchableError
SslCVerifyMode* = enum
CVerifyNone, CVerifyPeer, CVerifyPeerUseEnvVars
SslProtVersion* = enum
protSSLv2, protSSLv3, protTLSv1, protSSLv23
SslContext* = ref object
context*: SslCtx
referencedData: HashSet[int]
extraInternal: SslContextExtraInternal
SslAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
SslHandshakeType* = enum
handshakeAsClient, handshakeAsServer
SslClientGetPskFunc* = proc(hint: string): tuple[identity: string, psk: string]
SslServerGetPskFunc* = proc(identity: string): string
SslContextExtraInternal = ref object of RootRef
serverGetPskFunc: SslServerGetPskFunc
clientGetPskFunc: SslClientGetPskFunc
else:
type
SslContext* = ref object # TODO: Workaround #4797.
const
BufferSize*: int = 4000 ## size of a buffered socket's buffer
MaxLineLength* = 1_000_000
type
SocketImpl* = object ## socket type
fd: SocketHandle
isBuffered: bool # determines whether this socket is buffered.
buffer: array[0..BufferSize, char]
currPos: int # current index in buffer
bufLen: int # current length of buffer
when defineSsl:
isSsl: bool
sslHandle: SslPtr
sslContext: SslContext
sslNoHandshake: bool # True if needs handshake.
sslHasPeekChar: bool
sslPeekChar: char
sslNoShutdown: bool # True if shutdown shouldn't be done.
lastError: OSErrorCode ## stores the last error on this socket
domain: Domain
sockType: SockType
protocol: Protocol
Socket* = ref SocketImpl
SOBool* = enum ## Boolean socket options.
OptAcceptConn, OptBroadcast, OptDebug, OptDontRoute, OptKeepAlive,
OptOOBInline, OptReuseAddr, OptReusePort, OptNoDelay
ReadLineResult* = enum ## result for readLineAsync
ReadFullLine, ReadPartialLine, ReadDisconnected, ReadNone
TimeoutError* = object of CatchableError
SocketFlag* {.pure.} = enum
Peek,
SafeDisconn ## Ensures disconnection exceptions (ECONNRESET, EPIPE etc) are not thrown.
when defined(nimHasStyleChecks):
{.push styleChecks: off.}
type
IpAddressFamily* {.pure.} = enum ## Describes the type of an IP address
IPv6, ## IPv6 address
IPv4 ## IPv4 address
IpAddress* = object ## stores an arbitrary IP address
case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
of IpAddressFamily.IPv6:
address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
## case of IPv6
of IpAddressFamily.IPv4:
address_v4*: array[0..3, uint8] ## Contains the IP address in bytes in
## case of IPv4
when defined(nimHasStyleChecks):
{.pop.}
when defined(posix) and not defined(lwip):
from posix import TPollfd, POLLIN, POLLPRI, POLLOUT, POLLWRBAND, Tnfds
template monitorPollEvent(x: var SocketHandle, y, timeout: cint): int =
var tpollfd: TPollfd
tpollfd.fd = cast[cint](x)
tpollfd.events = y
posix.poll(addr(tpollfd), Tnfds(1), timeout)
proc timeoutRead(fd: var SocketHandle, timeout = 500): int =
when defined(windows) or defined(lwip):
var fds = @[fd]
selectRead(fds, timeout)
else:
monitorPollEvent(fd, POLLIN or POLLPRI, cint(timeout))
proc timeoutWrite(fd: var SocketHandle, timeout = 500): int =
when defined(windows) or defined(lwip):
var fds = @[fd]
selectWrite(fds, timeout)
else:
monitorPollEvent(fd, POLLOUT or POLLWRBAND, cint(timeout))
proc socketError*(socket: Socket, err: int = -1, async = false,
lastError = (-1).OSErrorCode,
flags: set[SocketFlag] = {}) {.gcsafe.}
proc isDisconnectionError*(flags: set[SocketFlag],
lastError: OSErrorCode): bool =
## Determines whether `lastError` is a disconnection error. Only does this
## if flags contains `SafeDisconn`.
when useWinVersion:
SocketFlag.SafeDisconn in flags and
(lastError.int32 == WSAECONNRESET or
lastError.int32 == WSAECONNABORTED or
lastError.int32 == WSAENETRESET or
lastError.int32 == WSAEDISCON or
lastError.int32 == WSAESHUTDOWN or
lastError.int32 == ERROR_NETNAME_DELETED)
else:
SocketFlag.SafeDisconn in flags and
(lastError.int32 == ECONNRESET or
lastError.int32 == EPIPE or
lastError.int32 == ENETRESET)
proc toOSFlags*(socketFlags: set[SocketFlag]): cint =
## Converts the flags into the underlying OS representation.
for f in socketFlags:
case f
of SocketFlag.Peek:
result = result or MSG_PEEK
of SocketFlag.SafeDisconn: continue
proc newSocket*(fd: SocketHandle, domain: Domain = AF_INET,
sockType: SockType = SOCK_STREAM,
protocol: Protocol = IPPROTO_TCP, buffered = true): owned(Socket) =
## Creates a new socket as specified by the params.
assert fd != osInvalidSocket
result = Socket(
fd: fd,
isBuffered: buffered,
domain: domain,
sockType: sockType,
protocol: protocol)
if buffered:
result.currPos = 0
# Set SO_NOSIGPIPE on OS X.
when defined(macosx) and not defined(nimdoc):
setSockOptInt(fd, SOL_SOCKET, SO_NOSIGPIPE, 1)
proc newSocket*(domain, sockType, protocol: cint, buffered = true,
inheritable = defined(nimInheritHandles)): owned(Socket) =
## Creates a new socket.
##
## The SocketHandle associated with the resulting Socket will not be
## inheritable by child processes by default. This can be changed via
## the `inheritable` parameter.
##
## If an error occurs OSError will be raised.
let fd = createNativeSocket(domain, sockType, protocol, inheritable)
if fd == osInvalidSocket:
raiseOSError(osLastError())
result = newSocket(fd, domain.Domain, sockType.SockType, protocol.Protocol,
buffered)
proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
protocol: Protocol = IPPROTO_TCP, buffered = true,
inheritable = defined(nimInheritHandles)): owned(Socket) =
## Creates a new socket.
##
## The SocketHandle associated with the resulting Socket will not be
## inheritable by child processes by default. This can be changed via
## the `inheritable` parameter.
##
## If an error occurs OSError will be raised.
let fd = createNativeSocket(domain, sockType, protocol, inheritable)
if fd == osInvalidSocket:
raiseOSError(osLastError())
result = newSocket(fd, domain, sockType, protocol, buffered)
proc parseIPv4Address(addressStr: string): IpAddress =
## Parses IPv4 addresses
## Raises ValueError on errors
var
byteCount = 0
currentByte: uint16 = 0
separatorValid = false
leadingZero = false
result = IpAddress(family: IpAddressFamily.IPv4)
for i in 0 .. high(addressStr):
if addressStr[i] in strutils.Digits: # Character is a number
if leadingZero:
raise newException(ValueError,
"Invalid IP address. Octal numbers are not allowed")
currentByte = currentByte * 10 +
cast[uint16](ord(addressStr[i]) - ord('0'))
if currentByte == 0'u16:
leadingZero = true
elif currentByte > 255'u16:
raise newException(ValueError,
"Invalid IP Address. Value is out of range")
separatorValid = true
elif addressStr[i] == '.': # IPv4 address separator
if not separatorValid or byteCount >= 3:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
result.address_v4[byteCount] = cast[uint8](currentByte)
currentByte = 0
byteCount.inc
separatorValid = false
leadingZero = false
else:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid character")
if byteCount != 3 or not separatorValid:
raise newException(ValueError, "Invalid IP Address")
result.address_v4[byteCount] = cast[uint8](currentByte)
proc parseIPv6Address(addressStr: string): IpAddress =
## Parses IPv6 addresses
## Raises ValueError on errors
result = IpAddress(family: IpAddressFamily.IPv6)
if addressStr.len < 2:
raise newException(ValueError, "Invalid IP Address")
var
groupCount = 0
currentGroupStart = 0
currentShort: uint32 = 0
separatorValid = true
dualColonGroup = -1
lastWasColon = false
v4StartPos = -1
byteCount = 0
for i, c in addressStr:
if c == ':':
if not separatorValid:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid separator")
if lastWasColon:
if dualColonGroup != -1:
raise newException(ValueError,
"Invalid IP Address. Address contains more than one \"::\" separator")
dualColonGroup = groupCount
separatorValid = false
elif i != 0 and i != high(addressStr):
if groupCount >= 8:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
currentShort = 0
groupCount.inc()
if dualColonGroup != -1: separatorValid = false
elif i == 0: # only valid if address starts with ::
if addressStr[1] != ':':
raise newException(ValueError,
"Invalid IP Address. Address may not start with \":\"")
else: # i == high(addressStr) - only valid if address ends with ::
if addressStr[high(addressStr)-1] != ':':
raise newException(ValueError,
"Invalid IP Address. Address may not end with \":\"")
lastWasColon = true
currentGroupStart = i + 1
elif c == '.': # Switch to parse IPv4 mode
if i < 3 or not separatorValid or groupCount >= 7:
raise newException(ValueError, "Invalid IP Address")
v4StartPos = currentGroupStart
currentShort = 0
separatorValid = false
break
elif c in strutils.HexDigits:
if c in strutils.Digits: # Normal digit
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
elif c >= 'a' and c <= 'f': # Lower case hex
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
else: # Upper case hex
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
if currentShort > 65535'u32:
raise newException(ValueError,
"Invalid IP Address. Value is out of range")
lastWasColon = false
separatorValid = true
else:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid character")
if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
if separatorValid: # Copy remaining data
if groupCount >= 8:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
groupCount.inc()
else: # Must parse IPv4 address
var leadingZero = false
for i, c in addressStr[v4StartPos..high(addressStr)]:
if c in strutils.Digits: # Character is a number
if leadingZero:
raise newException(ValueError,
"Invalid IP address. Octal numbers not allowed")
currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
if currentShort == 0'u32:
leadingZero = true
elif currentShort > 255'u32:
raise newException(ValueError,
"Invalid IP Address. Value is out of range")
separatorValid = true
elif c == '.': # IPv4 address separator
if not separatorValid or byteCount >= 3:
raise newException(ValueError, "Invalid IP Address")
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
currentShort = 0
byteCount.inc()
separatorValid = false
leadingZero = false
else: # Invalid character
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid character")
if byteCount != 3 or not separatorValid:
raise newException(ValueError, "Invalid IP Address")
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
groupCount += 2
# Shift and fill zeros in case of ::
if groupCount > 8:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
elif groupCount < 8: # must fill
if dualColonGroup == -1:
raise newException(ValueError,
"Invalid IP Address. The address consists of too few groups")
var toFill = 8 - groupCount # The number of groups to fill
var toShift = groupCount - dualColonGroup # Nr of known groups after ::
for i in 0..2*toShift-1: # shift
result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
for i in 0..2*toFill-1: # fill with 0s
result.address_v6[dualColonGroup*2+i] = 0
elif dualColonGroup != -1:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
proc parseIpAddress*(addressStr: string): IpAddress =
## Parses an IP address
##
## Raises ValueError on error.
##
## For IPv4 addresses, only the strict form as
## defined in RFC 6943 is considered valid, see
## https://datatracker.ietf.org/doc/html/rfc6943#section-3.1.1.
if addressStr.len == 0:
raise newException(ValueError, "IP Address string is empty")
if addressStr.contains(':'):
return parseIPv6Address(addressStr)
else:
return parseIPv4Address(addressStr)
proc isIpAddress*(addressStr: string): bool {.tags: [].} =
## Checks if a string is an IP address
## Returns true if it is, false otherwise
try:
discard parseIpAddress(addressStr)
except ValueError:
return false
return true
proc toSockAddr*(address: IpAddress, port: Port, sa: var Sockaddr_storage,
sl: var SockLen) =
## Converts `IpAddress` and `Port` to `SockAddr` and `SockLen`
let port = htons(uint16(port))
case address.family
of IpAddressFamily.IPv4:
sl = sizeof(Sockaddr_in).SockLen
let s = cast[ptr Sockaddr_in](addr sa)
s.sin_family = typeof(s.sin_family)(toInt(AF_INET))
s.sin_port = port
copyMem(addr s.sin_addr, unsafeAddr address.address_v4[0],
sizeof(s.sin_addr))
of IpAddressFamily.IPv6:
sl = sizeof(Sockaddr_in6).SockLen
let s = cast[ptr Sockaddr_in6](addr sa)
s.sin6_family = typeof(s.sin6_family)(toInt(AF_INET6))
s.sin6_port = port
copyMem(addr s.sin6_addr, unsafeAddr address.address_v6[0],
sizeof(s.sin6_addr))
proc fromSockAddrAux(sa: ptr Sockaddr_storage, sl: SockLen,
address: var IpAddress, port: var Port) =
if sa.ss_family.cint == toInt(AF_INET) and sl == sizeof(Sockaddr_in).SockLen:
address = IpAddress(family: IpAddressFamily.IPv4)
let s = cast[ptr Sockaddr_in](sa)
copyMem(addr address.address_v4[0], addr s.sin_addr,
sizeof(address.address_v4))
port = ntohs(s.sin_port).Port
elif sa.ss_family.cint == toInt(AF_INET6) and
sl == sizeof(Sockaddr_in6).SockLen:
address = IpAddress(family: IpAddressFamily.IPv6)
let s = cast[ptr Sockaddr_in6](sa)
copyMem(addr address.address_v6[0], addr s.sin6_addr,
sizeof(address.address_v6))
port = ntohs(s.sin6_port).Port
else:
raise newException(ValueError, "Neither IPv4 nor IPv6")
proc fromSockAddr*(sa: Sockaddr_storage | SockAddr | Sockaddr_in | Sockaddr_in6,
sl: SockLen, address: var IpAddress, port: var Port) {.inline.} =
## Converts `SockAddr` and `SockLen` to `IpAddress` and `Port`. Raises
## `ObjectConversionDefect` in case of invalid `sa` and `sl` arguments.
fromSockAddrAux(cast[ptr Sockaddr_storage](unsafeAddr sa), sl, address, port)
when defineSsl:
# OpenSSL >= 1.1.0 does not need explicit init.
when not useOpenssl3:
CRYPTO_malloc_init()
doAssert SslLibraryInit() == 1
SSL_load_error_strings()
ERR_load_BIO_strings()
OpenSSL_add_all_algorithms()
proc sslHandle*(self: Socket): SslPtr =
## Retrieve the ssl pointer of `socket`.
## Useful for interfacing with `openssl`.
self.sslHandle
proc raiseSSLError*(s = "") {.raises: [SslError].}=
## Raises a new SSL error.
if s != "":
raise newException(SslError, s)
let err = ERR_peek_last_error()
if err == 0:
raise newException(SslError, "No error reported.")
var errStr = $ERR_error_string(err, nil)
case err
of 336032814, 336032784:
errStr = "Please upgrade your OpenSSL library, it does not support the " &
"necessary protocols. OpenSSL error is: " & errStr
else:
discard
raise newException(SslError, errStr)
proc getExtraData*(ctx: SslContext, index: int): RootRef =
## Retrieves arbitrary data stored inside SslContext.
if index notin ctx.referencedData:
raise newException(IndexDefect, "No data with that index.")
let res = ctx.context.SSL_CTX_get_ex_data(index.cint)
if cast[int](res) == 0:
raiseSSLError()
return cast[RootRef](res)
proc setExtraData*(ctx: SslContext, index: int, data: RootRef) =
## Stores arbitrary data inside SslContext. The unique `index`
## should be retrieved using getSslContextExtraDataIndex.
if index in ctx.referencedData:
GC_unref(getExtraData(ctx, index))
if ctx.context.SSL_CTX_set_ex_data(index.cint, cast[pointer](data)) == -1:
raiseSSLError()
if index notin ctx.referencedData:
ctx.referencedData.incl(index)
GC_ref(data)
# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
proc loadCertificates(ctx: SslCtx, certFile, keyFile: string) =
if certFile != "" and not fileExists(certFile):
raise newException(system.IOError,
"Certificate file could not be found: " & certFile)
if keyFile != "" and not fileExists(keyFile):
raise newException(system.IOError, "Key file could not be found: " & keyFile)
if certFile != "":
var ret = SSL_CTX_use_certificate_chain_file(ctx, certFile)
if ret != 1:
raiseSSLError()
# TODO: Password? www.rtfm.com/openssl-examples/part1.pdf
if keyFile != "":
if SSL_CTX_use_PrivateKey_file(ctx, keyFile,
SSL_FILETYPE_PEM) != 1:
raiseSSLError()
if SSL_CTX_check_private_key(ctx) != 1:
raiseSSLError("Verification of private key file failed.")
proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
certFile = "", keyFile = "", cipherList = CiphersIntermediate,
caDir = "", caFile = "", ciphersuites = CiphersModern): SslContext =
## Creates an SSL context.
##
## Protocol version is currently ignored by default and TLS is used.
## With `-d:openssl10`, only SSLv23 and TLSv1 may be used.
##
## There are three options for verify mode:
## `CVerifyNone`: certificates are not verified;
## `CVerifyPeer`: certificates are verified;
## `CVerifyPeerUseEnvVars`: certificates are verified and the optional
## environment variables SSL_CERT_FILE and SSL_CERT_DIR are also used to
## locate certificates
##
## The `nimDisableCertificateValidation` define overrides verifyMode and
## disables certificate verification globally!
##
## CA certificates will be loaded, in the following order, from:
##
## - caFile, caDir, parameters, if set
## - if `verifyMode` is set to `CVerifyPeerUseEnvVars`,
## the SSL_CERT_FILE and SSL_CERT_DIR environment variables are used
## - a set of files and directories from the `ssl_certs <ssl_certs.html>`_ file.
##
## The last two parameters specify the certificate file path and the key file
## path, a server socket will most likely not work without these.
##
## Certificates can be generated using the following command:
## - `openssl req -x509 -nodes -days 365 -newkey rsa:4096 -keyout mykey.pem -out mycert.pem`
## or using ECDSA:
## - `openssl ecparam -out mykey.pem -name secp256k1 -genkey`
## - `openssl req -new -key mykey.pem -x509 -nodes -days 365 -out mycert.pem`
var mtd: PSSL_METHOD
when defined(openssl10):
case protVersion
of protSSLv23:
mtd = SSLv23_method()
of protSSLv2:
raiseSSLError("SSLv2 is no longer secure and has been deprecated, use protSSLv23")
of protSSLv3:
raiseSSLError("SSLv3 is no longer secure and has been deprecated, use protSSLv23")
of protTLSv1:
mtd = TLSv1_method()
else:
mtd = TLS_method()
if mtd == nil:
raiseSSLError("Failed to create TLS context")
var newCTX = SSL_CTX_new(mtd)
if newCTX == nil:
raiseSSLError("Failed to create TLS context")
if newCTX.SSL_CTX_set_cipher_list(cipherList) != 1:
raiseSSLError()
when not defined(openssl10) and not defined(libressl):
let sslVersion = getOpenSSLVersion()
if sslVersion >= 0x010101000 and sslVersion != 0x020000000:
# In OpenSSL >= 1.1.1, TLSv1.3 cipher suites can only be configured via
# this API.
if newCTX.SSL_CTX_set_ciphersuites(ciphersuites) != 1:
raiseSSLError()
# Automatically the best ECDH curve for client exchange. Without this, ECDH
# ciphers will be ignored by the server.
#
# From OpenSSL >= 1.1.0, this setting is set by default and can't be
# overridden.
if newCTX.SSL_CTX_set_ecdh_auto(1) != 1:
raiseSSLError()
when defined(nimDisableCertificateValidation):
newCTX.SSL_CTX_set_verify(SSL_VERIFY_NONE, nil)
else:
case verifyMode
of CVerifyPeer, CVerifyPeerUseEnvVars:
newCTX.SSL_CTX_set_verify(SSL_VERIFY_PEER, nil)
of CVerifyNone:
newCTX.SSL_CTX_set_verify(SSL_VERIFY_NONE, nil)
if newCTX == nil:
raiseSSLError()
discard newCTX.SSLCTXSetMode(SSL_MODE_AUTO_RETRY)
newCTX.loadCertificates(certFile, keyFile)
const VerifySuccess = 1 # SSL_CTX_load_verify_locations returns 1 on success.
when not defined(nimDisableCertificateValidation):
if verifyMode != CVerifyNone:
# Use the caDir and caFile parameters if set
if caDir != "" or caFile != "":
if newCTX.SSL_CTX_load_verify_locations(if caFile == "": nil else: caFile.cstring, if caDir == "": nil else: caDir.cstring) != VerifySuccess:
raise newException(IOError, "Failed to load SSL/TLS CA certificate(s).")
else:
# Scan for certs in known locations. For CVerifyPeerUseEnvVars also scan
# the SSL_CERT_FILE and SSL_CERT_DIR env vars
var found = false
let useEnvVars = (if verifyMode == CVerifyPeerUseEnvVars: true else: false)
for fn in scanSSLCertificates(useEnvVars = useEnvVars):
if fn.extractFilename == "":
if newCTX.SSL_CTX_load_verify_locations(nil, cstring(fn.normalizePathEnd(false))) == VerifySuccess:
found = true
break
elif newCTX.SSL_CTX_load_verify_locations(cstring(fn), nil) == VerifySuccess:
found = true
break
if not found:
raise newException(IOError, "No SSL/TLS CA certificates found.")
result = SslContext(context: newCTX, referencedData: initHashSet[int](),
extraInternal: new(SslContextExtraInternal))
proc getExtraInternal(ctx: SslContext): SslContextExtraInternal =
return ctx.extraInternal
proc destroyContext*(ctx: SslContext) =
## Free memory referenced by SslContext.
# We assume here that OpenSSL's internal indexes increase by 1 each time.
# That means we can assume that the next internal index is the length of
# extra data indexes.
for i in ctx.referencedData:
GC_unref(getExtraData(ctx, i))
ctx.context.SSL_CTX_free()
proc `pskIdentityHint=`*(ctx: SslContext, hint: string) =
## Sets the identity hint passed to server.
##
## Only used in PSK ciphersuites.
if ctx.context.SSL_CTX_use_psk_identity_hint(hint) <= 0:
raiseSSLError()
proc clientGetPskFunc*(ctx: SslContext): SslClientGetPskFunc =
return ctx.getExtraInternal().clientGetPskFunc
proc pskClientCallback(ssl: SslPtr; hint: cstring; identity: cstring;
max_identity_len: cuint; psk: ptr uint8;
max_psk_len: cuint): cuint {.cdecl.} =
let ctx = SslContext(context: ssl.SSL_get_SSL_CTX)
let hintString = if hint == nil: "" else: $hint
let (identityString, pskString) = (ctx.clientGetPskFunc)(hintString)
if pskString.len.cuint > max_psk_len:
return 0
if identityString.len.cuint >= max_identity_len:
return 0
copyMem(identity, identityString.cstring, identityString.len + 1) # with the last zero byte
copyMem(psk, pskString.cstring, pskString.len)
return pskString.len.cuint
proc `clientGetPskFunc=`*(ctx: SslContext, fun: SslClientGetPskFunc) =
## Sets function that returns the client identity and the PSK based on identity
## hint from the server.
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().clientGetPskFunc = fun
ctx.context.SSL_CTX_set_psk_client_callback(
if fun == nil: nil else: pskClientCallback)
proc serverGetPskFunc*(ctx: SslContext): SslServerGetPskFunc =
return ctx.getExtraInternal().serverGetPskFunc
proc pskServerCallback(ssl: SslCtx; identity: cstring; psk: ptr uint8;
max_psk_len: cint): cuint {.cdecl.} =
let ctx = SslContext(context: ssl.SSL_get_SSL_CTX)
let pskString = (ctx.serverGetPskFunc)($identity)
if pskString.len.cint > max_psk_len:
return 0
copyMem(psk, pskString.cstring, pskString.len)
return pskString.len.cuint
proc `serverGetPskFunc=`*(ctx: SslContext, fun: SslServerGetPskFunc) =
## Sets function that returns PSK based on the client identity.
##
## Only used in PSK ciphersuites.
ctx.getExtraInternal().serverGetPskFunc = fun
ctx.context.SSL_CTX_set_psk_server_callback(if fun == nil: nil
else: pskServerCallback)
proc getPskIdentity*(socket: Socket): string =
## Gets the PSK identity provided by the client.
assert socket.isSsl
return $(socket.sslHandle.SSL_get_psk_identity)
proc wrapSocket*(ctx: SslContext, socket: Socket) =
## Wraps a socket in an SSL context. This function effectively turns
## `socket` into an SSL socket.
##
## This must be called on an unconnected socket; an SSL session will
## be started when the socket is connected.
##
## FIXME:
## **Disclaimer**: This code is not well tested, may be very unsafe and
## prone to security vulnerabilities.
assert(not socket.isSsl)
socket.isSsl = true
socket.sslContext = ctx
socket.sslHandle = SSL_new(socket.sslContext.context)
socket.sslNoHandshake = false
socket.sslHasPeekChar = false
socket.sslNoShutdown = false
if socket.sslHandle == nil:
raiseSSLError()
if SSL_set_fd(socket.sslHandle, socket.fd) != 1:
raiseSSLError()
proc checkCertName(socket: Socket, hostname: string) {.raises: [SslError], tags:[RootEffect].} =
## Check if the certificate Subject Alternative Name (SAN) or Subject CommonName (CN) matches hostname.
## Wildcards match only in the left-most label.
## When name starts with a dot it will be matched by a certificate valid for any subdomain
when not defined(nimDisableCertificateValidation) and not defined(windows):
assert socket.isSsl
try:
let certificate = socket.sslHandle.SSL_get_peer_certificate()
if certificate.isNil:
raiseSSLError("No SSL certificate found.")
const X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT = 0x1.cuint
# https://www.openssl.org/docs/man1.1.1/man3/X509_check_host.html
let match = certificate.X509_check_host(hostname.cstring, hostname.len.cint,
X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT, nil)
# https://www.openssl.org/docs/man1.1.1/man3/SSL_get_peer_certificate.html
X509_free(certificate)
if match != 1:
raiseSSLError("SSL Certificate check failed.")
except LibraryError:
raiseSSLError("SSL import failed")
proc wrapConnectedSocket*(ctx: SslContext, socket: Socket,
handshake: SslHandshakeType,
hostname: string = "") =
## Wraps a connected socket in an SSL context. This function effectively
## turns `socket` into an SSL socket.
## `hostname` should be specified so that the client knows which hostname
## the server certificate should be validated against.
##
## This should be called on a connected socket, and will perform
## an SSL handshake immediately.
##
## FIXME:
## **Disclaimer**: This code is not well tested, may be very unsafe and
## prone to security vulnerabilities.
wrapSocket(ctx, socket)
case handshake
of handshakeAsClient:
if hostname.len > 0 and not isIpAddress(hostname):
# Discard result in case OpenSSL version doesn't support SNI, or we're
# not using TLSv1+
discard SSL_set_tlsext_host_name(socket.sslHandle, hostname)
ErrClearError()
let ret = SSL_connect(socket.sslHandle)
socketError(socket, ret)
when not defined(nimDisableCertificateValidation) and not defined(windows):
# FIXME: this should be skipped on CVerifyNone
if hostname.len > 0 and not isIpAddress(hostname):
socket.checkCertName(hostname)
of handshakeAsServer:
ErrClearError()
let ret = SSL_accept(socket.sslHandle)
socketError(socket, ret)
proc getPeerCertificates*(sslHandle: SslPtr): seq[Certificate] {.since: (1, 1).} =
## Returns the certificate chain received by the peer we are connected to
## through the OpenSSL connection represented by `sslHandle`.
## The handshake must have been completed and the certificate chain must
## have been verified successfully or else an empty sequence is returned.
## The chain is ordered from leaf certificate to root certificate.
result = newSeq[Certificate]()
if SSL_get_verify_result(sslHandle) != X509_V_OK:
return
let stack = SSL_get0_verified_chain(sslHandle)
if stack == nil:
return
let length = OPENSSL_sk_num(stack)
if length == 0:
return
for i in 0 .. length - 1:
let x509 = cast[PX509](OPENSSL_sk_value(stack, i))
result.add(i2d_X509(x509))
proc getPeerCertificates*(socket: Socket): seq[Certificate] {.since: (1, 1).} =
## Returns the certificate chain received by the peer we are connected to
## through the given socket.
## The handshake must have been completed and the certificate chain must
## have been verified successfully or else an empty sequence is returned.
## The chain is ordered from leaf certificate to root certificate.
if not socket.isSsl:
result = newSeq[Certificate]()
else:
result = getPeerCertificates(socket.sslHandle)
proc `sessionIdContext=`*(ctx: SslContext, sidCtx: string) =
## Sets the session id context in which a session can be reused.
## Used for permitting clients to reuse a session id instead of
## doing a new handshake.
##
## TLS clients might attempt to resume a session using the session id context,
## thus it must be set if verifyMode is set to CVerifyPeer or CVerifyPeerUseEnvVars,
## otherwise the connection will fail and SslError will be raised if resumption occurs.
##
## - Only useful if set server-side.
## - Should be unique per-application to prevent clients from malfunctioning.
## - sidCtx must be at most 32 characters in length.
if sidCtx.len > 32:
raiseSSLError("sessionIdContext must be shorter than 32 characters")
SSL_CTX_set_session_id_context(ctx.context, sidCtx, sidCtx.len)
proc getSocketError*(socket: Socket): OSErrorCode =
## Checks `osLastError` for a valid error. If it has been reset it uses
## the last error stored in the socket object.
result = osLastError()
if result == 0.OSErrorCode:
result = socket.lastError
if result == 0.OSErrorCode:
raiseOSError(result, "No valid socket error code available")
proc socketError*(socket: Socket, err: int = -1, async = false,
lastError = (-1).OSErrorCode,
flags: set[SocketFlag] = {}) =
## Raises an OSError based on the error code returned by `SSL_get_error`
## (for SSL sockets) and `osLastError` otherwise.
##
## If `async` is `true` no error will be thrown in the case when the
## error was caused by no data being available to be read.
##
## If `err` is not lower than 0 no exception will be raised.
##
## If `flags` contains `SafeDisconn`, no exception will be raised
## when the error was caused by a peer disconnection.
when defineSsl:
if socket.isSsl:
if err <= 0:
var ret = SSL_get_error(socket.sslHandle, err.cint)
case ret
of SSL_ERROR_ZERO_RETURN:
raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
if async:
return
else: raiseSSLError("Not enough data on socket.")
of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ:
if async:
return
else: raiseSSLError("Not enough data on socket.")
of SSL_ERROR_WANT_X509_LOOKUP:
raiseSSLError("Function for x509 lookup has been called.")
of SSL_ERROR_SYSCALL:
# SSL shutdown must not be done if a fatal error occurred.
socket.sslNoShutdown = true
let osErr = osLastError()
if not flags.isDisconnectionError(osErr):
var errStr = "IO error has occurred "
let sslErr = ERR_peek_last_error()
if sslErr == 0 and err == 0:
errStr.add "because an EOF was observed that violates the protocol"
elif sslErr == 0 and err == -1:
errStr.add "in the BIO layer"
else:
let errStr = $ERR_error_string(sslErr, nil)
raiseSSLError(errStr & ": " & errStr)
raiseOSError(osErr, errStr)
of SSL_ERROR_SSL:
# SSL shutdown must not be done if a fatal error occurred.
socket.sslNoShutdown = true
raiseSSLError()
else: raiseSSLError("Unknown Error")
if err == -1 and not (when defineSsl: socket.isSsl else: false):
var lastE = if lastError.int == -1: getSocketError(socket) else: lastError
if not flags.isDisconnectionError(lastE):
if async:
when useWinVersion:
if lastE.int32 == WSAEWOULDBLOCK:
return
else: raiseOSError(lastE)
else:
if lastE.int32 == EAGAIN or lastE.int32 == EWOULDBLOCK:
return
else: raiseOSError(lastE)
else: raiseOSError(lastE)
proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
## Marks `socket` as accepting connections.
## `Backlog` specifies the maximum length of the
## queue of pending connections.
##
## Raises an OSError error upon failure.
if nativesockets.listen(socket.fd, backlog) < 0'i32:
raiseOSError(osLastError())
proc bindAddr*(socket: Socket, port = Port(0), address = "") {.
tags: [ReadIOEffect].} =
## Binds `address`:`port` to the socket.
##
## If `address` is "" then ADDR_ANY will be bound.
var realaddr = address
if realaddr == "":
case socket.domain
of AF_INET6: realaddr = "::"
of AF_INET: realaddr = "0.0.0.0"
else:
raise newException(ValueError,
"Unknown socket address family and no address specified to bindAddr")
var aiList = getAddrInfo(realaddr, port, socket.domain)
if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.SockLen) < 0'i32:
freeAddrInfo(aiList)
var address2: string
address2.addQuoted address
raiseOSError(osLastError(), "address: $# port: $#" % [address2, $port])
freeAddrInfo(aiList)
proc acceptAddr*(server: Socket, client: var owned(Socket), address: var string,
flags = {SocketFlag.SafeDisconn},
inheritable = defined(nimInheritHandles)) {.
tags: [ReadIOEffect], gcsafe.} =
## Blocks until a connection is being made from a client. When a connection
## is made sets `client` to the client socket and `address` to the address
## of the connecting client.
## This function will raise OSError if an error occurs.
##
## The resulting client will inherit any properties of the server socket. For
## example: whether the socket is buffered or not.
##
## The SocketHandle associated with the resulting client will not be
## inheritable by child processes by default. This can be changed via
## the `inheritable` parameter.
##
## The `accept` call may result in an error if the connecting socket
## disconnects during the duration of the `accept`. If the `SafeDisconn`
## flag is specified then this error will not be raised and instead
## accept will be called again.
if client.isNil:
new(client)
let ret = accept(server.fd, inheritable)
let sock = ret[0]
if sock == osInvalidSocket:
let err = osLastError()
if flags.isDisconnectionError(err):
acceptAddr(server, client, address, flags, inheritable)
raiseOSError(err)
else:
address = ret[1]
client.fd = sock
client.domain = getSockDomain(sock)
client.isBuffered = server.isBuffered
# Handle SSL.
when defineSsl:
if server.isSsl:
# We must wrap the client sock in a ssl context.
server.sslContext.wrapSocket(client)
ErrClearError()
let ret = SSL_accept(client.sslHandle)
socketError(client, ret, false)
when false: #defineSsl:
proc acceptAddrSSL*(server: Socket, client: var Socket,
address: var string): SSL_acceptResult {.
tags: [ReadIOEffect].} =
## This procedure should only be used for non-blocking **SSL** sockets.
## It will immediately return with one of the following values:
##
## `AcceptSuccess` will be returned when a client has been successfully
## accepted and the handshake has been successfully performed between
## `server` and the newly connected client.
##
## `AcceptNoHandshake` will be returned when a client has been accepted
## but no handshake could be performed. This can happen when the client
## connects but does not yet initiate a handshake. In this case
## `acceptAddrSSL` should be called again with the same parameters.
##
## `AcceptNoClient` will be returned when no client is currently attempting
## to connect.
template doHandshake(): untyped =
when defineSsl:
if server.isSsl:
client.setBlocking(false)
# We must wrap the client sock in a ssl context.
if not client.isSsl or client.sslHandle == nil:
server.sslContext.wrapSocket(client)
ErrClearError()
let ret = SSL_accept(client.sslHandle)
while ret <= 0:
let err = SSL_get_error(client.sslHandle, ret)
if err != SSL_ERROR_WANT_ACCEPT:
case err
of SSL_ERROR_ZERO_RETURN:
raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE,
SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
client.sslNoHandshake = true
return AcceptNoHandshake
of SSL_ERROR_WANT_X509_LOOKUP:
raiseSSLError("Function for x509 lookup has been called.")
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
raiseSSLError()
else:
raiseSSLError("Unknown error")
client.sslNoHandshake = false
if client.isSsl and client.sslNoHandshake:
doHandshake()
return AcceptSuccess
else:
acceptAddrPlain(AcceptNoClient, AcceptSuccess):
doHandshake()
proc accept*(server: Socket, client: var owned(Socket),
flags = {SocketFlag.SafeDisconn},
inheritable = defined(nimInheritHandles))
{.tags: [ReadIOEffect].} =
## Equivalent to `acceptAddr` but doesn't return the address, only the
## socket.
##
## The SocketHandle associated with the resulting client will not be
## inheritable by child processes by default. This can be changed via
## the `inheritable` parameter.
##
## The `accept` call may result in an error if the connecting socket
## disconnects during the duration of the `accept`. If the `SafeDisconn`
## flag is specified then this error will not be raised and instead
## accept will be called again.
var addrDummy = ""
acceptAddr(server, client, addrDummy, flags)
when defined(posix) and not defined(lwip):
from posix import Sigset, sigwait, sigismember, sigemptyset, sigaddset,
sigprocmask, pthread_sigmask, SIGPIPE, SIG_BLOCK, SIG_UNBLOCK
template blockSigpipe(body: untyped): untyped =
## Temporary block SIGPIPE within the provided code block. If SIGPIPE is
## raised for the duration of the code block, it will be queued and will be
## raised once the block ends.
##
## Within the block a `selectSigpipe()` template is provided which can be
## used to remove SIGPIPE from the queue. Note that if SIGPIPE is **not**
## raised at the time of call, it will block until SIGPIPE is raised.
##
## If SIGPIPE has already been blocked at the time of execution, the
## signal mask is left as-is and `selectSigpipe()` will become a no-op.
##
## For convenience, this template is also available for non-POSIX system,
## where `body` will be executed as-is.
when not defined(posix) or defined(lwip):
body
else:
template sigmask(how: cint, set, oset: var Sigset): untyped {.gensym.} =
## Alias for pthread_sigmask or sigprocmask depending on the status
## of --threads
when compileOption("threads"):
pthread_sigmask(how, set, oset)
else:
sigprocmask(how, set, oset)
var oldSet, watchSet: Sigset
if sigemptyset(oldSet) == -1:
raiseOSError(osLastError())
if sigemptyset(watchSet) == -1:
raiseOSError(osLastError())
if sigaddset(watchSet, SIGPIPE) == -1:
raiseOSError(osLastError(), "Couldn't add SIGPIPE to Sigset")
if sigmask(SIG_BLOCK, watchSet, oldSet) == -1:
raiseOSError(osLastError(), "Couldn't block SIGPIPE")
let alreadyBlocked = sigismember(oldSet, SIGPIPE) == 1
template selectSigpipe(): untyped {.used.} =
if not alreadyBlocked:
var signal: cint
let err = sigwait(watchSet, signal)
if err != 0:
raiseOSError(err.OSErrorCode, "Couldn't select SIGPIPE")
assert signal == SIGPIPE
try:
body
finally:
if not alreadyBlocked:
if sigmask(SIG_UNBLOCK, watchSet, oldSet) == -1:
raiseOSError(osLastError(), "Couldn't unblock SIGPIPE")
proc close*(socket: Socket, flags = {SocketFlag.SafeDisconn}) =
## Closes a socket.
##
## If `socket` is an SSL/TLS socket, this proc will also send a closure
## notification to the peer. If `SafeDisconn` is in `flags`, failure to do so
## due to disconnections will be ignored. This is generally safe in
## practice. See
## `here <https://security.stackexchange.com/a/82044>`_ for more details.
try:
when defineSsl:
if socket.isSsl and socket.sslHandle != nil:
# Don't call SSL_shutdown if the connection has not been fully
# established, see:
# https://github.com/openssl/openssl/issues/710#issuecomment-253897666
if not socket.sslNoShutdown and SSL_in_init(socket.sslHandle) == 0:
# As we are closing the underlying socket immediately afterwards,
# it is valid, under the TLS standard, to perform a unidirectional
# shutdown i.e not wait for the peers "close notify" alert with a second
# call to SSL_shutdown
blockSigpipe:
ErrClearError()
let res = SSL_shutdown(socket.sslHandle)
if res == 0:
discard
elif res != 1:
let
err = osLastError()
sslError = SSL_get_error(socket.sslHandle, res)
# If a close notification is received, failures outside of the
# protocol will be returned as SSL_ERROR_ZERO_RETURN instead
# of SSL_ERROR_SYSCALL. This fact is deduced by digging into
# SSL_get_error() source code.
if sslError == SSL_ERROR_ZERO_RETURN or
sslError == SSL_ERROR_SYSCALL:
when defined(posix) and not defined(macosx) and
not defined(nimdoc):
if err == EPIPE.OSErrorCode:
# Clear the SIGPIPE that's been raised due to
# the disconnection.
selectSigpipe()
else:
discard
if not flags.isDisconnectionError(err):
socketError(socket, res, lastError = err, flags = flags)
else:
socketError(socket, res, lastError = err, flags = flags)
finally:
when defineSsl:
if socket.isSsl and socket.sslHandle != nil:
SSL_free(socket.sslHandle)
socket.sslHandle = nil
socket.fd.close()
socket.fd = osInvalidSocket
when defined(posix):
from posix import TCP_NODELAY
else:
from winlean import TCP_NODELAY
proc toCInt*(opt: SOBool): cint =
## Converts a `SOBool` into its Socket Option cint representation.
case opt
of OptAcceptConn: SO_ACCEPTCONN
of OptBroadcast: SO_BROADCAST
of OptDebug: SO_DEBUG
of OptDontRoute: SO_DONTROUTE
of OptKeepAlive: SO_KEEPALIVE
of OptOOBInline: SO_OOBINLINE
of OptReuseAddr: SO_REUSEADDR
of OptReusePort: SO_REUSEPORT
of OptNoDelay: TCP_NODELAY
proc getSockOpt*(socket: Socket, opt: SOBool, level = SOL_SOCKET): bool {.
tags: [ReadIOEffect].} =
## Retrieves option `opt` as a boolean value.
var res = getSockOptInt(socket.fd, cint(level), toCInt(opt))
result = res != 0
proc getLocalAddr*(socket: Socket): (string, Port) =
## Get the socket's local address and port number.
##
## This is high-level interface for `getsockname`:idx:.
getLocalAddr(socket.fd, socket.domain)
when not useNimNetLite:
proc getPeerAddr*(socket: Socket): (string, Port) =
## Get the socket's peer address and port number.
##
## This is high-level interface for `getpeername`:idx:.
getPeerAddr(socket.fd, socket.domain)
proc setSockOpt*(socket: Socket, opt: SOBool, value: bool,
level = SOL_SOCKET) {.tags: [WriteIOEffect].} =
## Sets option `opt` to a boolean value specified by `value`.
runnableExamples("-r:off"):
let socket = newSocket()
socket.setSockOpt(OptReusePort, true)
socket.setSockOpt(OptNoDelay, true, level = IPPROTO_TCP.cint)
var valuei = cint(if value: 1 else: 0)
setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei)
when defined(nimdoc) or (defined(posix) and not useNimNetLite):
proc connectUnix*(socket: Socket, path: string) =
## Connects to Unix socket on `path`.
## This only works on Unix-style systems: Mac OS X, BSD and Linux
when not defined(nimdoc):
var socketAddr = makeUnixAddr(path)
if socket.fd.connect(cast[ptr SockAddr](addr socketAddr),
(sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32:
raiseOSError(osLastError())
proc bindUnix*(socket: Socket, path: string) =
## Binds Unix socket to `path`.
## This only works on Unix-style systems: Mac OS X, BSD and Linux
when not defined(nimdoc):
var socketAddr = makeUnixAddr(path)
if socket.fd.bindAddr(cast[ptr SockAddr](addr socketAddr),
(sizeof(socketAddr.sun_family) + path.len).SockLen) != 0'i32:
raiseOSError(osLastError())
when defineSsl:
proc gotHandshake*(socket: Socket): bool =
## Determines whether a handshake has occurred between a client (`socket`)
## and the server that `socket` is connected to.
##
## Throws SslError if `socket` is not an SSL socket.
if socket.isSsl:
return not socket.sslNoHandshake
else:
raiseSSLError("Socket is not an SSL socket.")
proc hasDataBuffered*(s: Socket): bool =
## Determines whether a socket has data buffered.
result = false
if s.isBuffered:
result = s.bufLen > 0 and s.currPos != s.bufLen
when defineSsl:
if s.isSsl and not result:
result = s.sslHasPeekChar
proc isClosed(socket: Socket): bool =
socket.fd == osInvalidSocket
proc uniRecv(socket: Socket, buffer: pointer, size, flags: cint): int =
## Handles SSL and non-ssl recv in a nice package.
##
## In particular handles the case where socket has been closed properly
## for both SSL and non-ssl.
result = 0
assert(not socket.isClosed, "Cannot `recv` on a closed socket")
when defineSsl:
if socket.isSsl:
ErrClearError()
return SSL_read(socket.sslHandle, buffer, size)
return recv(socket.fd, buffer, size, flags)
proc readIntoBuf(socket: Socket, flags: int32): int =
result = 0
result = uniRecv(socket, addr(socket.buffer), socket.buffer.high, flags)
if result < 0:
# Save it in case it gets reset (the Nim codegen occasionally may call
# Win API functions which reset it).
socket.lastError = osLastError()
if result <= 0:
socket.bufLen = 0
socket.currPos = 0
return result
socket.bufLen = result
socket.currPos = 0
template retRead(flags, readBytes: int) {.dirty.} =
let res = socket.readIntoBuf(flags.int32)
if res <= 0:
if readBytes > 0:
return readBytes
else:
return res
proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [
ReadIOEffect].} =
## Receives data from a socket.
##
## **Note**: This is a low-level function, you may be interested in the higher
## level versions of this function which are also named `recv`.
if size == 0: return
if socket.isBuffered:
if socket.bufLen == 0:
retRead(0'i32, 0)
var read = 0
while read < size:
if socket.currPos >= socket.bufLen:
retRead(0'i32, read)
let chunk = min(socket.bufLen-socket.currPos, size-read)
var d = cast[cstring](data)
assert size-read >= chunk
copyMem(addr(d[read]), addr(socket.buffer[socket.currPos]), chunk)
read.inc(chunk)
socket.currPos.inc(chunk)
result = read
else:
when defineSsl:
if socket.isSsl:
if socket.sslHasPeekChar: # TODO: Merge this peek char mess into uniRecv
copyMem(data, addr(socket.sslPeekChar), 1)
socket.sslHasPeekChar = false
if size-1 > 0:
var d = cast[cstring](data)
result = uniRecv(socket, addr(d[1]), cint(size-1), 0'i32) + 1
else:
result = 1
else:
result = uniRecv(socket, data, size.cint, 0'i32)
else:
result = recv(socket.fd, data, size.cint, 0'i32)
else:
result = recv(socket.fd, data, size.cint, 0'i32)
if result < 0:
# Save the error in case it gets reset.
socket.lastError = osLastError()
proc waitFor(socket: Socket, waited: var Duration, timeout, size: int,
funcName: string): int {.tags: [TimeEffect].} =
## determines the amount of characters that can be read. Result will never
## be larger than `size`. For unbuffered sockets this will be `1`.
## For buffered sockets it can be as big as `BufferSize`.
##
## If this function does not determine that there is data on the socket
## within `timeout` ms, a TimeoutError error will be raised.
result = 1
if size <= 0: assert false
if timeout == -1: return size
if socket.isBuffered and socket.bufLen != 0 and
socket.bufLen != socket.currPos:
result = socket.bufLen - socket.currPos
result = min(result, size)
else:
if timeout - waited.inMilliseconds < 1:
raise newException(TimeoutError, "Call to '" & funcName & "' timed out.")
when defineSsl:
if socket.isSsl:
if socket.hasDataBuffered:
# sslPeekChar is present.
return 1
let sslPending = SSL_pending(socket.sslHandle)
if sslPending != 0:
return min(sslPending, size)
var startTime = getMonoTime()
let selRet = if socket.hasDataBuffered: 1
else:
timeoutRead(socket.fd, (timeout - waited.inMilliseconds).int)
if selRet < 0: raiseOSError(osLastError())
if selRet != 1:
raise newException(TimeoutError, "Call to '" & funcName & "' timed out.")
waited += (getMonoTime() - startTime)
proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {.
tags: [ReadIOEffect, TimeEffect].} =
## overload with a `timeout` parameter in milliseconds.
var waited: Duration # duration already waited
var read = 0
while read < size:
let avail = waitFor(socket, waited, timeout, size-read, "recv")
var d = cast[cstring](data)
assert avail <= size-read
result = recv(socket, addr(d[read]), avail)
if result == 0: break
if result < 0:
return result
inc(read, result)
result = read
proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
flags = {SocketFlag.SafeDisconn}): int =
## Higher-level version of `recv`.
##
## Reads **up to** `size` bytes from `socket` into `data`.
##
## For buffered sockets this function will attempt to read all the requested
## data. It will read this data in `BufferSize` chunks.
##
## For unbuffered sockets this function makes no effort to read
## all the data requested. It will return as much data as the operating system
## gives it.
##
## When 0 is returned the socket's connection has been closed.
##
## This function will throw an OSError exception when an error occurs. A value
## lower than 0 is never returned.
##
## A timeout may be specified in milliseconds, if enough data is not received
## within the time specified a TimeoutError exception will be raised.
##
## .. warning:: Only the `SafeDisconn` flag is currently supported.
data.setLen(size)
result =
if timeout == -1:
recv(socket, cstring(data), size)
else:
recv(socket, cstring(data), size, timeout)
if result < 0:
data.setLen(0)
let lastError = getSocketError(socket)
socket.socketError(result, lastError = lastError, flags = flags)
else:
data.setLen(result)
proc recv*(socket: Socket, size: int, timeout = -1,
flags = {SocketFlag.SafeDisconn}): string {.inline.} =
## Higher-level version of `recv` which returns a string.
##
## Reads **up to** `size` bytes from `socket` into the result.
##
## For buffered sockets this function will attempt to read all the requested
## data. It will read this data in `BufferSize` chunks.
##
## For unbuffered sockets this function makes no effort to read
## all the data requested. It will return as much data as the operating system
## gives it.
##
## When `""` is returned the socket's connection has been closed.
##
## This function will throw an OSError exception when an error occurs.
##
## A timeout may be specified in milliseconds, if enough data is not received
## within the time specified a TimeoutError exception will be raised.
##
##
## .. warning:: Only the `SafeDisconn` flag is currently supported.
result = newString(size)
discard recv(socket, result, size, timeout, flags)
proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
if socket.isBuffered:
result = 1
if socket.bufLen == 0 or socket.currPos > socket.bufLen-1:
var res = socket.readIntoBuf(0'i32)
if res <= 0:
result = res
c = socket.buffer[socket.currPos]
else:
when defineSsl:
if socket.isSsl:
if not socket.sslHasPeekChar:
result = uniRecv(socket, addr(socket.sslPeekChar), 1, 0'i32)
socket.sslHasPeekChar = true
c = socket.sslPeekChar
return
result = recv(socket.fd, addr(c), 1, MSG_PEEK)
proc readLine*(socket: Socket, line: var string, timeout = -1,
flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.
tags: [ReadIOEffect, TimeEffect].} =
## Reads a line of data from `socket`.
##
## If a full line is read `\r\L` is not
## added to `line`, however if solely `\r\L` is read then `line`
## will be set to it.
##
## If the socket is disconnected, `line` will be set to `""`.
##
## An OSError exception will be raised in the case of a socket error.
##
## A timeout can be specified in milliseconds, if data is not received within
## the specified time a TimeoutError exception will be raised.
##
## The `maxLength` parameter determines the maximum amount of characters
## that can be read. The result is truncated after that.
##
## .. warning:: Only the `SafeDisconn` flag is currently supported.
template addNLIfEmpty() =
if line.len == 0:
line.add("\c\L")
template raiseSockError() {.dirty.} =
let lastError = getSocketError(socket)
if flags.isDisconnectionError(lastError):
setLen(line, 0)
socket.socketError(n, lastError = lastError, flags = flags)
return
var waited: Duration
setLen(line, 0)
while true:
var c: char
discard waitFor(socket, waited, timeout, 1, "readLine")
var n = recv(socket, addr(c), 1)
if n < 0: raiseSockError()
elif n == 0: setLen(line, 0); return
if c == '\r':
discard waitFor(socket, waited, timeout, 1, "readLine")
n = peekChar(socket, c)
if n > 0 and c == '\L':
discard recv(socket, addr(c), 1)
elif n <= 0: raiseSockError()
addNLIfEmpty()
return
elif c == '\L':
addNLIfEmpty()
return
add(line, c)
# Verify that this isn't a DOS attack: #3847.
if line.len > maxLength: break
proc recvLine*(socket: Socket, timeout = -1,
flags = {SocketFlag.SafeDisconn},
maxLength = MaxLineLength): string =
## Reads a line of data from `socket`.
##
## If a full line is read `\r\L` is not
## added to the result, however if solely `\r\L` is read then the result
## will be set to it.
##
## If the socket is disconnected, the result will be set to `""`.
##
## An OSError exception will be raised in the case of a socket error.
##
## A timeout can be specified in milliseconds, if data is not received within
## the specified time a TimeoutError exception will be raised.
##
## The `maxLength` parameter determines the maximum amount of characters
## that can be read. The result is truncated after that.
##
## .. warning:: Only the `SafeDisconn` flag is currently supported.
result = ""
readLine(socket, result, timeout, flags, maxLength)
proc recvFrom*[T: string | IpAddress](socket: Socket, data: var string, length: int,
address: var T, port: var Port, flags = 0'i32): int {.
tags: [ReadIOEffect].} =
## Receives data from `socket`. This function should normally be used with
## connection-less sockets (UDP sockets). The source address of the data
## packet is stored in the `address` argument as either a string or an IpAddress.
##
## If an error occurs an OSError exception will be raised. Otherwise the return
## value will be the length of data received.
##
## .. warning:: This function does not yet have a buffered implementation,
## so when `socket` is buffered the non-buffered implementation will be
## used. Therefore if `socket` contains something in its buffer this
## function will make no effort to return it.
template adaptRecvFromToDomain(sockAddress: untyped, domain: Domain) =
var addrLen = SockLen(sizeof(sockAddress))
result = recvfrom(socket.fd, cstring(data), length.cint, flags.cint,
cast[ptr SockAddr](addr(sockAddress)), addr(addrLen))
if result != -1:
data.setLen(result)
when typeof(address) is string:
address = getAddrString(cast[ptr SockAddr](addr(sockAddress)))
when domain == AF_INET6:
port = ntohs(sockAddress.sin6_port).Port
else:
port = ntohs(sockAddress.sin_port).Port
else:
data.setLen(result)
sockAddress.fromSockAddr(addrLen, address, port)
else:
raiseOSError(osLastError())
assert(socket.protocol != IPPROTO_TCP, "Cannot `recvFrom` on a TCP socket")
# TODO: Buffered sockets
data.setLen(length)
case socket.domain
of AF_INET6:
var sockAddress: Sockaddr_in6
adaptRecvFromToDomain(sockAddress, AF_INET6)
of AF_INET:
var sockAddress: Sockaddr_in
adaptRecvFromToDomain(sockAddress, AF_INET)
else:
raise newException(ValueError, "Unknown socket address family")
proc skip*(socket: Socket, size: int, timeout = -1) =
## Skips `size` amount of bytes.
##
## An optional timeout can be specified in milliseconds, if skipping the
## bytes takes longer than specified a TimeoutError exception will be raised.
##
## Returns the number of skipped bytes.
var waited: Duration
var dummy = alloc(size)
var bytesSkipped = 0
while bytesSkipped != size:
let avail = waitFor(socket, waited, timeout, size-bytesSkipped, "skip")
bytesSkipped += recv(socket, dummy, avail)
dealloc(dummy)
proc send*(socket: Socket, data: pointer, size: int): int {.
tags: [WriteIOEffect].} =
## Sends data to a socket.
##
## **Note**: This is a low-level version of `send`. You likely should use
## the version below.
assert(not socket.isClosed, "Cannot `send` on a closed socket")
when defineSsl:
if socket.isSsl:
ErrClearError()
return SSL_write(socket.sslHandle, cast[cstring](data), size)
when useWinVersion or defined(macosx):
result = send(socket.fd, data, size.cint, 0'i32)
else:
when defined(solaris):
const MSG_NOSIGNAL = 0
result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
proc send*(socket: Socket, data: string,
flags = {SocketFlag.SafeDisconn}, maxRetries = 100) {.tags: [WriteIOEffect].} =
## Sends data to a socket. Will try to send all the data by handling interrupts
## and incomplete writes up to `maxRetries`.
var written = 0
var attempts = 0
while data.len - written > 0:
let sent = send(socket, cstring(data), data.len)
if sent < 0:
let lastError = osLastError()
let isBlockingErr =
when defined(nimdoc):
false
elif useWinVersion:
lastError.int32 == WSAEINTR or
lastError.int32 == WSAEWOULDBLOCK
else:
lastError.int32 == EINTR or
lastError.int32 == EWOULDBLOCK or
lastError.int32 == EAGAIN
if not isBlockingErr:
let lastError = osLastError()
socketError(socket, lastError = lastError, flags = flags)
else:
attempts.inc()
if attempts > maxRetries:
raiseOSError(osLastError(), "Could not send all data.")
else:
written.inc(sent)
template `&=`*(socket: Socket; data: typed) =
## an alias for 'send'.
send(socket, data)
proc trySend*(socket: Socket, data: string): bool {.tags: [WriteIOEffect].} =
## Safe alternative to `send`. Does not raise an OSError when an error occurs,
## and instead returns `false` on failure.
result = send(socket, cstring(data), data.len) == data.len
proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
size: int, af: Domain = AF_INET, flags = 0'i32) {.
tags: [WriteIOEffect].} =
## This proc sends `data` to the specified `address`,
## which may be an IP address or a hostname, if a hostname is specified
## this function will try each IP of that hostname. This function
## should normally be used with connection-less sockets (UDP sockets).
##
## If an error occurs an OSError exception will be raised.
##
## **Note:** You may wish to use the high-level version of this function
## which is defined below.
##
## **Note:** This proc is not available for SSL sockets.
assert(socket.protocol != IPPROTO_TCP, "Cannot `sendTo` on a TCP socket")
assert(not socket.isClosed, "Cannot `sendTo` on a closed socket")
var aiList = getAddrInfo(address, port, af, socket.sockType, socket.protocol)
# try all possibilities:
var success = false
var it = aiList
var result = 0
while it != nil:
result = sendto(socket.fd, data, size.cint, flags.cint, it.ai_addr,
it.ai_addrlen.SockLen)
if result != -1'i32:
success = true
break
it = it.ai_next
let osError = osLastError()
freeAddrInfo(aiList)
if not success:
raiseOSError(osError)
proc sendTo*(socket: Socket, address: string, port: Port,
data: string) {.tags: [WriteIOEffect].} =
## This proc sends `data` to the specified `address`,
## which may be an IP address or a hostname, if a hostname is specified
## this function will try each IP of that hostname.
##
## Generally for use with connection-less (UDP) sockets.
##
## If an error occurs an OSError exception will be raised.
##
## This is the high-level version of the above `sendTo` function.
socket.sendTo(address, port, cstring(data), data.len, socket.domain)
proc sendTo*(socket: Socket, address: IpAddress, port: Port,
data: string, flags = 0'i32): int {.
discardable, tags: [WriteIOEffect].} =
## This proc sends `data` to the specified `IpAddress` and returns
## the number of bytes written.
##
## Generally for use with connection-less (UDP) sockets.
##
## If an error occurs an OSError exception will be raised.
##
## This is the high-level version of the above `sendTo` function.
assert(socket.protocol != IPPROTO_TCP, "Cannot `sendTo` on a TCP socket")
assert(not socket.isClosed, "Cannot `sendTo` on a closed socket")
var sa: Sockaddr_storage
var sl: SockLen
toSockAddr(address, port, sa, sl)
result = sendto(socket.fd, cstring(data), data.len().cint, flags.cint,
cast[ptr SockAddr](addr sa), sl)
if result == -1'i32:
let osError = osLastError()
raiseOSError(osError)
proc isSsl*(socket: Socket): bool =
## Determines whether `socket` is a SSL socket.
when defineSsl:
result = socket.isSsl
else:
result = false
proc getFd*(socket: Socket): SocketHandle = return socket.fd
## Returns the socket's file descriptor
when defined(zephyr) or defined(nimNetSocketExtras): # Remove in future
proc getDomain*(socket: Socket): Domain = return socket.domain
## Returns the socket's domain
proc getType*(socket: Socket): SockType = return socket.sockType
## Returns the socket's type
proc getProtocol*(socket: Socket): Protocol = return socket.protocol
## Returns the socket's protocol
when defined(nimHasStyleChecks):
{.push styleChecks: off.}
proc IPv4_any*(): IpAddress =
## Returns the IPv4 any address, which can be used to listen on all available
## network adapters
result = IpAddress(
family: IpAddressFamily.IPv4,
address_v4: [0'u8, 0, 0, 0])
proc IPv4_loopback*(): IpAddress =
## Returns the IPv4 loopback address (127.0.0.1)
result = IpAddress(
family: IpAddressFamily.IPv4,
address_v4: [127'u8, 0, 0, 1])
proc IPv4_broadcast*(): IpAddress =
## Returns the IPv4 broadcast address (255.255.255.255)
result = IpAddress(
family: IpAddressFamily.IPv4,
address_v4: [255'u8, 255, 255, 255])
proc IPv6_any*(): IpAddress =
## Returns the IPv6 any address (::0), which can be used
## to listen on all available network adapters
result = IpAddress(
family: IpAddressFamily.IPv6,
address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
proc IPv6_loopback*(): IpAddress =
## Returns the IPv6 loopback address (::1)
result = IpAddress(
family: IpAddressFamily.IPv6,
address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
when defined(nimHasStyleChecks):
{.pop.}
proc `==`*(lhs, rhs: IpAddress): bool =
## Compares two IpAddresses for Equality. Returns true if the addresses are equal
if lhs.family != rhs.family: return false
if lhs.family == IpAddressFamily.IPv4:
for i in low(lhs.address_v4) .. high(lhs.address_v4):
if lhs.address_v4[i] != rhs.address_v4[i]: return false
else: # IPv6
for i in low(lhs.address_v6) .. high(lhs.address_v6):
if lhs.address_v6[i] != rhs.address_v6[i]: return false
return true
proc `$`*(address: IpAddress): string =
## Converts an IpAddress into the textual representation
case address.family
of IpAddressFamily.IPv4:
result = newStringOfCap(15)
result.addInt address.address_v4[0]
result.add '.'
result.addInt address.address_v4[1]
result.add '.'
result.addInt address.address_v4[2]
result.add '.'
result.addInt address.address_v4[3]
of IpAddressFamily.IPv6:
result = newStringOfCap(39)
var
currentZeroStart = -1
currentZeroCount = 0
biggestZeroStart = -1
biggestZeroCount = 0
# Look for the largest block of zeros
for i in 0..7:
var isZero = address.address_v6[i*2] == 0 and address.address_v6[i*2+1] == 0
if isZero:
if currentZeroStart == -1:
currentZeroStart = i
currentZeroCount = 1
else:
currentZeroCount.inc()
if currentZeroCount > biggestZeroCount:
biggestZeroCount = currentZeroCount
biggestZeroStart = currentZeroStart
else:
currentZeroStart = -1
if biggestZeroCount == 8: # Special case ::0
result.add("::")
else: # Print address
var printedLastGroup = false
for i in 0..7:
var word: uint16 = (cast[uint16](address.address_v6[i*2])) shl 8
word = word or cast[uint16](address.address_v6[i*2+1])
if biggestZeroCount != 0 and # Check if group is in skip group
(i >= biggestZeroStart and i < (biggestZeroStart + biggestZeroCount)):
if i == biggestZeroStart: # skip start
result.add("::")
printedLastGroup = false
else:
if printedLastGroup:
result.add(':')
var
afterLeadingZeros = false
mask = 0xF000'u16
for j in 0'u16..3'u16:
var val = (mask and word) shr (4'u16*(3'u16-j))
if val != 0 or afterLeadingZeros:
if val < 0xA:
result.add(chr(uint16(ord('0'))+val))
else: # val >= 0xA
result.add(chr(uint16(ord('a'))+val-0xA))
afterLeadingZeros = true
mask = mask shr 4
if not afterLeadingZeros:
result.add '0'
printedLastGroup = true
proc dial*(address: string, port: Port,
protocol = IPPROTO_TCP, buffered = true): owned(Socket)
{.tags: [ReadIOEffect, WriteIOEffect].} =
## Establishes connection to the specified `address`:`port` pair via the
## specified protocol. The procedure iterates through possible
## resolutions of the `address` until it succeeds, meaning that it
## seamlessly works with both IPv4 and IPv6.
## Returns Socket ready to send or receive data.
let sockType = protocol.toSockType()
let aiList = getAddrInfo(address, port, AF_UNSPEC, sockType, protocol)
var fdPerDomain: array[low(Domain).ord..high(Domain).ord, SocketHandle]
for i in low(fdPerDomain)..high(fdPerDomain):
fdPerDomain[i] = osInvalidSocket
template closeUnusedFds(domainToKeep = -1) {.dirty.} =
for i, fd in fdPerDomain:
if fd != osInvalidSocket and i != domainToKeep:
fd.close()
var success = false
var lastError: OSErrorCode
var it = aiList
var domain: Domain
var lastFd: SocketHandle
while it != nil:
let domainOpt = it.ai_family.toKnownDomain()
if domainOpt.isNone:
it = it.ai_next
continue
domain = domainOpt.unsafeGet()
lastFd = fdPerDomain[ord(domain)]
if lastFd == osInvalidSocket:
lastFd = createNativeSocket(domain, sockType, protocol)
if lastFd == osInvalidSocket:
# we always raise if socket creation failed, because it means a
# network system problem (e.g. not enough FDs), and not an unreachable
# address.
let err = osLastError()
freeAddrInfo(aiList)
closeUnusedFds()
raiseOSError(err)
fdPerDomain[ord(domain)] = lastFd
if connect(lastFd, it.ai_addr, it.ai_addrlen.SockLen) == 0'i32:
success = true
break
lastError = osLastError()
it = it.ai_next
freeAddrInfo(aiList)
closeUnusedFds(ord(domain))
if success:
result = newSocket(lastFd, domain, sockType, protocol, buffered)
elif lastError != 0.OSErrorCode:
lastFd.close()
raiseOSError(lastError)
else:
lastFd.close()
raise newException(IOError, "Couldn't resolve address: " & address)
proc connect*(socket: Socket, address: string,
port = Port(0)) {.tags: [ReadIOEffect, RootEffect].} =
## Connects socket to `address`:`port`. `Address` can be an IP address or a
## host name. If `address` is a host name, this function will try each IP
## of that host name. `htons` is already performed on `port` so you must
## not do it.
##
## If `socket` is an SSL socket a handshake will be automatically performed.
var aiList = getAddrInfo(address, port, socket.domain)
# try all possibilities:
var success = false
var lastError: OSErrorCode
var it = aiList
while it != nil:
if connect(socket.fd, it.ai_addr, it.ai_addrlen.SockLen) == 0'i32:
success = true
break
else: lastError = osLastError()
it = it.ai_next
freeAddrInfo(aiList)
if not success: raiseOSError(lastError)
when defineSsl:
if socket.isSsl:
# RFC3546 for SNI specifies that IP addresses are not allowed.
if not isIpAddress(address):
# Discard result in case OpenSSL version doesn't support SNI, or we're
# not using TLSv1+
discard SSL_set_tlsext_host_name(socket.sslHandle, address)
ErrClearError()
let ret = SSL_connect(socket.sslHandle)
socketError(socket, ret)
when not defined(nimDisableCertificateValidation) and not defined(windows):
if not isIpAddress(address):
socket.checkCertName(address)
proc connectAsync(socket: Socket, name: string, port = Port(0),
af: Domain = AF_INET) {.tags: [ReadIOEffect].} =
## A variant of `connect` for non-blocking sockets.
##
## This procedure will immediately return, it will not block until a connection
## is made. It is up to the caller to make sure the connection has been established
## by checking (using `select`) whether the socket is writeable.
##
## **Note**: For SSL sockets, the `handshake` procedure must be called
## whenever the socket successfully connects to a server.
var aiList = getAddrInfo(name, port, af)
# try all possibilities:
var success = false
var lastError: OSErrorCode
var it = aiList
while it != nil:
var ret = connect(socket.fd, it.ai_addr, it.ai_addrlen.SockLen)
if ret == 0'i32:
success = true
break
else:
lastError = osLastError()
when useWinVersion:
# Windows EINTR doesn't behave same as POSIX.
if lastError.int32 == WSAEWOULDBLOCK:
success = true
break
else:
if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS:
success = true
break
it = it.ai_next
freeAddrInfo(aiList)
if not success: raiseOSError(lastError)
proc connect*(socket: Socket, address: string, port = Port(0),
timeout: int) {.tags: [ReadIOEffect, WriteIOEffect, RootEffect].} =
## Connects to server as specified by `address` on port specified by `port`.
##
## The `timeout` parameter specifies the time in milliseconds to allow for
## the connection to the server to be made.
socket.fd.setBlocking(false)
socket.connectAsync(address, port, socket.domain)
if timeoutWrite(socket.fd, timeout) != 1:
raise newException(TimeoutError, "Call to 'connect' timed out.")
else:
let res = getSockOptInt(socket.fd, SOL_SOCKET, SO_ERROR)
if res != 0:
raiseOSError(OSErrorCode(res))
when defineSsl and not defined(nimdoc):
if socket.isSsl:
socket.fd.setBlocking(true)
# RFC3546 for SNI specifies that IP addresses are not allowed.
if not isIpAddress(address):
# Discard result in case OpenSSL version doesn't support SNI, or we're
# not using TLSv1+
discard SSL_set_tlsext_host_name(socket.sslHandle, address)
ErrClearError()
let ret = SSL_connect(socket.sslHandle)
socketError(socket, ret)
when not defined(nimDisableCertificateValidation):
if not isIpAddress(address):
socket.checkCertName(address)
socket.fd.setBlocking(true)
proc getPrimaryIPAddr*(dest = parseIpAddress("8.8.8.8")): IpAddress =
## Finds the local IP address, usually assigned to eth0 on LAN or wlan0 on WiFi,
## used to reach an external address. Useful to run local services.
##
## No traffic is sent.
##
## Supports IPv4 and v6.
## Raises OSError if external networking is not set up.
runnableExamples("-r:off"):
echo getPrimaryIPAddr() # "192.168.1.2"
let socket =
if dest.family == IpAddressFamily.IPv4:
newSocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)
else:
newSocket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)
try:
socket.connect($dest, 80.Port)
result = socket.getLocalAddr()[0].parseIpAddress()
finally:
socket.close()