mirror of
https://github.com/nim-lang/Nim.git
synced 2026-02-16 16:14:20 +00:00
Merge branch 'starttls' of https://github.com/wiml/Nim into wiml-starttls
Conflicts: lib/pure/net.nim
This commit is contained in:
@@ -472,6 +472,15 @@ when defined(ssl):
|
||||
socket.bioOut = bioNew(bio_s_mem())
|
||||
sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
|
||||
|
||||
proc wrapSocket*(ctx: SslContext, socket: AsyncSocket, handshake: SslHandshakeType) =
|
||||
wrapSocket(ctx, socket)
|
||||
|
||||
case handshake
|
||||
of handshakeAsClient:
|
||||
sslSetConnectState(socket.sslHandle)
|
||||
of handshakeAsServer:
|
||||
sslSetAcceptState(socket.sslHandle)
|
||||
|
||||
proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {.
|
||||
tags: [ReadIOEffect].} =
|
||||
## Retrieves option ``opt`` as a boolean value.
|
||||
|
||||
133
lib/pure/net.nim
133
lib/pure/net.nim
@@ -26,15 +26,18 @@ when defined(ssl):
|
||||
|
||||
SslCVerifyMode* = enum
|
||||
CVerifyNone, CVerifyPeer
|
||||
|
||||
|
||||
SslProtVersion* = enum
|
||||
protSSLv2, protSSLv3, protTLSv1, protSSLv23
|
||||
|
||||
|
||||
SslContext* = distinct SslCtx
|
||||
|
||||
SslAcceptResult* = enum
|
||||
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
|
||||
|
||||
SslHandshakeType* = enum
|
||||
handshakeAsClient, handshakeAsServer
|
||||
|
||||
{.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
|
||||
TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
|
||||
TSSLAcceptResult: SSLAcceptResult].}
|
||||
@@ -86,7 +89,7 @@ type
|
||||
IPv6, ## IPv6 address
|
||||
IPv4 ## IPv4 address
|
||||
|
||||
IpAddress* = object ## stores an arbitrary IP address
|
||||
IpAddress* = object ## stores an arbitrary IP address
|
||||
case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
|
||||
of IpAddressFamily.IPv6:
|
||||
address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
|
||||
@@ -98,6 +101,8 @@ type
|
||||
|
||||
proc isIpAddress*(address_str: string): bool {.tags: [].}
|
||||
proc parseIpAddress*(address_str: string): IpAddress
|
||||
proc socketError*(socket: Socket, err: int = -1, async = false,
|
||||
lastError = (-1).OSErrorCode): void
|
||||
|
||||
proc isDisconnectionError*(flags: set[SocketFlag],
|
||||
lastError: OSErrorCode): bool =
|
||||
@@ -109,7 +114,7 @@ proc isDisconnectionError*(flags: set[SocketFlag],
|
||||
WSAEDISCON, ERROR_NETNAME_DELETED}
|
||||
else:
|
||||
SocketFlag.SafeDisconn in flags and
|
||||
lastError.int32 in {ECONNRESET, EPIPE, ENETRESET}
|
||||
lastError.int32 in {ECONNRESET, EPIPE, ENETRESET}
|
||||
|
||||
proc toOSFlags*(socketFlags: set[SocketFlag]): cint =
|
||||
## Converts the flags into the underlying OS representation.
|
||||
@@ -172,27 +177,27 @@ when defined(ssl):
|
||||
raise newException(system.IOError, "Certificate file could not be found: " & certFile)
|
||||
if keyFile != "" and not existsFile(keyFile):
|
||||
raise newException(system.IOError, "Key file could not be found: " & keyFile)
|
||||
|
||||
|
||||
if certFile != "":
|
||||
var ret = SSLCTXUseCertificateChainFile(ctx, certFile)
|
||||
if ret != 1:
|
||||
raiseSSLError()
|
||||
|
||||
|
||||
# TODO: Password? www.rtfm.com/openssl-examples/part1.pdf
|
||||
if keyFile != "":
|
||||
if SSL_CTX_use_PrivateKey_file(ctx, keyFile,
|
||||
SSL_FILETYPE_PEM) != 1:
|
||||
raiseSSLError()
|
||||
|
||||
|
||||
if SSL_CTX_check_private_key(ctx) != 1:
|
||||
raiseSSLError("Verification of private key file failed.")
|
||||
|
||||
proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
|
||||
certFile = "", keyFile = ""): SSLContext =
|
||||
## Creates an SSL context.
|
||||
##
|
||||
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
|
||||
## are available with the addition of ``protSSLv23`` which allows for
|
||||
##
|
||||
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
|
||||
## are available with the addition of ``protSSLv23`` which allows for
|
||||
## compatibility with all of them.
|
||||
##
|
||||
## There are currently only two options for verify mode;
|
||||
@@ -217,7 +222,7 @@ when defined(ssl):
|
||||
newCTX = SSL_CTX_new(SSLv3_method())
|
||||
of protTLSv1:
|
||||
newCTX = SSL_CTX_new(TLSv1_method())
|
||||
|
||||
|
||||
if newCTX.SSLCTXSetCipherList("ALL") != 1:
|
||||
raiseSSLError()
|
||||
case verifyMode
|
||||
@@ -236,9 +241,13 @@ when defined(ssl):
|
||||
## Wraps a socket in an SSL context. This function effectively turns
|
||||
## ``socket`` into an SSL socket.
|
||||
##
|
||||
## This must be called on an unconnected socket; an SSL session will
|
||||
## be started when the socket is connected.
|
||||
##
|
||||
## **Disclaimer**: This code is not well tested, may be very unsafe and
|
||||
## prone to security vulnerabilities.
|
||||
|
||||
|
||||
assert (not socket.isSSL)
|
||||
socket.isSSL = true
|
||||
socket.sslContext = ctx
|
||||
socket.sslHandle = SSLNew(SSLCTX(socket.sslContext))
|
||||
@@ -246,10 +255,28 @@ when defined(ssl):
|
||||
socket.sslHasPeekChar = false
|
||||
if socket.sslHandle == nil:
|
||||
raiseSSLError()
|
||||
|
||||
|
||||
if SSLSetFd(socket.sslHandle, socket.fd) != 1:
|
||||
raiseSSLError()
|
||||
|
||||
proc wrapSocket*(ctx: SSLContext, socket: Socket, handshake: SslHandshakeType) =
|
||||
## Wraps a socket in an SSL context. This function effectively turns
|
||||
## ``socket`` into an SSL socket.
|
||||
##
|
||||
## This should be called on a connected socket, and will perform
|
||||
## an SSL handshake immediately.
|
||||
##
|
||||
## **Disclaimer**: This code is not well tested, may be very unsafe and
|
||||
## prone to security vulnerabilities.
|
||||
wrapSocket(ctx, socket)
|
||||
case handshake
|
||||
of handshakeAsClient:
|
||||
let ret = SSLConnect(socket.sslHandle)
|
||||
socketError(socket, ret)
|
||||
of handshakeAsServer:
|
||||
let ret = SSLAccept(socket.sslHandle)
|
||||
socketError(socket, ret)
|
||||
|
||||
proc getSocketError*(socket: Socket): OSErrorCode =
|
||||
## Checks ``osLastError`` for a valid error. If it has been reset it uses
|
||||
## the last error stored in the socket object.
|
||||
@@ -302,7 +329,7 @@ proc socketError*(socket: Socket, err: int = -1, async = false,
|
||||
of SSL_ERROR_SSL:
|
||||
raiseSSLError()
|
||||
else: raiseSSLError("Unknown Error")
|
||||
|
||||
|
||||
if err == -1 and not (when defined(ssl): socket.isSSL else: false):
|
||||
var lastE = if lastError.int == -1: getSocketError(socket) else: lastError
|
||||
if async:
|
||||
@@ -317,8 +344,8 @@ proc socketError*(socket: Socket, err: int = -1, async = false,
|
||||
else: raiseOSError(lastE)
|
||||
|
||||
proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
|
||||
## Marks ``socket`` as accepting connections.
|
||||
## ``Backlog`` specifies the maximum length of the
|
||||
## Marks ``socket`` as accepting connections.
|
||||
## ``Backlog`` specifies the maximum length of the
|
||||
## queue of pending connections.
|
||||
##
|
||||
## Raises an EOS error upon failure.
|
||||
@@ -360,7 +387,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
|
||||
## The resulting client will inherit any properties of the server socket. For
|
||||
## example: whether the socket is buffered or not.
|
||||
##
|
||||
## **Note**: ``client`` must be initialised (with ``new``), this function
|
||||
## **Note**: ``client`` must be initialised (with ``new``), this function
|
||||
## makes no effort to initialise the ``client`` variable.
|
||||
##
|
||||
## The ``accept`` call may result in an error if the connecting socket
|
||||
@@ -372,7 +399,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
|
||||
var addrLen = sizeof(sockAddress).SockLen
|
||||
var sock = accept(server.fd, cast[ptr SockAddr](addr(sockAddress)),
|
||||
addr(addrLen))
|
||||
|
||||
|
||||
if sock == osInvalidSocket:
|
||||
let err = osLastError()
|
||||
if flags.isDisconnectionError(err):
|
||||
@@ -386,11 +413,11 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
|
||||
when defined(ssl):
|
||||
if server.isSSL:
|
||||
# We must wrap the client sock in a ssl context.
|
||||
|
||||
|
||||
server.sslContext.wrapSocket(client)
|
||||
let ret = SSLAccept(client.sslHandle)
|
||||
socketError(client, ret, false)
|
||||
|
||||
|
||||
# Client socket is set above.
|
||||
address = $inet_ntoa(sockAddress.sin_addr)
|
||||
|
||||
@@ -398,9 +425,9 @@ when false: #defined(ssl):
|
||||
proc acceptAddrSSL*(server: Socket, client: var Socket,
|
||||
address: var string): SSLAcceptResult {.
|
||||
tags: [ReadIOEffect].} =
|
||||
## This procedure should only be used for non-blocking **SSL** sockets.
|
||||
## This procedure should only be used for non-blocking **SSL** sockets.
|
||||
## It will immediately return with one of the following values:
|
||||
##
|
||||
##
|
||||
## ``AcceptSuccess`` will be returned when a client has been successfully
|
||||
## accepted and the handshake has been successfully performed between
|
||||
## ``server`` and the newly connected client.
|
||||
@@ -417,7 +444,7 @@ when false: #defined(ssl):
|
||||
if server.isSSL:
|
||||
client.setBlocking(false)
|
||||
# We must wrap the client sock in a ssl context.
|
||||
|
||||
|
||||
if not client.isSSL or client.sslHandle == nil:
|
||||
server.sslContext.wrapSocket(client)
|
||||
let ret = SSLAccept(client.sslHandle)
|
||||
@@ -450,7 +477,7 @@ proc accept*(server: Socket, client: var Socket,
|
||||
flags = {SocketFlag.SafeDisconn}) {.tags: [ReadIOEffect].} =
|
||||
## Equivalent to ``acceptAddr`` but doesn't return the address, only the
|
||||
## socket.
|
||||
##
|
||||
##
|
||||
## **Note**: ``client`` must be initialised (with ``new``), this function
|
||||
## makes no effort to initialise the ``client`` variable.
|
||||
##
|
||||
@@ -504,7 +531,7 @@ proc setSockOpt*(socket: Socket, opt: SOBool, value: bool, level = SOL_SOCKET) {
|
||||
var valuei = cint(if value: 1 else: 0)
|
||||
setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei)
|
||||
|
||||
proc connect*(socket: Socket, address: string, port = Port(0),
|
||||
proc connect*(socket: Socket, address: string, port = Port(0),
|
||||
af: Domain = AF_INET) {.tags: [ReadIOEffect].} =
|
||||
## Connects socket to ``address``:``port``. ``Address`` can be an IP address or a
|
||||
## host name. If ``address`` is a host name, this function will try each IP
|
||||
@@ -526,7 +553,7 @@ proc connect*(socket: Socket, address: string, port = Port(0),
|
||||
|
||||
dealloc(aiList)
|
||||
if not success: raiseOSError(lastError)
|
||||
|
||||
|
||||
when defined(ssl):
|
||||
if socket.isSSL:
|
||||
# RFC3546 for SNI specifies that IP addresses are not allowed.
|
||||
@@ -634,12 +661,12 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect]
|
||||
if socket.isBuffered:
|
||||
if socket.bufLen == 0:
|
||||
retRead(0'i32, 0)
|
||||
|
||||
|
||||
var read = 0
|
||||
while read < size:
|
||||
if socket.currPos >= socket.bufLen:
|
||||
retRead(0'i32, read)
|
||||
|
||||
|
||||
let chunk = min(socket.bufLen-socket.currPos, size-read)
|
||||
var d = cast[cstring](data)
|
||||
assert size-read >= chunk
|
||||
@@ -686,7 +713,7 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int,
|
||||
else:
|
||||
if timeout - int(waited * 1000.0) < 1:
|
||||
raise newException(TimeoutError, "Call to '" & funcName & "' timed out.")
|
||||
|
||||
|
||||
when defined(ssl):
|
||||
if socket.isSSL:
|
||||
if socket.hasDataBuffered:
|
||||
@@ -695,7 +722,7 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int,
|
||||
let sslPending = SSLPending(socket.sslHandle)
|
||||
if sslPending != 0:
|
||||
return sslPending
|
||||
|
||||
|
||||
var startTime = epochTime()
|
||||
let selRet = select(socket, timeout - int(waited * 1000.0))
|
||||
if selRet < 0: raiseOSError(osLastError())
|
||||
@@ -706,8 +733,8 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int,
|
||||
proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {.
|
||||
tags: [ReadIOEffect, TimeEffect].} =
|
||||
## overload with a ``timeout`` parameter in milliseconds.
|
||||
var waited = 0.0 # number of seconds already waited
|
||||
|
||||
var waited = 0.0 # number of seconds already waited
|
||||
|
||||
var read = 0
|
||||
while read < size:
|
||||
let avail = waitFor(socket, waited, timeout, size-read, "recv")
|
||||
@@ -718,7 +745,7 @@ proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {.
|
||||
if result < 0:
|
||||
return result
|
||||
inc(read, result)
|
||||
|
||||
|
||||
result = read
|
||||
|
||||
proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
|
||||
@@ -752,7 +779,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
|
||||
var res = socket.readIntoBuf(0'i32)
|
||||
if res <= 0:
|
||||
result = res
|
||||
|
||||
|
||||
c = socket.buffer[socket.currPos]
|
||||
else:
|
||||
when defined(ssl):
|
||||
@@ -760,7 +787,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
|
||||
if not socket.sslHasPeekChar:
|
||||
result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1)
|
||||
socket.sslHasPeekChar = true
|
||||
|
||||
|
||||
c = socket.sslPeekChar
|
||||
return
|
||||
result = recv(socket.fd, addr(c), 1, MSG_PEEK)
|
||||
@@ -773,7 +800,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
|
||||
## If a full line is read ``\r\L`` is not
|
||||
## added to ``line``, however if solely ``\r\L`` is read then ``line``
|
||||
## will be set to it.
|
||||
##
|
||||
##
|
||||
## If the socket is disconnected, ``line`` will be set to ``""``.
|
||||
##
|
||||
## An EOS exception will be raised in the case of a socket error.
|
||||
@@ -782,7 +809,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
|
||||
## the specified time an ETimeout exception will be raised.
|
||||
##
|
||||
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
|
||||
|
||||
|
||||
template addNLIfEmpty(): stmt =
|
||||
if line.len == 0:
|
||||
line.add("\c\L")
|
||||
@@ -809,7 +836,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
|
||||
elif n <= 0: raiseSockError()
|
||||
addNLIfEmpty()
|
||||
return
|
||||
elif c == '\L':
|
||||
elif c == '\L':
|
||||
addNLIfEmpty()
|
||||
return
|
||||
add(line.string, c)
|
||||
@@ -827,7 +854,7 @@ proc recvFrom*(socket: Socket, data: var string, length: int,
|
||||
## so when ``socket`` is buffered the non-buffered implementation will be
|
||||
## used. Therefore if ``socket`` contains something in its buffer this
|
||||
## function will make no effort to return it.
|
||||
|
||||
|
||||
# TODO: Buffered sockets
|
||||
data.setLen(length)
|
||||
var sockAddress: Sockaddr_in
|
||||
@@ -861,16 +888,16 @@ proc send*(socket: Socket, data: pointer, size: int): int {.
|
||||
tags: [WriteIOEffect].} =
|
||||
## Sends data to a socket.
|
||||
##
|
||||
## **Note**: This is a low-level version of ``send``. You likely should use
|
||||
## **Note**: This is a low-level version of ``send``. You likely should use
|
||||
## the version below.
|
||||
when defined(ssl):
|
||||
if socket.isSSL:
|
||||
return SSLWrite(socket.sslHandle, cast[cstring](data), size)
|
||||
|
||||
|
||||
when useWinVersion or defined(macosx):
|
||||
result = send(socket.fd, data, size.cint, 0'i32)
|
||||
else:
|
||||
when defined(solaris):
|
||||
when defined(solaris):
|
||||
const MSG_NOSIGNAL = 0
|
||||
result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
|
||||
|
||||
@@ -895,7 +922,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
|
||||
size: int, af: Domain = AF_INET, flags = 0'i32): int {.
|
||||
tags: [WriteIOEffect].} =
|
||||
## This proc sends ``data`` to the specified ``address``,
|
||||
## which may be an IP address or a hostname, if a hostname is specified
|
||||
## which may be an IP address or a hostname, if a hostname is specified
|
||||
## this function will try each IP of that hostname.
|
||||
##
|
||||
##
|
||||
@@ -904,7 +931,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
|
||||
##
|
||||
## **Note:** This proc is not available for SSL sockets.
|
||||
var aiList = getAddrInfo(address, port, af)
|
||||
|
||||
|
||||
# try all possibilities:
|
||||
var success = false
|
||||
var it = aiList
|
||||
@@ -918,10 +945,10 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
|
||||
|
||||
dealloc(aiList)
|
||||
|
||||
proc sendTo*(socket: Socket, address: string, port: Port,
|
||||
proc sendTo*(socket: Socket, address: string, port: Port,
|
||||
data: string): int {.tags: [WriteIOEffect].} =
|
||||
## This proc sends ``data`` to the specified ``address``,
|
||||
## which may be an IP address or a hostname, if a hostname is specified
|
||||
## which may be an IP address or a hostname, if a hostname is specified
|
||||
## this function will try each IP of that hostname.
|
||||
##
|
||||
## This is the high-level version of the above ``sendTo`` function.
|
||||
@@ -958,7 +985,7 @@ proc connectAsync(socket: Socket, name: string, port = Port(0),
|
||||
if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS:
|
||||
success = true
|
||||
break
|
||||
|
||||
|
||||
it = it.ai_next
|
||||
|
||||
dealloc(aiList)
|
||||
@@ -971,7 +998,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int,
|
||||
## The ``timeout`` paremeter specifies the time in milliseconds to allow for
|
||||
## the connection to the server to be made.
|
||||
socket.fd.setBlocking(false)
|
||||
|
||||
|
||||
socket.connectAsync(address, port, af)
|
||||
var s = @[socket.fd]
|
||||
if selectWrite(s, timeout) != 1:
|
||||
@@ -983,7 +1010,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int,
|
||||
doAssert socket.handshake()
|
||||
socket.fd.setBlocking(true)
|
||||
|
||||
proc isSsl*(socket: Socket): bool =
|
||||
proc isSsl*(socket: Socket): bool =
|
||||
## Determines whether ``socket`` is a SSL socket.
|
||||
when defined(ssl):
|
||||
result = socket.isSSL
|
||||
@@ -1014,7 +1041,7 @@ proc IPv4_broadcast*(): IpAddress =
|
||||
|
||||
proc IPv6_any*(): IpAddress =
|
||||
## Returns the IPv6 any address (::0), which can be used
|
||||
## to listen on all available network adapters
|
||||
## to listen on all available network adapters
|
||||
result = IpAddress(
|
||||
family: IpAddressFamily.IPv6,
|
||||
address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
|
||||
@@ -1152,7 +1179,7 @@ proc parseIPv6Address(address_str: string): IpAddress =
|
||||
if not seperatorValid:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains an invalid seperator")
|
||||
if lastWasColon:
|
||||
if lastWasColon:
|
||||
if dualColonGroup != -1:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains more than one \"::\" seperator")
|
||||
@@ -1165,14 +1192,14 @@ proc parseIPv6Address(address_str: string): IpAddress =
|
||||
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
|
||||
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
|
||||
currentShort = 0
|
||||
groupCount.inc()
|
||||
groupCount.inc()
|
||||
if dualColonGroup != -1: seperatorValid = false
|
||||
elif i == 0: # only valid if address starts with ::
|
||||
if address_str[1] != ':':
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address may not start with \":\"")
|
||||
else: # i == high(address_str) - only valid if address ends with ::
|
||||
if address_str[high(address_str)-1] != ':':
|
||||
if address_str[high(address_str)-1] != ':':
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address may not end with \":\"")
|
||||
lastWasColon = true
|
||||
|
||||
Reference in New Issue
Block a user