Async SSL support.

This commit is contained in:
Dominik Picheta
2014-08-31 12:54:43 +01:00
parent d26d42b88e
commit bb1e87ce4d
5 changed files with 239 additions and 63 deletions

View File

@@ -47,6 +47,7 @@
import asyncdispatch
import rawsockets
import net
import os
when defined(ssl):
import openssl
@@ -54,7 +55,22 @@ when defined(ssl):
type
# TODO: I would prefer to just do:
# PAsyncSocket* {.borrow: `.`.} = distinct PSocket. But that doesn't work.
AsyncSocketDesc {.borrow: `.`.} = distinct TSocketImpl
AsyncSocketDesc = object
fd*: SocketHandle
case isBuffered*: bool # determines whether this socket is buffered.
of true:
buffer*: array[0..BufferSize, char]
currPos*: int # current index in buffer
bufLen*: int # current length of buffer
of false: nil
case isSsl: bool
of true:
when defined(ssl):
sslHandle: SslPtr
sslContext: SslContext
bioIn: BIO
bioOut: BIO
of false: nil
AsyncSocket* = ref AsyncSocketDesc
{.deprecated: [PAsyncSocket: AsyncSocket].}
@@ -63,7 +79,7 @@ type
proc newSocket(fd: TAsyncFD, isBuff: bool): PAsyncSocket =
assert fd != osInvalidSocket.TAsyncFD
new(result.PSocket)
new(result)
result.fd = fd.SocketHandle
result.isBuffered = isBuff
if isBuff:
@@ -74,22 +90,94 @@ proc newAsyncSocket*(domain: TDomain = AF_INET, typ: TType = SOCK_STREAM,
## Creates a new asynchronous socket.
result = newSocket(newAsyncRawSocket(domain, typ, protocol), buffered)
when defined(ssl):
proc getSslError(handle: SslPtr, err: cint): cint =
assert err < 0
var ret = SSLGetError(handle, err.cint)
case ret
of SSL_ERROR_ZERO_RETURN:
raiseSSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
return ret
of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_READ:
return ret
of SSL_ERROR_WANT_X509_LOOKUP:
raiseSSLError("Function for x509 lookup has been called.")
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
raiseSSLError()
else: raiseSSLError("Unknown Error")
proc sendPendingSslData(socket: AsyncSocket,
flags: set[TSocketFlags]) {.async.} =
let len = bioCtrlPending(socket.bioOut)
if len > 0:
var data = newStringOfCap(len)
let read = bioRead(socket.bioOut, addr data[0], len)
assert read != 0
if read < 0:
raiseSslError()
data.setLen(read)
await socket.fd.TAsyncFd.send(data, flags)
proc appeaseSsl(socket: AsyncSocket, flags: set[TSocketFlags],
sslError: cint) {.async.} =
case sslError
of SSL_ERROR_WANT_WRITE:
await sendPendingSslData(socket, flags)
of SSL_ERROR_WANT_READ:
var data = await recv(socket.fd.TAsyncFD, BufferSize, flags)
let ret = bioWrite(socket.bioIn, addr data[0], data.len.cint)
if ret < 0:
raiseSSLError()
else:
raiseSSLError("Cannot appease SSL.")
template sslLoop(socket: AsyncSocket, flags: set[TSocketFlags],
op: expr) =
var opResult {.inject.} = -1.cint
while opResult < 0:
opResult = op
# Bit hackish here.
# TODO: Introduce an async template transformation pragma?
yield sendPendingSslData(socket, flags)
if opResult < 0:
let err = getSslError(socket.sslHandle, opResult.cint)
yield appeaseSsl(socket, flags, err.cint)
proc connect*(socket: PAsyncSocket, address: string, port: TPort,
af = AF_INET): Future[void] =
af = AF_INET) {.async.} =
## Connects ``socket`` to server at ``address:port``.
##
## Returns a ``Future`` which will complete when the connection succeeds
## or an error occurs.
result = connect(socket.fd.TAsyncFD, address, port, af)
await connect(socket.fd.TAsyncFD, address, port, af)
let flags = {TSocketFlags.SafeDisconn}
if socket.isSsl:
when defined(ssl):
sslSetConnectState(socket.sslHandle)
sslLoop(socket, flags, sslDoHandshake(socket.sslHandle))
proc readIntoBuf(socket: PAsyncSocket,
flags: set[TSocketFlags]): Future[int] {.async.} =
var data = await recv(socket.fd.TAsyncFD, BufferSize, flags)
if data.len != 0:
copyMem(addr socket.buffer[0], addr data[0], data.len)
socket.bufLen = data.len
socket.currPos = 0
result = data.len
if socket.isSsl:
when defined(ssl):
# SSL mode.
let ret = bioWrite(socket.bioIn, addr socket.buffer[0], data.len.cint)
if ret < 0:
raiseSSLError()
sslLoop(socket, flags,
sslRead(socket.sslHandle, addr socket.buffer[0], BufferSize.cint))
socket.currPos = 0
socket.bufLen = opResult # Injected from sslLoop template.
result = opResult
else:
# Not in SSL mode.
socket.bufLen = data.len
socket.currPos = 0
result = data.len
proc recv*(socket: PAsyncSocket, size: int,
flags = {TSocketFlags.SafeDisconn}): Future[string] {.async.} =
@@ -131,11 +219,18 @@ proc recv*(socket: PAsyncSocket, size: int,
result = await recv(socket.fd.TAsyncFD, size, flags)
proc send*(socket: PAsyncSocket, data: string,
flags = {TSocketFlags.SafeDisconn}): Future[void] =
flags = {TSocketFlags.SafeDisconn}) {.async.} =
## Sends ``data`` to ``socket``. The returned future will complete once all
## data has been sent.
assert socket != nil
result = send(socket.fd.TAsyncFD, data, flags)
if socket.isSsl:
when defined(ssl):
var copy = data
sslLoop(socket, flags,
sslWrite(socket.sslHandle, addr copy[0], copy.len.cint))
await sendPendingSslData(socket, flags)
else:
await send(socket.fd.TAsyncFD, data, flags)
proc acceptAddr*(socket: PAsyncSocket, flags = {TSocketFlags.SafeDisconn}):
Future[tuple[address: string, client: PAsyncSocket]] =
@@ -240,24 +335,67 @@ proc recvLine*(socket: PAsyncSocket,
return
add(result.string, c)
proc bindAddr*(socket: PAsyncSocket, port = TPort(0), address = "") =
## Binds ``address``:``port`` to the socket.
##
## If ``address`` is "" then ADDR_ANY will be bound.
socket.PSocket.bindAddr(port, address)
proc listen*(socket: PAsyncSocket, backlog = SOMAXCONN) =
proc listen*(socket: Socket, backlog = SOMAXCONN) {.tags: [ReadIOEffect].} =
## Marks ``socket`` as accepting connections.
## ``Backlog`` specifies the maximum length of the
## queue of pending connections.
##
## Raises an EOS error upon failure.
socket.PSocket.listen(backlog)
if listen(socket.fd, backlog) < 0'i32: raiseOSError(osLastError())
proc bindAddr*(socket: Socket, port = Port(0), address = "") {.
tags: [ReadIOEffect].} =
## Binds ``address``:``port`` to the socket.
##
## If ``address`` is "" then ADDR_ANY will be bound.
if address == "":
var name: Sockaddr_in
when defined(Windows) or defined(nimdoc):
name.sin_family = toInt(AF_INET).int16
else:
name.sin_family = toInt(AF_INET)
name.sin_port = htons(int16(port))
name.sin_addr.s_addr = htonl(INADDR_ANY)
if bindAddr(socket.fd, cast[ptr SockAddr](addr(name)),
sizeof(name).Socklen) < 0'i32:
raiseOSError(osLastError())
else:
var aiList = getAddrInfo(address, port, AF_INET)
if bindAddr(socket.fd, aiList.ai_addr, aiList.ai_addrlen.Socklen) < 0'i32:
dealloc(aiList)
raiseOSError(osLastError())
dealloc(aiList)
proc close*(socket: PAsyncSocket) =
## Closes the socket.
socket.fd.TAsyncFD.closeSocket()
# TODO SSL
when defined(ssl):
if socket.isSSL:
let res = SslShutdown(socket.sslHandle)
if res == 0:
if SslShutdown(socket.sslHandle) != 1:
raiseSslError()
elif res != 1:
raiseSslError()
when defined(ssl):
proc wrapSocket*(ctx: SslContext, socket: AsyncSocket) =
## Wraps a socket in an SSL context. This function effectively turns
## ``socket`` into an SSL socket.
##
## **Disclaimer**: This code is not well tested, may be very unsafe and
## prone to security vulnerabilities.
socket.isSsl = true
socket.sslContext = ctx
socket.sslHandle = SSLNew(PSSLCTX(socket.sslContext))
if socket.sslHandle == nil:
raiseSslError()
socket.bioIn = bioNew(bio_s_mem())
socket.bioOut = bioNew(bio_s_mem())
sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
when isMainModule:
type

View File

@@ -78,6 +78,7 @@
import sockets, strutils, parseurl, parseutils, strtabs, base64, os
import asyncnet, asyncdispatch
import rawsockets
from net import nil
type
Response* = tuple[
@@ -164,16 +165,17 @@ proc parseBody(s: TSocket, headers: PStringTable, timeout: int): string =
var contentLengthHeader = headers["Content-Length"]
if contentLengthHeader != "":
var length = contentLengthHeader.parseint()
result = newString(length)
var received = 0
while true:
if received >= length: break
let r = s.recv(addr(result[received]), length-received, timeout)
if r == 0: break
received += r
if received != length:
httpError("Got invalid content length. Expected: " & $length &
" got: " & $received)
if length > 0:
result = newString(length)
var received = 0
while true:
if received >= length: break
let r = s.recv(addr(result[received]), length-received, timeout)
if r == 0: break
received += r
if received != length:
httpError("Got invalid content length. Expected: " & $length &
" got: " & $received)
else:
# (http://tools.ietf.org/html/rfc2616#section-4.4) NR.4 TODO
@@ -444,11 +446,13 @@ type
headers: StringTableRef
maxRedirects: int
userAgent: string
when defined(ssl):
sslContext: net.SslContext
{.deprecated: [PAsyncHttpClient: AsyncHttpClient].}
proc newAsyncHttpClient*(userAgent = defUserAgent,
maxRedirects = 5): AsyncHttpClient =
maxRedirects = 5, sslContext = defaultSslContext): AsyncHttpClient =
## Creates a new PAsyncHttpClient instance.
##
## ``userAgent`` specifies the user agent that will be used when making
@@ -456,10 +460,13 @@ proc newAsyncHttpClient*(userAgent = defUserAgent,
##
## ``maxRedirects`` specifies the maximum amount of redirects to follow,
## default is 5.
##
## ``sslContext`` specifies the SSL context to use for HTTPS requests.
new result
result.headers = newStringTable(modeCaseInsensitive)
result.userAgent = defUserAgent
result.maxRedirects = maxRedirects
result.sslContext = net.SslContext(sslContext)
proc close*(client: AsyncHttpClient) =
## Closes any connections held by the HTTP client.
@@ -519,12 +526,13 @@ proc parseBody(client: PAsyncHttpClient,
var contentLengthHeader = headers["Content-Length"]
if contentLengthHeader != "":
var length = contentLengthHeader.parseint()
result = await client.socket.recvFull(length)
if result == "":
httpError("Got disconnected while trying to read body.")
if result.len != length:
httpError("Received length doesn't match expected length. Wanted " &
$length & " got " & $result.len)
if length > 0:
result = await client.socket.recvFull(length)
if result == "":
httpError("Got disconnected while trying to read body.")
if result.len != length:
httpError("Received length doesn't match expected length. Wanted " &
$length & " got " & $result.len)
else:
# (http://tools.ietf.org/html/rfc2616#section-4.4) NR.4 TODO
@@ -590,14 +598,23 @@ proc newConnection(client: PAsyncHttpClient, url: TURL) {.async.} =
client.currentURL.scheme != url.scheme:
if client.connected: client.close()
client.socket = newAsyncSocket()
if url.scheme == "https":
assert false, "TODO SSL"
# TODO: I should be able to write 'net.TPort' here...
let port =
if url.port == "": rawsockets.TPort(80)
if url.port == "":
if url.scheme.toLower() == "https":
rawsockets.TPort(443)
else:
rawsockets.TPort(80)
else: rawsockets.TPort(url.port.parseInt)
if url.scheme.toLower() == "https":
when defined(ssl):
client.sslContext.wrapSocket(client.socket)
else:
raise newException(EHttpRequestErr,
"SSL support is not available. Cannot connect over SSL.")
await client.socket.connect(url.hostname, port)
client.currentURL = url
client.connected = true

View File

@@ -22,17 +22,17 @@ when defined(ssl):
when defined(ssl):
type
SSLError* = object of Exception
SslError* = object of Exception
SSLCVerifyMode* = enum
SslCVerifyMode* = enum
CVerifyNone, CVerifyPeer
SSLProtVersion* = enum
SslProtVersion* = enum
protSSLv2, protSSLv3, protTLSv1, protSSLv23
SSLContext* = distinct PSSLCTX
SslContext* = distinct SslCtx
SSLAcceptResult* = enum
SslAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
{.deprecated: [ESSL: SSLError, TSSLCVerifyMode: SSLCVerifyMode,
@@ -125,7 +125,8 @@ when defined(ssl):
ErrLoadBioStrings()
OpenSSL_add_all_algorithms()
proc raiseSSLError(s = "") =
proc raiseSSLError*(s = "") =
## Raises a new SSL error.
if s != "":
raise newException(SSLError, s)
let err = ErrPeekLastError()
@@ -161,7 +162,7 @@ when defined(ssl):
certFile = "", keyFile = ""): PSSLContext =
## Creates an SSL context.
##
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 are
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1
## are available with the addition of ``protSSLv23`` which allows for
## compatibility with all of them.
##

View File

@@ -44,8 +44,11 @@ type
ReplyError* = object of IOError
AsyncSmtp* = ref object
sock: AsyncSocket
{.deprecated: [EInvalidReply: ReplyError, TMessage: Message, TSMTP: Smtp].}
proc debugSend(smtp: TSMTP, cmd: string) =
if smtp.debug:
echo("C:" & cmd)

View File

@@ -63,14 +63,13 @@ type
SslStruct {.final, pure.} = object
SslPtr* = ptr SslStruct
PSslPtr* = ptr SslPtr
PSSL_CTX* = SslPtr
PSSL* = SslPtr
SslCtx* = SslPtr
PSSL_METHOD* = SslPtr
PX509* = SslPtr
PX509_NAME* = SslPtr
PEVP_MD* = SslPtr
PBIO_METHOD* = SslPtr
PBIO* = SslPtr
BIO* = SslPtr
EVP_PKEY* = SslPtr
PRSA* = SslPtr
PASN1_UTCTIME* = SslPtr
@@ -85,6 +84,8 @@ type
des_key_schedule* = array[1..16, des_ks_struct]
{.deprecated: [PSSL: SslPtr, PSSL_CTX: SslCtx, PBIO: BIO].}
const
EVP_MAX_MD_SIZE* = 16 + 20
SSL_ERROR_NONE* = 0
@@ -282,6 +283,34 @@ proc SSL_CTX_ctrl*(ctx: PSSL_CTX, cmd: cInt, larg: int, parg: pointer): int{.
proc SSLCTXSetMode*(ctx: PSSL_CTX, mode: int): int =
result = SSL_CTX_ctrl(ctx, SSL_CTRL_MODE, mode, nil)
proc bioNew*(b: PBIO_METHOD): PBIO{.cdecl, dynlib: DLLUtilName, importc: "BIO_new".}
proc bioFreeAll*(b: PBIO){.cdecl, dynlib: DLLUtilName, importc: "BIO_free_all".}
proc bioSMem*(): PBIO_METHOD{.cdecl, dynlib: DLLUtilName, importc: "BIO_s_mem".}
proc bioCtrlPending*(b: PBIO): cInt{.cdecl, dynlib: DLLUtilName, importc: "BIO_ctrl_pending".}
proc bioRead*(b: PBIO, Buf: cstring, length: cInt): cInt{.cdecl,
dynlib: DLLUtilName, importc: "BIO_read".}
proc bioWrite*(b: PBIO, Buf: cstring, length: cInt): cInt{.cdecl,
dynlib: DLLUtilName, importc: "BIO_write".}
proc sslSetConnectState*(s: SslPtr) {.cdecl,
dynlib: DLLSSLName, importc: "SSL_set_connect_state".}
proc sslSetAcceptState*(s: SslPtr) {.cdecl,
dynlib: DLLSSLName, importc: "SSL_set_accept_state".}
proc sslRead*(ssl: SslPtr, buf: cstring, num: cInt): cInt{.cdecl,
dynlib: DLLSSLName, importc: "SSL_read".}
proc sslPeek*(ssl: SslPtr, buf: cstring, num: cInt): cInt{.cdecl,
dynlib: DLLSSLName, importc: "SSL_peek".}
proc sslWrite*(ssl: SslPtr, buf: cstring, num: cInt): cInt{.cdecl,
dynlib: DLLSSLName, importc: "SSL_write".}
proc sslSetBio*(ssl: SslPtr, rbio, wbio: BIO) {.cdecl,
dynlib: DLLSSLName, importc: "SSL_set_bio".}
proc sslDoHandshake*(ssl: SslPtr): cint {.cdecl,
dynlib: DLLSSLName, importc: "SSL_do_handshake".}
when true:
discard
else:
@@ -328,12 +357,7 @@ else:
proc SslConnect*(ssl: PSSL): cInt{.cdecl, dynlib: DLLSSLName, importc.}
proc SslRead*(ssl: PSSL, buf: SslPtr, num: cInt): cInt{.cdecl,
dynlib: DLLSSLName, importc.}
proc SslPeek*(ssl: PSSL, buf: SslPtr, num: cInt): cInt{.cdecl,
dynlib: DLLSSLName, importc.}
proc SslWrite*(ssl: PSSL, buf: SslPtr, num: cInt): cInt{.cdecl,
dynlib: DLLSSLName, importc.}
proc SslGetVersion*(ssl: PSSL): cstring{.cdecl, dynlib: DLLSSLName, importc.}
proc SslGetPeerCertificate*(ssl: PSSL): PX509{.cdecl, dynlib: DLLSSLName,
importc.}
@@ -393,14 +417,7 @@ else:
proc OPENSSLaddallalgorithms*(){.cdecl, dynlib: DLLUtilName, importc.}
proc CRYPTOcleanupAllExData*(){.cdecl, dynlib: DLLUtilName, importc.}
proc RandScreen*(){.cdecl, dynlib: DLLUtilName, importc.}
proc BioNew*(b: PBIO_METHOD): PBIO{.cdecl, dynlib: DLLUtilName, importc.}
proc BioFreeAll*(b: PBIO){.cdecl, dynlib: DLLUtilName, importc.}
proc BioSMem*(): PBIO_METHOD{.cdecl, dynlib: DLLUtilName, importc.}
proc BioCtrlPending*(b: PBIO): cInt{.cdecl, dynlib: DLLUtilName, importc.}
proc BioRead*(b: PBIO, Buf: cstring, length: cInt): cInt{.cdecl,
dynlib: DLLUtilName, importc.}
proc BioWrite*(b: PBIO, Buf: cstring, length: cInt): cInt{.cdecl,
dynlib: DLLUtilName, importc.}
proc d2iPKCS12bio*(b: PBIO, Pkcs12: SslPtr): SslPtr{.cdecl, dynlib: DLLUtilName,
importc.}
proc PKCS12parse*(p12: SslPtr, pass: cstring, pkey, cert, ca: var SslPtr): cint{.