Merge branch 'master' of github.com:Araq/Nimrod

This commit is contained in:
Araq
2012-07-24 01:14:53 +02:00
11 changed files with 507 additions and 145 deletions

View File

@@ -232,7 +232,9 @@ when isMainModule:
# check for new new connection & handle it
var list: seq[TSocket] = @[server.socket]
if select(list, 10) > 0:
var client = accept(server.socket)
var client: TSocket
new(client)
accept(server.socket, client)
try:
acceptRequest(server, client)
except:

View File

@@ -237,7 +237,7 @@ template `=~` *(s: string, pattern: TRegEx): expr =
## echo("syntax error")
##
when not definedInScope(matches):
var matches: array[0..maxSubPatterns-1, string]
var matches: array[0..re.maxSubPatterns-1, string]
match(s, pattern, matches)
# ------------------------- more string handling ------------------------------

View File

@@ -415,7 +415,7 @@ when hasSpawnH:
header: "<spawn.h>", final, pure.} = object
type
TSocklen* {.importc: "socklen_t", header: "<sys/socket.h>".} = cint
TSocklen* {.importc: "socklen_t", header: "<sys/socket.h>".} = cuint
TSa_Family* {.importc: "sa_family_t", header: "<sys/socket.h>".} = cint
TSockAddr* {.importc: "struct sockaddr", header: "<sys/socket.h>",

View File

@@ -11,8 +11,9 @@ import sockets, os
## This module implements an asynchronous event loop for sockets.
## It is akin to Python's asyncore module. Many modules that use sockets
## have an implementation for this module, those modules should all have a
## ``register`` function which you should use to add it to a dispatcher so
## that you can receive the events associated with that module.
## ``register`` function which you should use to add the desired objects to a
## dispatcher which you created so
## that you can receive the events associated with that module's object.
##
## Once everything is registered in a dispatcher, you need to call the ``poll``
## function in a while loop.
@@ -28,9 +29,46 @@ import sockets, os
## Most (if not all) modules that use asyncio provide a userArg which is passed
## on with the events. The type that you set userArg to must be inheriting from
## TObject!
##
## Asynchronous sockets
## ====================
##
## For most purposes you do not need to worry about the ``TDelegate`` type. The
## ``PAsyncSocket`` is what you are after. It's a reference to the ``TAsyncSocket``
## object. This object defines events which you should overwrite by your own
## procedures.
##
## For server sockets the only event you need to worry about is the ``handleAccept``
## event, in your handleAccept proc you should call ``accept`` on the server
## socket which will give you the client which is connecting. You should then
## set any events that you want to use on that client and add it to your dispatcher
## using the ``register`` procedure.
##
## An example ``handleAccept`` follows:
##
## .. code-block:: nimrod
##
## var disp: PDispatcher = newDispatcher()
## ...
## proc handleAccept(s: PAsyncSocket, arg: Pobject) {.nimcall.} =
## echo("Accepted client.")
## var client: PAsyncSocket
## new(client)
## s.accept(client)
## client.handleRead = ...
## disp.register(client)
## ...
##
## For client sockets you should only be interested in the ``handleRead`` and
## ``handleConnect`` events. The former gets called whenever the socket has
## received messages and can be read from and the latter gets called whenever
## the socket has established a connection to a server socket; from that point
## it can be safely written to.
type
TDelegate = object
TDelegate* = object
deleVal*: PObject
handleRead*: proc (h: PObject) {.nimcall.}
@@ -50,7 +88,7 @@ type
delegates: seq[PDelegate]
PAsyncSocket* = ref TAsyncSocket
TAsyncSocket = object of TObject
TAsyncSocket* = object of TObject
socket: TSocket
info: TInfo
@@ -62,6 +100,7 @@ type
handleAccept*: proc (s: PAsyncSocket, arg: PObject) {.nimcall.}
lineBuffer: TaintedString ## Temporary storage for ``recvLine``
sslNeedAccept: bool
TInfo* = enum
SockIdle, SockConnecting, SockConnected, SockListening, SockClosed
@@ -94,29 +133,60 @@ proc newAsyncSocket(userArg: PObject = nil): PAsyncSocket =
proc AsyncSocket*(domain: TDomain = AF_INET, typ: TType = SOCK_STREAM,
protocol: TProtocol = IPPROTO_TCP,
userArg: PObject = nil): PAsyncSocket =
userArg: PObject = nil, buffered = true): PAsyncSocket =
result = newAsyncSocket(userArg)
result.socket = socket(domain, typ, protocol)
result.socket = socket(domain, typ, protocol, buffered)
if result.socket == InvalidSocket: OSError()
result.socket.setBlocking(false)
proc asyncSockHandleConnect(h: PObject) =
when defined(ssl):
if PAsyncSocket(h).socket.isSSL and not
PAsyncSocket(h).socket.gotHandshake:
return
PAsyncSocket(h).info = SockConnected
PAsyncSocket(h).handleConnect(PAsyncSocket(h),
PAsyncSocket(h).userArg)
proc asyncSockHandleRead(h: PObject) =
when defined(ssl):
if PAsyncSocket(h).socket.isSSL and not
PAsyncSocket(h).socket.gotHandshake:
return
PAsyncSocket(h).handleRead(PAsyncSocket(h), PAsyncSocket(h).userArg)
when defined(ssl):
proc asyncSockDoHandshake(h: PObject) =
if PAsyncSocket(h).socket.isSSL and not
PAsyncSocket(h).socket.gotHandshake:
if PAsyncSocket(h).sslNeedAccept:
var d = ""
let ret = PAsyncSocket(h).socket.acceptAddrSSL(PAsyncSocket(h).socket, d)
assert ret != AcceptNoClient
if ret == AcceptSuccess:
PAsyncSocket(h).info = SockConnected
else:
# handshake will set socket's ``sslNoHandshake`` field.
discard PAsyncSocket(h).socket.handshake()
proc toDelegate(sock: PAsyncSocket): PDelegate =
result = newDelegate()
result.deleVal = sock
result.getSocket = (proc (h: PObject): tuple[info: TInfo, sock: TSocket] =
return (PAsyncSocket(h).info, PAsyncSocket(h).socket))
result.handleConnect = (proc (h: PObject) =
PAsyncSocket(h).info = SockConnected
PAsyncSocket(h).handleConnect(PAsyncSocket(h),
PAsyncSocket(h).userArg))
result.handleRead = (proc (h: PObject) =
PAsyncSocket(h).handleRead(PAsyncSocket(h),
PAsyncSocket(h).userArg))
result.handleConnect = asyncSockHandleConnect
result.handleRead = asyncSockHandleRead
result.handleAccept = (proc (h: PObject) =
PAsyncSocket(h).handleAccept(PAsyncSocket(h),
PAsyncSocket(h).userArg))
when defined(ssl):
result.task = asyncSockDoHandshake
proc connect*(sock: PAsyncSocket, name: string, port = TPort(0),
af: TDomain = AF_INET) =
## Begins connecting ``sock`` to ``name``:``port``.
@@ -137,23 +207,64 @@ proc listen*(sock: PAsyncSocket) =
sock.socket.listen()
sock.info = SockListening
proc acceptAddr*(server: PAsyncSocket): tuple[sock: PAsyncSocket,
address: string] =
## Equivalent to ``sockets.acceptAddr``.
var (client, a) = server.socket.acceptAddr()
if client == InvalidSocket: OSError()
client.setBlocking(false) # TODO: Needs to be tested.
var aSock: PAsyncSocket = newAsyncSocket()
aSock.socket = client
aSock.info = SockConnected
return (aSock, a)
proc acceptAddr*(server: PAsyncSocket, client: var PAsyncSocket,
address: var string) =
## Equivalent to ``sockets.acceptAddr``. This procedure should be called in
## a ``handleAccept`` event handler **only** once.
##
## **Note**: ``client`` needs to be initialised.
assert(client != nil)
var c: TSocket
new(c)
when defined(ssl):
if server.socket.isSSL:
var ret = server.socket.acceptAddrSSL(c, address)
# The following shouldn't happen because when this function is called
# it is guaranteed that there is a client waiting.
# (This should be called in handleAccept)
assert(ret != AcceptNoClient)
if ret == AcceptNoHandshake:
client.sslNeedAccept = true
else:
client.sslNeedAccept = false
client.info = SockConnected
else:
server.socket.acceptAddr(c, address)
client.sslNeedAccept = false
client.info = SockConnected
else:
server.socket.acceptAddr(c, address)
client.sslNeedAccept = false
client.info = SockConnected
proc accept*(server: PAsyncSocket): PAsyncSocket =
if c == InvalidSocket: OSError()
c.setBlocking(false) # TODO: Needs to be tested.
client.socket = c
client.lineBuffer = ""
proc accept*(server: PAsyncSocket, client: var PAsyncSocket) =
## Equivalent to ``sockets.accept``.
var (client, a) = server.acceptAddr()
return client
var dummyAddr = ""
server.acceptAddr(client, dummyAddr)
proc acceptAddr*(server: PAsyncSocket): tuple[sock: PAsyncSocket,
address: string] {.deprecated.} =
## Equivalent to ``sockets.acceptAddr``.
##
## **Warning**: This is deprecated in favour of the above.
var client = newAsyncSocket()
var address: string = ""
acceptAddr(server, client, address)
return (client, address)
proc accept*(server: PAsyncSocket): PAsyncSocket {.deprecated.} =
## Equivalent to ``sockets.accept``.
##
## **Warning**: This is deprecated.
new(result)
var address = ""
server.acceptAddr(result, address)
proc newDispatcher*(): PDispatcher =
new(result)
@@ -210,8 +321,9 @@ proc recvLine*(s: PAsyncSocket, line: var TaintedString): bool =
if s.lineBuffer.len > 0:
string(line).add(s.lineBuffer.string)
setLen(s.lineBuffer.string, 0)
string(line).add(dataReceived.string)
if string(line) == "":
line = "\c\L".TaintedString
result = true
of RecvPartialLine:
string(s.lineBuffer).add(dataReceived.string)
@@ -263,7 +375,7 @@ proc poll*(d: PDispatcher, timeout: int = 500): bool =
if readSocks.len() == 0 and writeSocks.len() == 0:
return False
if select(readSocks, writeSocks, timeout) != 0:
for i in 0..len(d.delegates)-1:
if i > len(d.delegates)-1: break # One delegate might've been removed.
@@ -294,7 +406,11 @@ proc poll*(d: PDispatcher, timeout: int = 500): bool =
# Execute tasks
for i in items(d.delegates):
i.task(i.deleVal)
proc len*(disp: PDispatcher): int =
## Retrieves the amount of delegates in ``disp``.
return disp.delegates.len
when isMainModule:
type
PIntType = ref TIntType
@@ -320,7 +436,10 @@ when isMainModule:
proc testAccept(s: PAsyncSocket, arg: PObject) =
echo("Accepting client! " & $PMyArg(arg).val)
var (client, address) = s.acceptAddr()
var client: PAsyncSocket
new(client)
var address = ""
s.acceptAddr(client, address)
echo("Accepted ", address)
client.handleRead = testRead
var userArg: PIntType

View File

@@ -241,7 +241,10 @@ proc port*(s: var TServer): TPort =
proc next*(s: var TServer) =
## proceed to the first/next request.
let (client, ip) = acceptAddr(s.socket)
var client: TSocket
new(client)
var ip: string
acceptAddr(s.socket, client, ip)
s.client = client
s.ip = ip
s.headers = newStringTable(modeCaseInsensitive)

View File

@@ -865,7 +865,7 @@ template `=~`*(s: string, pattern: TPeg): bool =
## echo("syntax error")
##
when not definedInScope(matches):
var matches: array[0..maxSubpatterns-1, string]
var matches: array[0..pegs.maxSubpatterns-1, string]
match(s, pattern, matches)
# ------------------------- more string handling ------------------------------

View File

@@ -86,6 +86,7 @@ proc open*(s: var TScgiState, port = TPort(4000), address = "127.0.0.1") =
s.input = newString(s.buflen) # will be reused
s.server = socket()
new(s.client) # Initialise s.client for `next`
if s.server == InvalidSocket: scgiError("could not open socket")
#s.server.connect(connectionName, port)
bindAddr(s.server, port, address)
@@ -101,7 +102,7 @@ proc next*(s: var TScgistate, timeout: int = -1): bool =
## Returns `True` if a new request has been processed.
var rsocks = @[s.server]
if select(rsocks, timeout) == 1 and rsocks.len == 0:
s.client = accept(s.server)
accept(s.server, s.client)
var L = 0
while true:
var d = s.client.recvChar()
@@ -159,7 +160,7 @@ proc getSocket(h: PObject): tuple[info: TInfo, sock: TSocket] =
proc handleAccept(h: PObject) =
var s = PAsyncScgiState(h)
s.client = accept(s.server)
accept(s.server, s.client)
var L = 0
while true:
var d = s.client.recvChar()

View File

@@ -37,10 +37,10 @@ when defined(ssl):
TSSLProtVersion* = enum
protSSLv2, protSSLv3, protTLSv1, protSSLv23
TSSLOptions* = object
verifyMode*: TSSLCVerifyMode
certFile*, keyFile*: string
protVer*: TSSLprotVersion
PSSLContext* = distinct PSSLCTX
TSSLAcceptResult* = enum
AcceptNoClient = 0, AcceptNoHandshake, AcceptSuccess
type
TSocketImpl = object ## socket type
@@ -55,8 +55,8 @@ type
case isSsl: bool
of true:
sslHandle: PSSL
sslContext: PSSLCTX
wrapOptions: TSSLOptions
sslContext: PSSLContext
sslNoHandshake: bool # True if needs handshake.
of false: nil
TSocket* = ref TSocketImpl
@@ -211,22 +211,49 @@ when defined(ssl):
raise newException(ESSL, $errStr)
# http://simplestcodings.blogspot.co.uk/2010/08/secure-server-client-using-openssl-in-c.html
proc loadCertificates(socket: var TSocket, certFile, keyFile: string) =
proc loadCertificates(ctx: PSSL_CTX, certFile, keyFile: string) =
if certFile != "":
if SSLCTXUseCertificateFile(socket.sslContext, certFile,
SSL_FILETYPE_PEM) != 1:
var ret = SSLCTXUseCertificateFile(ctx, certFile,
SSL_FILETYPE_PEM)
if ret != 1:
SSLError()
# TODO: Password? www.rtfm.com/openssl-examples/part1.pdf
if keyFile != "":
if SSL_CTX_use_PrivateKey_file(socket.sslContext, keyFile,
if SSL_CTX_use_PrivateKey_file(ctx, keyFile,
SSL_FILETYPE_PEM) != 1:
SSLError()
if SSL_CTX_check_private_key(socket.sslContext) != 1:
if SSL_CTX_check_private_key(ctx) != 1:
SSLError("Verification of private key file failed.")
proc wrapSocket*(socket: var TSocket, protVersion = ProtSSLv23,
verifyMode = CVerifyPeer,
certFile = "", keyFile = "") =
proc newContext*(protVersion = ProtSSLv23, verifyMode = CVerifyPeer,
certFile = "", keyFile = ""): PSSLContext =
var newCTX: PSSL_CTX
case protVersion
of protSSLv23:
newCTX = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support.
of protSSLv2:
newCTX = SSL_CTX_new(SSLv2_method())
of protSSLv3:
newCTX = SSL_CTX_new(SSLv3_method())
of protTLSv1:
newCTX = SSL_CTX_new(TLSv1_method())
if newCTX.SSLCTXSetCipherList("ALL") != 1:
SSLError()
case verifyMode
of CVerifyPeer:
newCTX.SSLCTXSetVerify(SSLVerifyPeer, nil)
of CVerifyNone:
newCTX.SSLCTXSetVerify(SSLVerifyNone, nil)
if newCTX == nil:
SSLError()
newCTX.loadCertificates(certFile, keyFile)
return PSSLContext(newCTX)
proc wrapSocket*(ctx: PSSLContext, socket: TSocket) =
## Creates a SSL context for ``socket`` and wraps the socket in it.
##
## Protocol version specifies the protocol to use. SSLv2, SSLv3, TLSv1 are
@@ -247,44 +274,15 @@ when defined(ssl):
## most likely very prone to security vulnerabilities.
socket.isSSL = true
socket.wrapOptions.verifyMode = verifyMode
socket.wrapOptions.certFile = certFile
socket.wrapOptions.keyFile = keyFile
socket.wrapOptions.protVer = protVersion
case protVersion
of protSSLv23:
socket.sslContext = SSL_CTX_new(SSLv23_method()) # SSlv2,3 and TLS1 support.
of protSSLv2:
socket.sslContext = SSL_CTX_new(SSLv2_method())
of protSSLv3:
socket.sslContext = SSL_CTX_new(SSLv3_method())
of protTLSv1:
socket.sslContext = SSL_CTX_new(TLSv1_method())
if socket.sslContext.SSLCTXSetCipherList("ALL") != 1:
SSLError()
case verifyMode
of CVerifyPeer:
socket.sslContext.SSLCTXSetVerify(SSLVerifyPeer, nil)
of CVerifyNone:
socket.sslContext.SSLCTXSetVerify(SSLVerifyNone, nil)
if socket.sslContext == nil:
SSLError()
socket.loadCertificates(certFile, keyFile)
socket.sslHandle = SSLNew(socket.sslContext)
socket.sslContext = ctx
socket.sslHandle = SSLNew(PSSLCTX(socket.sslContext))
socket.sslNoHandshake = false
if socket.sslHandle == nil:
SSLError()
if SSLSetFd(socket.sslHandle, socket.fd) != 1:
SSLError()
proc wrapSocket*(socket: var TSocket, wo: TSSLOptions) =
## A variant of the above with a options object.
wrapSocket(socket, wo.protVer, wo.verifyMode, wo.certFile, wo.keyFile)
proc listen*(socket: TSocket, backlog = SOMAXCONN) =
## Marks ``socket`` as accepting connections.
## ``Backlog`` specifies the maximum length of the
@@ -352,7 +350,7 @@ proc bindAddr*(socket: TSocket, port = TPort(0), address = "") =
hints.ai_socktype = toInt(SOCK_STREAM)
hints.ai_protocol = toInt(IPPROTO_TCP)
gaiNim(address, port, hints, aiList)
if bindSocket(socket.fd, aiList.ai_addr, aiList.ai_addrLen.cint) < 0'i32:
if bindSocket(socket.fd, aiList.ai_addr, aiList.ai_addrLen.cuint) < 0'i32:
OSError()
when false:
@@ -386,17 +384,8 @@ proc getSockName*(socket: TSocket): TPort =
proc selectWrite*(writefds: var seq[TSocket], timeout = 500): int
proc acceptAddr*(server: TSocket): tuple[client: TSocket, address: string] =
## Blocks until a connection is being made from a client. When a connection
## is made sets ``client`` to the client socket and ``address`` to the address
## of the connecting client.
## If ``server`` is non-blocking then this function returns immediately, and
## if there are no connections queued the returned socket will be
## ``InvalidSocket``.
## This function will raise EOS if an error occurs.
##
## **Warning:** This function might block even if socket is non-blocking
## when using SSL.
template acceptAddrPlain(noClientRet, successRet: expr, sslImplementation: stmt): stmt =
assert(client != nil)
var sockAddress: Tsockaddr_in
var addrLen = sizeof(sockAddress).TSockLen
var sock = accept(server.fd, cast[ptr TSockAddr](addr(sockAddress)),
@@ -407,20 +396,56 @@ proc acceptAddr*(server: TSocket): tuple[client: TSocket, address: string] =
when defined(windows):
var err = WSAGetLastError()
if err == WSAEINPROGRESS:
return (InvalidSocket, "")
client = InvalidSocket
address = ""
when noClientRet.int == -1:
return
else:
return noClientRet
else: OSError()
else:
if errno == EAGAIN or errno == EWOULDBLOCK:
return (InvalidSocket, "")
client = InvalidSocket
address = ""
when noClientRet.int == -1:
return
else:
return noClientRet
else: OSError()
else:
else:
client.fd = sock
client.isBuffered = server.isBuffered
sslImplementation
# Client socket is set above.
address = $inet_ntoa(sockAddress.sin_addr)
when successRet.int == -1:
return
else:
return successRet
proc acceptAddr*(server: TSocket, client: var TSocket, address: var string) =
## Blocks until a connection is being made from a client. When a connection
## is made sets ``client`` to the client socket and ``address`` to the address
## of the connecting client.
## If ``server`` is non-blocking then this function returns immediately, and
## if there are no connections queued the returned socket will be
## ``InvalidSocket``.
## This function will raise EOS if an error occurs.
##
## The resulting client will inherit any properties of the server socket. For
## example: whether the socket is buffered or not.
##
## **Note**: ``client`` must be initialised, this function makes no effort to
## initialise the ``client`` variable.
##
## **Warning:** When using SSL with non-blocking sockets, it is best to use
## the acceptAddrAsync procedure as this procedure will most likely block.
acceptAddrPlain(-1, -1):
when defined(ssl):
if server.isSSL:
# We must wrap the client sock in a ssl context.
var client = newTSocket(sock, server.isBuffered)
let wo = server.wrapOptions
wrapSocket(client, wo.protVer, wo.verifyMode,
wo.certFile, wo.keyFile)
server.sslContext.wrapSocket(client)
let ret = SSLAccept(client.sslHandle)
while ret <= 0:
let err = SSLGetError(client.sslHandle, ret)
@@ -428,26 +453,93 @@ proc acceptAddr*(server: TSocket): tuple[client: TSocket, address: string] =
case err
of SSL_ERROR_ZERO_RETURN:
SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_CONNECT:
SSLError("The operation did not complete. Perhaps you should use connectAsync?")
of SSL_ERROR_WANT_ACCEPT:
var sss: seq[TSocket] = @[client]
discard selectWrite(sss, 1500)
continue
of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE,
SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
SSLError("Please use acceptAsync instead of accept.")
of SSL_ERROR_WANT_X509_LOOKUP:
SSLError("Function for x509 lookup has been called.")
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
SSLError()
else:
SSLError("Unknown error")
return (client, $inet_ntoa(sockAddress.sin_addr))
return (newTSocket(sock, server.isBuffered), $inet_ntoa(sockAddress.sin_addr))
proc accept*(server: TSocket): TSocket =
proc setBlocking*(s: TSocket, blocking: bool)
when defined(ssl):
proc acceptAddrSSL*(server: TSocket, client: var TSocket,
address: var string): TSSLAcceptResult =
## This procedure should only be used for non-blocking **SSL** sockets.
## It will immediatelly return with one of the following values:
##
## ``AcceptSuccess`` will be returned when a client has been successfully
## accepted and the handshake has been successfully performed between
## ``server`` and the newly connected client.
##
## ``AcceptNoHandshake`` will be returned when a client has been accepted
## but no handshake could be performed. This can happen when the client
## connects but does not yet initiate a handshake. In this case
## ``acceptAddrSSL`` should be called again with the same parameters.
##
## ``AcceptNoClient`` will be returned when no client is currently attempting
## to connect.
template doHandshake(): stmt =
when defined(ssl):
if server.isSSL:
client.setBlocking(false)
# We must wrap the client sock in a ssl context.
if not client.isSSL or client.sslHandle == nil:
server.sslContext.wrapSocket(client)
let ret = SSLAccept(client.sslHandle)
while ret <= 0:
let err = SSLGetError(client.sslHandle, ret)
if err != SSL_ERROR_WANT_ACCEPT:
case err
of SSL_ERROR_ZERO_RETURN:
SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE,
SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
client.sslNoHandshake = true
return AcceptNoHandshake
of SSL_ERROR_WANT_X509_LOOKUP:
SSLError("Function for x509 lookup has been called.")
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
SSLError()
else:
SSLError("Unknown error")
client.sslNoHandshake = false
if client.isSSL and client.sslNoHandshake:
doHandshake()
return AcceptSuccess
else:
acceptAddrPlain(AcceptNoClient, AcceptSuccess):
doHandshake()
proc accept*(server: TSocket, client: var TSocket) =
## Equivalent to ``acceptAddr`` but doesn't return the address, only the
## socket.
let (client, a) = acceptAddr(server)
return client
##
## **Note**: ``client`` must be initialised, this function makes no effort to
## initialise the ``client`` variable.
var addrDummy = ""
acceptAddr(server, client, addrDummy)
proc acceptAddr*(server: TSocket): tuple[client: TSocket, address: string] {.deprecated.} =
## Slightly different version of ``acceptAddr``.
##
## **Warning**: This function is now deprecated, you shouldn't use it!
var client: TSocket
new(client)
var address = ""
acceptAddr(server, client, address)
return (client, address)
proc accept*(server: TSocket): TSocket {.deprecated.} =
## **Warning**: This function is now deprecated, you shouldn't use it!
new(result)
var address = ""
acceptAddr(server, result, address)
proc close*(socket: TSocket) =
## closes a socket.
@@ -459,8 +551,6 @@ proc close*(socket: TSocket) =
when defined(ssl):
if socket.isSSL:
discard SSLShutdown(socket.sslHandle)
SSLCTXFree(socket.sslContext)
proc getServByName*(name, proto: string): TServent =
## well-known getservbyname proc.
@@ -492,11 +582,11 @@ proc getHostByAddr*(ip: string): THostEnt =
myaddr.s_addr = inet_addr(ip)
when defined(windows):
var s = winlean.gethostbyaddr(addr(myaddr), sizeof(myaddr).cint,
var s = winlean.gethostbyaddr(addr(myaddr), sizeof(myaddr).cuint,
cint(sockets.AF_INET))
if s == nil: OSError()
else:
var s = posix.gethostbyaddr(addr(myaddr), sizeof(myaddr).cint,
var s = posix.gethostbyaddr(addr(myaddr), sizeof(myaddr).cuint,
cint(posix.AF_INET))
if s == nil:
raise newException(EOS, $hStrError(h_errno))
@@ -539,7 +629,7 @@ proc getHostByName*(name: string): THostEnt =
proc getSockOptInt*(socket: TSocket, level, optname: int): int =
## getsockopt for integer options.
var res: cint
var size = sizeof(res).cint
var size = sizeof(res).cuint
if getsockopt(socket.fd, cint(level), cint(optname),
addr(res), addr(size)) < 0'i32:
OSError()
@@ -549,7 +639,7 @@ proc setSockOptInt*(socket: TSocket, level, optname, optval: int) =
## setsockopt for integer options.
var value = cint(optval)
if setsockopt(socket.fd, cint(level), cint(optname), addr(value),
sizeof(value).cint) < 0'i32:
sizeof(value).cuint) < 0'i32:
OSError()
proc connect*(socket: TSocket, name: string, port = TPort(0),
@@ -558,7 +648,8 @@ proc connect*(socket: TSocket, name: string, port = TPort(0),
## host name. If ``name`` is a host name, this function will try each IP
## of that host name. ``htons`` is already performed on ``port`` so you must
## not do it.
##
## If ``socket`` is an SSL socket a handshake will be automatically performed.
var hints: TAddrInfo
var aiList: ptr TAddrInfo = nil
hints.ai_family = toInt(af)
@@ -570,7 +661,7 @@ proc connect*(socket: TSocket, name: string, port = TPort(0),
var success = false
var it = aiList
while it != nil:
if connect(socket.fd, it.ai_addr, it.ai_addrlen.cint) == 0'i32:
if connect(socket.fd, it.ai_addr, it.ai_addrlen.cuint) == 0'i32:
success = true
break
it = it.ai_next
@@ -614,6 +705,13 @@ proc connect*(socket: TSocket, name: string, port = TPort(0),
proc connectAsync*(socket: TSocket, name: string, port = TPort(0),
af: TDomain = AF_INET) =
## A variant of ``connect`` for non-blocking sockets.
##
## This procedure will immediatelly return, it will not block until a connection
## is made. It is up to the caller to make sure the connections has been established
## by checking (using ``select``) whether the socket is writeable.
##
## **Note**: For SSL sockets, the ``handshake`` procedure must be called
## whenever the socket successfully connects to a server.
var hints: TAddrInfo
var aiList: ptr TAddrInfo = nil
hints.ai_family = toInt(af)
@@ -624,7 +722,7 @@ proc connectAsync*(socket: TSocket, name: string, port = TPort(0),
var success = false
var it = aiList
while it != nil:
var ret = connect(socket.fd, it.ai_addr, it.ai_addrlen.cint)
var ret = connect(socket.fd, it.ai_addr, it.ai_addrlen.cuint)
if ret == 0'i32:
success = true
break
@@ -634,19 +732,32 @@ proc connectAsync*(socket: TSocket, name: string, port = TPort(0),
var err = WSAGetLastError()
# Windows EINTR doesn't behave same as POSIX.
if err == WSAEWOULDBLOCK:
freeaddrinfo(aiList)
return
success = true
break
else:
if errno == EINTR or errno == EINPROGRESS:
freeaddrinfo(aiList)
return
success = true
break
it = it.ai_next
freeaddrinfo(aiList)
if not success: OSError()
when defined(ssl):
if socket.isSSL:
socket.sslNoHandshake = true
when defined(ssl):
proc handshake*(socket: TSocket): bool =
## This proc needs to be called on a socket after it connects. This is
## only applicable when using ``connectAsync``.
## This proc performs the SSL handshake.
##
## Returns ``False`` whenever the socket is not yet ready for a handshake,
## ``True`` whenever handshake completed successfully.
##
## A ESSL error is raised on any other errors.
result = true
if socket.isSSL:
var ret = SSLConnect(socket.sslHandle)
if ret <= 0:
@@ -654,17 +765,28 @@ proc connectAsync*(socket: TSocket, name: string, port = TPort(0),
case errret
of SSL_ERROR_ZERO_RETURN:
SSLError("TLS/SSL connection failed to initiate, socket closed prematurely.")
of SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE,
SSL_ERROR_WANT_ACCEPT:
SSLError("Unexpected error occured.") # This should just not happen.
of SSL_ERROR_WANT_CONNECT:
return
of SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT,
SSL_ERROR_WANT_READ, SSL_ERROR_WANT_WRITE:
return false
of SSL_ERROR_WANT_X509_LOOKUP:
SSLError("Function for x509 lookup has been called.")
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
SSLError()
else:
SSLError("Unknown Error")
socket.sslNoHandshake = false
else:
SSLError("Socket is not an SSL socket.")
proc gotHandshake*(socket: TSocket): bool =
## Determines whether a handshake has occurred between a client - ``socket``
## and the server that ``socket`` is connected to.
##
## Throws ESSL if ``socket`` is not an SSL socket.
if socket.isSSL:
return not socket.sslNoHandshake
else:
SSLError("Socket is not an SSL socket.")
proc timeValFromMilliseconds(timeout = 500): TTimeVal =
if timeout != -1:
@@ -694,6 +816,21 @@ proc pruneSocketSet(s: var seq[TSocket], fd: var TFdSet) =
inc(i)
setLen(s, L)
proc checkBuffer(readfds: var seq[TSocket]): int =
## Checks the buffer of each socket in ``readfds`` to see whether there is data.
## Removes the sockets from ``readfds`` and returns the count of removed sockets.
var res: seq[TSocket] = @[]
result = 0
for s in readfds:
if s.isBuffered:
if s.bufLen <= 0 or s.currPos == s.bufLen:
res.add(s)
else:
inc(result)
else:
res.add(s)
readfds = res
proc select*(readfds, writefds, exceptfds: var seq[TSocket],
timeout = 500): int =
## Traditional select function. This function will return the number of
@@ -702,6 +839,9 @@ proc select*(readfds, writefds, exceptfds: var seq[TSocket],
##
## You can determine whether a socket is ready by checking if it's still
## in one of the TSocket sequences.
let buffersFilled = checkBuffer(readfds)
if buffersFilled > 0:
return buffersFilled
var tv {.noInit.}: TTimeVal = timeValFromMilliseconds(timeout)
@@ -722,6 +862,9 @@ proc select*(readfds, writefds, exceptfds: var seq[TSocket],
proc select*(readfds, writefds: var seq[TSocket],
timeout = 500): int =
let buffersFilled = checkBuffer(readfds)
if buffersFilled > 0:
return buffersFilled
var tv {.noInit.}: TTimeVal = timeValFromMilliseconds(timeout)
var rd, wr: TFdSet
@@ -753,6 +896,9 @@ proc selectWrite*(writefds: var seq[TSocket],
pruneSocketSet(writefds, (wr))
proc select*(readfds: var seq[TSocket], timeout = 500): int =
let buffersFilled = checkBuffer(readfds)
if buffersFilled > 0:
return buffersFilled
var tv {.noInit.}: TTimeVal = timeValFromMilliseconds(timeout)
var rd: TFdSet
@@ -878,7 +1024,7 @@ proc peekChar(socket: TSocket, c: var char): int =
proc recvLine*(socket: TSocket, line: var TaintedString): bool =
## retrieves a line from ``socket``. If a full line is received ``\r\L`` is not
## added to ``line``, however if solely ``\r\L`` is received then ``data``
## added to ``line``, however if solely ``\r\L`` is received then ``line``
## will be set to it.
##
## ``True`` is returned if data is available. ``False`` usually suggests an
@@ -945,7 +1091,7 @@ proc recvLineAsync*(socket: TSocket, line: var TaintedString): TRecvLineResult =
## The values of the returned enum should be pretty self explanatory:
## If a full line has been retrieved; ``RecvFullLine`` is returned.
## If some data has been retrieved; ``RecvPartialLine`` is returned.
## If the socket has been disconnected; ``RecvDisconncted`` is returned.
## If the socket has been disconnected; ``RecvDisconnected`` is returned.
## If call to ``recv`` failed; ``RecvFail`` is returned.
setLen(line.string, 0)
while true:
@@ -1162,6 +1308,8 @@ proc connect*(socket: TSocket, timeout: int, name: string, port = TPort(0),
if selectWrite(s, timeout) != 1:
raise newException(ETimeout, "Call to connect() timed out.")
proc isSSL*(socket: TSocket): bool = return socket.isSSL
when defined(Windows):
var wsa: TWSADATA
if WSAStartup(0x0101'i16, wsa) != 0: OSError()

View File

@@ -407,7 +407,7 @@ type
ai_addr*: ptr TSockAddr ## Socket address of socket.
ai_next*: ptr TAddrInfo ## Pointer to next in list.
Tsocklen* = cint
Tsocklen* = cuint
var
SOMAXCONN* {.importc, header: "Winsock2.h".}: cint
@@ -418,7 +418,7 @@ proc getservbyname*(name, proto: cstring): ptr TServent {.
proc getservbyport*(port: cint, proto: cstring): ptr TServent {.
stdcall, importc: "getservbyport", dynlib: ws2dll.}
proc gethostbyaddr*(ip: ptr TInAddr, len: cint, theType: cint): ptr THostEnt {.
proc gethostbyaddr*(ip: ptr TInAddr, len: cuint, theType: cint): ptr THostEnt {.
stdcall, importc: "gethostbyaddr", dynlib: ws2dll.}
proc gethostbyname*(name: cstring): ptr THostEnt {.
@@ -430,20 +430,20 @@ proc socket*(af, typ, protocol: cint): TWinSocket {.
proc closesocket*(s: TWinSocket): cint {.
stdcall, importc: "closesocket", dynlib: ws2dll.}
proc accept*(s: TWinSocket, a: ptr TSockAddr, addrlen: ptr cint): TWinSocket {.
proc accept*(s: TWinSocket, a: ptr TSockAddr, addrlen: ptr cuint): TWinSocket {.
stdcall, importc: "accept", dynlib: ws2dll.}
proc bindSocket*(s: TWinSocket, name: ptr TSockAddr, namelen: cint): cint {.
proc bindSocket*(s: TWinSocket, name: ptr TSockAddr, namelen: cuint): cint {.
stdcall, importc: "bind", dynlib: ws2dll.}
proc connect*(s: TWinSocket, name: ptr TSockAddr, namelen: cint): cint {.
proc connect*(s: TWinSocket, name: ptr TSockAddr, namelen: cuint): cint {.
stdcall, importc: "connect", dynlib: ws2dll.}
proc getsockname*(s: TWinSocket, name: ptr TSockAddr,
namelen: ptr cint): cint {.
namelen: ptr cuint): cint {.
stdcall, importc: "getsockname", dynlib: ws2dll.}
proc getsockopt*(s: TWinSocket, level, optname: cint, optval: pointer,
optlen: ptr cint): cint {.
optlen: ptr cuint): cint {.
stdcall, importc: "getsockopt", dynlib: ws2dll.}
proc setsockopt*(s: TWinSocket, level, optname: cint, optval: pointer,
optlen: cint): cint {.
optlen: cuint): cint {.
stdcall, importc: "setsockopt", dynlib: ws2dll.}
proc listen*(s: TWinSocket, backlog: cint): cint {.

View File

@@ -215,7 +215,7 @@ proc SSL_get_verify_result*(ssl: PSSL): int{.cdecl,
proc SSL_CTX_set_cipher_list*(s: PSSLCTX, ciphers: cstring): cint{.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_CTX_use_certificate_file*(ctx: PSSL_CTX, filename: cstring, typ: cInt): cInt{.
cdecl, dynlib: DLLSSLName, importc.}
stdcall, dynlib: DLLSSLName, importc.}
proc SSL_CTX_use_PrivateKey_file*(ctx: PSSL_CTX,
filename: cstring, typ: cInt): cInt{.cdecl, dynlib: DLLSSLName, importc.}
proc SSL_CTX_check_private_key*(ctx: PSSL_CTX): cInt{.cdecl, dynlib: DLLSSLName,

89
tests/run/tasynciossl.nim Normal file
View File

@@ -0,0 +1,89 @@
discard """
file: "tasynciossl.nim"
cmd: "nimrod cc --hints:on --define:ssl $# $#"
output: "20000"
"""
import sockets, asyncio, strutils, times
var disp = newDispatcher()
var msgCount = 0
when defined(ssl):
var ctx = newContext(verifyMode = CVerifyNone,
certFile = "mycert.pem", keyFile = "mycert.pem")
var ctx1 = newContext(verifyMode = CVerifyNone)
const
swarmSize = 50
messagesToSend = 100
proc swarmConnect(s: PAsyncSocket, arg: PObject) {.nimcall.} =
#echo("Connected")
for i in 1..messagesToSend:
s.send("Message " & $i & "\c\L")
s.close()
proc serverRead(s: PAsyncSocket, arg: PObject) {.nimcall.} =
var line = ""
assert s.recvLine(line)
if line != "":
#echo(line)
if line.startsWith("Message "):
msgCount.inc()
else:
assert(false)
else:
s.close()
proc serverAccept(s: PAsyncSocket, arg: Pobject) {.nimcall.} =
var client: PAsyncSocket
new(client)
s.accept(client)
client.handleRead = serverRead
disp.register(client)
proc launchSwarm(disp: var PDispatcher, port: TPort, count: int,
buffered = true, useSSL = false) =
for i in 1..count:
var client = AsyncSocket()
when defined(ssl):
if useSSL:
ctx1.wrapSocket(client)
client.handleConnect = swarmConnect
disp.register(client)
client.connect("localhost", port)
proc createSwarm(port: TPort, buffered = true, useSSL = false) =
var server = AsyncSocket()
when defined(ssl):
if useSSL:
ctx.wrapSocket(server)
server.handleAccept = serverAccept
disp.register(server)
server.bindAddr(port)
server.listen()
disp.launchSwarm(port, swarmSize, buffered, useSSL)
when defined(ssl):
const serverCount = 4
else:
const serverCount = 2
createSwarm(TPort(10235))
createSwarm(TPort(10236), false)
when defined(ssl):
createSwarm(TPort(10237), true, true)
createSwarm(TPort(10238), false, true)
var startTime = epochTime()
while true:
if epochTime() - startTime >= 300.0:
break
if not disp.poll(): break
if disp.len == serverCount:
break
assert msgCount == (swarmSize * messagesToSend) * serverCount
echo(msgCount)