Merge branch 'starttls' of https://github.com/wiml/Nim into wiml-starttls

Conflicts:
	lib/pure/net.nim
This commit is contained in:
Dominik Picheta
2015-06-22 21:34:21 +01:00
2 changed files with 89 additions and 53 deletions

View File

@@ -472,6 +472,15 @@ when defined(ssl):
socket.bioOut = bioNew(bio_s_mem())
sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
proc wrapSocket*(ctx: SslContext, socket: AsyncSocket, handshake: SslHandshakeType) =
wrapSocket(ctx, socket)
case handshake
of handshakeAsClient:
sslSetConnectState(socket.sslHandle)
of handshakeAsServer:
sslSetAcceptState(socket.sslHandle)
proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {.
tags: [ReadIOEffect].} =
## Retrieves option ``opt`` as a boolean value.

View File

@@ -26,15 +26,18 @@ when defined(ssl):
SslCVerifyMode* = enum
CVerifyNone, CVerifyPeer
SslProtVersion* = enum
protSSLv2, protSSLv3, protTLSv1, protSSLv23
SslContext* = distinct SslCtx
SslAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
SslHandshakeType* = enum
handshakeAsClient, handshakeAsServer
{.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
TSSLProtVersion: SSLProtVersion, PSSLContext: SSLContext,
TSSLAcceptResult: SSLAcceptResult].}
@@ -86,7 +89,7 @@ type
IPv6, ## IPv6 address
IPv4 ## IPv4 address
IpAddress* = object ## stores an arbitrary IP address
IpAddress* = object ## stores an arbitrary IP address
case family*: IpAddressFamily ## the type of the IP address (IPv4 or IPv6)
of IpAddressFamily.IPv6:
address_v6*: array[0..15, uint8] ## Contains the IP address in bytes in
@@ -98,6 +101,8 @@ type
proc isIpAddress*(address_str: string): bool {.tags: [].}
proc parseIpAddress*(address_str: string): IpAddress
proc socketError*(socket: Socket, err: int = -1, async = false,
lastError = (-1).OSErrorCode): void
proc isDisconnectionError*(flags: set[SocketFlag],
lastError: OSErrorCode): bool =
@@ -109,7 +114,7 @@ proc isDisconnectionError*(flags: set[SocketFlag],
WSAEDISCON, ERROR_NETNAME_DELETED}
else:
SocketFlag.SafeDisconn in flags and
lastError.int32 in {ECONNRESET, EPIPE, ENETRESET}
lastError.int32 in {ECONNRESET, EPIPE, ENETRESET}
proc toOSFlags*(socketFlags: set[SocketFlag]): cint =
## Converts the flags into the underlying OS representation.
@@ -172,27 +177,27 @@ when defined(ssl):
raise newException(system.IOError, "Certificate file could not be found: " & certFile)
if keyFile != "" and not existsFile(keyFile):
raise newException(system.IOError, "Key file could not be found: " & keyFile)
if certFile != "":
var ret = SSLCTXUseCertificateChainFile(ctx, certFile)
if ret != 1:
raiseSSLError()
# TODO: Password? www.rtfm.com/openssl-examples/part1.pdf
if keyFile != "":
if SSL_CTX_use_PrivateKey_file(ctx, keyFile,
SSL_FILETYPE_PEM) != 1:
raiseSSLError()
if SSL_CTX_check_private_key(ctx) != 1:
raiseSSLError("Verification of private key file failed.")
proc newContext*(protVersion = protSSLv23, verifyMode = CVerifyPeer,
certFile = "", keyFile = ""): SSLContext =
## Creates an SSL context.
##
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
## are available with the addition of ``protSSLv23`` which allows for
##
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
## are available with the addition of ``protSSLv23`` which allows for
## compatibility with all of them.
##
## There are currently only two options for verify mode;
@@ -217,7 +222,7 @@ when defined(ssl):
newCTX = SSL_CTX_new(SSLv3_method())
of protTLSv1:
newCTX = SSL_CTX_new(TLSv1_method())
if newCTX.SSLCTXSetCipherList("ALL") != 1:
raiseSSLError()
case verifyMode
@@ -236,9 +241,13 @@ when defined(ssl):
## Wraps a socket in an SSL context. This function effectively turns
## ``socket`` into an SSL socket.
##
## This must be called on an unconnected socket; an SSL session will
## be started when the socket is connected.
##
## **Disclaimer**: This code is not well tested, may be very unsafe and
## prone to security vulnerabilities.
assert (not socket.isSSL)
socket.isSSL = true
socket.sslContext = ctx
socket.sslHandle = SSLNew(SSLCTX(socket.sslContext))
@@ -246,10 +255,28 @@ when defined(ssl):
socket.sslHasPeekChar = false
if socket.sslHandle == nil:
raiseSSLError()
if SSLSetFd(socket.sslHandle, socket.fd) != 1:
raiseSSLError()
proc wrapSocket*(ctx: SSLContext, socket: Socket, handshake: SslHandshakeType) =
## Wraps a socket in an SSL context. This function effectively turns
## ``socket`` into an SSL socket.
##
## This should be called on a connected socket, and will perform
## an SSL handshake immediately.
##
## **Disclaimer**: This code is not well tested, may be very unsafe and
## prone to security vulnerabilities.
wrapSocket(ctx, socket)
case handshake
of handshakeAsClient:
let ret = SSLConnect(socket.sslHandle)
socketError(socket, ret)
of handshakeAsServer:
let ret = SSLAccept(socket.sslHandle)
socketError(socket, ret)
proc getSocketError*(socket: Socket): OSErrorCode =
## Checks ``osLastError`` for a valid error. If it has been reset it uses
## the last error stored in the socket object.
@@ -302,7 +329,7 @@ proc socketError*(socket: Socket, err: int = -1, async = false,
of SSL_ERROR_SSL:
raiseSSLError()
else: raiseSSLError("Unknown Error")
if err == -1 and not (when defined(ssl): socket.isSSL else: false):
var lastE = if lastError.int == -1: getSocketError(socket) else: lastError
if async:
@@ -317,8 +344,8 @@ proc socketError*(socket: Socket, err: int = -1, async = false,
else: raiseOSError(lastE)
proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
## Marks ``socket`` as accepting connections.
## ``Backlog`` specifies the maximum length of the
## Marks ``socket`` as accepting connections.
## ``Backlog`` specifies the maximum length of the
## queue of pending connections.
##
## Raises an EOS error upon failure.
@@ -360,7 +387,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
## The resulting client will inherit any properties of the server socket. For
## example: whether the socket is buffered or not.
##
## **Note**: ``client`` must be initialised (with ``new``), this function
## **Note**: ``client`` must be initialised (with ``new``), this function
## makes no effort to initialise the ``client`` variable.
##
## The ``accept`` call may result in an error if the connecting socket
@@ -372,7 +399,7 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
var addrLen = sizeof(sockAddress).SockLen
var sock = accept(server.fd, cast[ptr SockAddr](addr(sockAddress)),
addr(addrLen))
if sock == osInvalidSocket:
let err = osLastError()
if flags.isDisconnectionError(err):
@@ -386,11 +413,11 @@ proc acceptAddr*(server: Socket, client: var Socket, address: var string,
when defined(ssl):
if server.isSSL:
# We must wrap the client sock in a ssl context.
server.sslContext.wrapSocket(client)
let ret = SSLAccept(client.sslHandle)
socketError(client, ret, false)
# Client socket is set above.
address = $inet_ntoa(sockAddress.sin_addr)
@@ -398,9 +425,9 @@ when false: #defined(ssl):
proc acceptAddrSSL*(server: Socket, client: var Socket,
address: var string): SSLAcceptResult {.
tags: [ReadIOEffect].} =
## This procedure should only be used for non-blocking **SSL** sockets.
## This procedure should only be used for non-blocking **SSL** sockets.
## It will immediately return with one of the following values:
##
##
## ``AcceptSuccess`` will be returned when a client has been successfully
## accepted and the handshake has been successfully performed between
## ``server`` and the newly connected client.
@@ -417,7 +444,7 @@ when false: #defined(ssl):
if server.isSSL:
client.setBlocking(false)
# We must wrap the client sock in a ssl context.
if not client.isSSL or client.sslHandle == nil:
server.sslContext.wrapSocket(client)
let ret = SSLAccept(client.sslHandle)
@@ -450,7 +477,7 @@ proc accept*(server: Socket, client: var Socket,
flags = {SocketFlag.SafeDisconn}) {.tags: [ReadIOEffect].} =
## Equivalent to ``acceptAddr`` but doesn't return the address, only the
## socket.
##
##
## **Note**: ``client`` must be initialised (with ``new``), this function
## makes no effort to initialise the ``client`` variable.
##
@@ -504,7 +531,7 @@ proc setSockOpt*(socket: Socket, opt: SOBool, value: bool, level = SOL_SOCKET) {
var valuei = cint(if value: 1 else: 0)
setSockOptInt(socket.fd, cint(level), toCInt(opt), valuei)
proc connect*(socket: Socket, address: string, port = Port(0),
proc connect*(socket: Socket, address: string, port = Port(0),
af: Domain = AF_INET) {.tags: [ReadIOEffect].} =
## Connects socket to ``address``:``port``. ``Address`` can be an IP address or a
## host name. If ``address`` is a host name, this function will try each IP
@@ -526,7 +553,7 @@ proc connect*(socket: Socket, address: string, port = Port(0),
dealloc(aiList)
if not success: raiseOSError(lastError)
when defined(ssl):
if socket.isSSL:
# RFC3546 for SNI specifies that IP addresses are not allowed.
@@ -634,12 +661,12 @@ proc recv*(socket: Socket, data: pointer, size: int): int {.tags: [ReadIOEffect]
if socket.isBuffered:
if socket.bufLen == 0:
retRead(0'i32, 0)
var read = 0
while read < size:
if socket.currPos >= socket.bufLen:
retRead(0'i32, read)
let chunk = min(socket.bufLen-socket.currPos, size-read)
var d = cast[cstring](data)
assert size-read >= chunk
@@ -686,7 +713,7 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int,
else:
if timeout - int(waited * 1000.0) < 1:
raise newException(TimeoutError, "Call to '" & funcName & "' timed out.")
when defined(ssl):
if socket.isSSL:
if socket.hasDataBuffered:
@@ -695,7 +722,7 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int,
let sslPending = SSLPending(socket.sslHandle)
if sslPending != 0:
return sslPending
var startTime = epochTime()
let selRet = select(socket, timeout - int(waited * 1000.0))
if selRet < 0: raiseOSError(osLastError())
@@ -706,8 +733,8 @@ proc waitFor(socket: Socket, waited: var float, timeout, size: int,
proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {.
tags: [ReadIOEffect, TimeEffect].} =
## overload with a ``timeout`` parameter in milliseconds.
var waited = 0.0 # number of seconds already waited
var waited = 0.0 # number of seconds already waited
var read = 0
while read < size:
let avail = waitFor(socket, waited, timeout, size-read, "recv")
@@ -718,7 +745,7 @@ proc recv*(socket: Socket, data: pointer, size: int, timeout: int): int {.
if result < 0:
return result
inc(read, result)
result = read
proc recv*(socket: Socket, data: var string, size: int, timeout = -1,
@@ -752,7 +779,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
var res = socket.readIntoBuf(0'i32)
if res <= 0:
result = res
c = socket.buffer[socket.currPos]
else:
when defined(ssl):
@@ -760,7 +787,7 @@ proc peekChar(socket: Socket, c: var char): int {.tags: [ReadIOEffect].} =
if not socket.sslHasPeekChar:
result = SSLRead(socket.sslHandle, addr(socket.sslPeekChar), 1)
socket.sslHasPeekChar = true
c = socket.sslPeekChar
return
result = recv(socket.fd, addr(c), 1, MSG_PEEK)
@@ -773,7 +800,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
## If a full line is read ``\r\L`` is not
## added to ``line``, however if solely ``\r\L`` is read then ``line``
## will be set to it.
##
##
## If the socket is disconnected, ``line`` will be set to ``""``.
##
## An EOS exception will be raised in the case of a socket error.
@@ -782,7 +809,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
## the specified time an ETimeout exception will be raised.
##
## **Warning**: Only the ``SafeDisconn`` flag is currently supported.
template addNLIfEmpty(): stmt =
if line.len == 0:
line.add("\c\L")
@@ -809,7 +836,7 @@ proc readLine*(socket: Socket, line: var TaintedString, timeout = -1,
elif n <= 0: raiseSockError()
addNLIfEmpty()
return
elif c == '\L':
elif c == '\L':
addNLIfEmpty()
return
add(line.string, c)
@@ -827,7 +854,7 @@ proc recvFrom*(socket: Socket, data: var string, length: int,
## so when ``socket`` is buffered the non-buffered implementation will be
## used. Therefore if ``socket`` contains something in its buffer this
## function will make no effort to return it.
# TODO: Buffered sockets
data.setLen(length)
var sockAddress: Sockaddr_in
@@ -861,16 +888,16 @@ proc send*(socket: Socket, data: pointer, size: int): int {.
tags: [WriteIOEffect].} =
## Sends data to a socket.
##
## **Note**: This is a low-level version of ``send``. You likely should use
## **Note**: This is a low-level version of ``send``. You likely should use
## the version below.
when defined(ssl):
if socket.isSSL:
return SSLWrite(socket.sslHandle, cast[cstring](data), size)
when useWinVersion or defined(macosx):
result = send(socket.fd, data, size.cint, 0'i32)
else:
when defined(solaris):
when defined(solaris):
const MSG_NOSIGNAL = 0
result = send(socket.fd, data, size, int32(MSG_NOSIGNAL))
@@ -895,7 +922,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
size: int, af: Domain = AF_INET, flags = 0'i32): int {.
tags: [WriteIOEffect].} =
## This proc sends ``data`` to the specified ``address``,
## which may be an IP address or a hostname, if a hostname is specified
## which may be an IP address or a hostname, if a hostname is specified
## this function will try each IP of that hostname.
##
##
@@ -904,7 +931,7 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
##
## **Note:** This proc is not available for SSL sockets.
var aiList = getAddrInfo(address, port, af)
# try all possibilities:
var success = false
var it = aiList
@@ -918,10 +945,10 @@ proc sendTo*(socket: Socket, address: string, port: Port, data: pointer,
dealloc(aiList)
proc sendTo*(socket: Socket, address: string, port: Port,
proc sendTo*(socket: Socket, address: string, port: Port,
data: string): int {.tags: [WriteIOEffect].} =
## This proc sends ``data`` to the specified ``address``,
## which may be an IP address or a hostname, if a hostname is specified
## which may be an IP address or a hostname, if a hostname is specified
## this function will try each IP of that hostname.
##
## This is the high-level version of the above ``sendTo`` function.
@@ -958,7 +985,7 @@ proc connectAsync(socket: Socket, name: string, port = Port(0),
if lastError.int32 == EINTR or lastError.int32 == EINPROGRESS:
success = true
break
it = it.ai_next
dealloc(aiList)
@@ -971,7 +998,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int,
## The ``timeout`` paremeter specifies the time in milliseconds to allow for
## the connection to the server to be made.
socket.fd.setBlocking(false)
socket.connectAsync(address, port, af)
var s = @[socket.fd]
if selectWrite(s, timeout) != 1:
@@ -983,7 +1010,7 @@ proc connect*(socket: Socket, address: string, port = Port(0), timeout: int,
doAssert socket.handshake()
socket.fd.setBlocking(true)
proc isSsl*(socket: Socket): bool =
proc isSsl*(socket: Socket): bool =
## Determines whether ``socket`` is a SSL socket.
when defined(ssl):
result = socket.isSSL
@@ -1014,7 +1041,7 @@ proc IPv4_broadcast*(): IpAddress =
proc IPv6_any*(): IpAddress =
## Returns the IPv6 any address (::0), which can be used
## to listen on all available network adapters
## to listen on all available network adapters
result = IpAddress(
family: IpAddressFamily.IPv6,
address_v6: [0'u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
@@ -1152,7 +1179,7 @@ proc parseIPv6Address(address_str: string): IpAddress =
if not seperatorValid:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid seperator")
if lastWasColon:
if lastWasColon:
if dualColonGroup != -1:
raise newException(ValueError,
"Invalid IP Address. Address contains more than one \"::\" seperator")
@@ -1165,14 +1192,14 @@ proc parseIPv6Address(address_str: string): IpAddress =
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
currentShort = 0
groupCount.inc()
groupCount.inc()
if dualColonGroup != -1: seperatorValid = false
elif i == 0: # only valid if address starts with ::
if address_str[1] != ':':
raise newException(ValueError,
"Invalid IP Address. Address may not start with \":\"")
else: # i == high(address_str) - only valid if address ends with ::
if address_str[high(address_str)-1] != ':':
if address_str[high(address_str)-1] != ':':
raise newException(ValueError,
"Invalid IP Address. Address may not end with \":\"")
lastWasColon = true