Add 'hostname' param to wrapConnectedSocket

This commit is contained in:
Ruslan Mustakov
2017-05-04 16:27:08 +07:00
parent 27b571dd95
commit e0059287bb
3 changed files with 195 additions and 180 deletions

View File

@@ -647,9 +647,12 @@ when defineSsl:
sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
proc wrapConnectedSocket*(ctx: SslContext, socket: AsyncSocket,
handshake: SslHandshakeType) =
handshake: SslHandshakeType,
hostname: string = nil) =
## Wraps a connected socket in an SSL context. This function effectively
## turns ``socket`` into an SSL socket.
## ``hostname`` should be specified so that the client knows which hostname
## the server certificate should be validated against.
##
## This should be called on a connected socket, and will perform
## an SSL handshake immediately.
@@ -660,6 +663,10 @@ when defineSsl:
case handshake
of handshakeAsClient:
if not hostname.isNil and not isIpAddress(hostname):
# Set the SNI address for this connection. This call can fail if
# we're not using TLSv1+.
discard SSL_set_tlsext_host_name(socket.sslHandle, hostname)
sslSetConnectState(socket.sslHandle)
of handshakeAsServer:
sslSetAcceptState(socket.sslHandle)

View File

@@ -512,7 +512,7 @@ proc request*(url: string, httpMethod: string, extraHeaders = "",
raise newException(HttpRequestError,
"The proxy server rejected a CONNECT request, " &
"so a secure connection could not be established.")
sslContext.wrapConnectedSocket(s, handshakeAsClient)
sslContext.wrapConnectedSocket(s, handshakeAsClient, hostUrl.hostname)
else:
raise newException(HttpRequestError, "SSL support not available. Cannot connect via proxy over SSL")
else:
@@ -1060,7 +1060,8 @@ proc newConnection(client: HttpClient | AsyncHttpClient,
when defined(ssl):
if isSsl:
try:
client.sslContext.wrapConnectedSocket(client.socket, handshakeAsClient)
client.sslContext.wrapConnectedSocket(
client.socket, handshakeAsClient, url.hostname)
except:
client.socket.close()
raise getCurrentException()
@@ -1102,7 +1103,8 @@ proc requestAux(client: HttpClient | AsyncHttpClient, url: string,
raise newException(HttpRequestError,
"The proxy server rejected a CONNECT request, " &
"so a secure connection could not be established.")
client.sslContext.wrapConnectedSocket(client.socket, handshakeAsClient)
client.sslContext.wrapConnectedSocket(
client.socket, handshakeAsClient, requestUrl.hostname)
client.proxy = nil
else:
raise newException(HttpRequestError,

View File

@@ -237,6 +237,180 @@ proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
raiseOSError(osLastError())
result = newSocket(fd, domain, sockType, protocol, buffered)
proc parseIPv4Address(address_str: string): IpAddress =
## Parses IPv4 adresses
## Raises EInvalidValue on errors
var
byteCount = 0
currentByte:uint16 = 0
seperatorValid = false
result.family = IpAddressFamily.IPv4
for i in 0 .. high(address_str):
if address_str[i] in strutils.Digits: # Character is a number
currentByte = currentByte * 10 +
cast[uint16](ord(address_str[i]) - ord('0'))
if currentByte > 255'u16:
raise newException(ValueError,
"Invalid IP Address. Value is out of range")
seperatorValid = true
elif address_str[i] == '.': # IPv4 address separator
if not seperatorValid or byteCount >= 3:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
result.address_v4[byteCount] = cast[uint8](currentByte)
currentByte = 0
byteCount.inc
seperatorValid = false
else:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid character")
if byteCount != 3 or not seperatorValid:
raise newException(ValueError, "Invalid IP Address")
result.address_v4[byteCount] = cast[uint8](currentByte)
proc parseIPv6Address(address_str: string): IpAddress =
## Parses IPv6 adresses
## Raises EInvalidValue on errors
result.family = IpAddressFamily.IPv6
if address_str.len < 2:
raise newException(ValueError, "Invalid IP Address")
var
groupCount = 0
currentGroupStart = 0
currentShort:uint32 = 0
seperatorValid = true
dualColonGroup = -1
lastWasColon = false
v4StartPos = -1
byteCount = 0
for i,c in address_str:
if c == ':':
if not seperatorValid:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid seperator")
if lastWasColon:
if dualColonGroup != -1:
raise newException(ValueError,
"Invalid IP Address. Address contains more than one \"::\" seperator")
dualColonGroup = groupCount
seperatorValid = false
elif i != 0 and i != high(address_str):
if groupCount >= 8:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
currentShort = 0
groupCount.inc()
if dualColonGroup != -1: seperatorValid = false
elif i == 0: # only valid if address starts with ::
if address_str[1] != ':':
raise newException(ValueError,
"Invalid IP Address. Address may not start with \":\"")
else: # i == high(address_str) - only valid if address ends with ::
if address_str[high(address_str)-1] != ':':
raise newException(ValueError,
"Invalid IP Address. Address may not end with \":\"")
lastWasColon = true
currentGroupStart = i + 1
elif c == '.': # Switch to parse IPv4 mode
if i < 3 or not seperatorValid or groupCount >= 7:
raise newException(ValueError, "Invalid IP Address")
v4StartPos = currentGroupStart
currentShort = 0
seperatorValid = false
break
elif c in strutils.HexDigits:
if c in strutils.Digits: # Normal digit
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
elif c >= 'a' and c <= 'f': # Lower case hex
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
else: # Upper case hex
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
if currentShort > 65535'u32:
raise newException(ValueError,
"Invalid IP Address. Value is out of range")
lastWasColon = false
seperatorValid = true
else:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid character")
if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
if seperatorValid: # Copy remaining data
if groupCount >= 8:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
groupCount.inc()
else: # Must parse IPv4 address
for i,c in address_str[v4StartPos..high(address_str)]:
if c in strutils.Digits: # Character is a number
currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
if currentShort > 255'u32:
raise newException(ValueError,
"Invalid IP Address. Value is out of range")
seperatorValid = true
elif c == '.': # IPv4 address separator
if not seperatorValid or byteCount >= 3:
raise newException(ValueError, "Invalid IP Address")
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
currentShort = 0
byteCount.inc()
seperatorValid = false
else: # Invalid character
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid character")
if byteCount != 3 or not seperatorValid:
raise newException(ValueError, "Invalid IP Address")
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
groupCount += 2
# Shift and fill zeros in case of ::
if groupCount > 8:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
elif groupCount < 8: # must fill
if dualColonGroup == -1:
raise newException(ValueError,
"Invalid IP Address. The address consists of too few groups")
var toFill = 8 - groupCount # The number of groups to fill
var toShift = groupCount - dualColonGroup # Nr of known groups after ::
for i in 0..2*toShift-1: # shift
result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
for i in 0..2*toFill-1: # fill with 0s
result.address_v6[dualColonGroup*2+i] = 0
elif dualColonGroup != -1:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
proc parseIpAddress*(address_str: string): IpAddress =
## Parses an IP address
## Raises EInvalidValue on error
if address_str == nil:
raise newException(ValueError, "IP Address string is nil")
if address_str.contains(':'):
return parseIPv6Address(address_str)
else:
return parseIPv4Address(address_str)
proc isIpAddress*(address_str: string): bool {.tags: [].} =
## Checks if a string is an IP address
## Returns true if it is, false otherwise
try:
discard parseIpAddress(address_str)
except ValueError:
return false
return true
when defineSsl:
CRYPTO_malloc_init()
SslLibraryInit()
@@ -438,9 +612,12 @@ when defineSsl:
raiseSSLError()
proc wrapConnectedSocket*(ctx: SSLContext, socket: Socket,
handshake: SslHandshakeType) =
handshake: SslHandshakeType,
hostname: string = nil) =
## Wraps a connected socket in an SSL context. This function effectively
## turns ``socket`` into an SSL socket.
## ``hostname`` should be specified so that the client knows which hostname
## the server certificate should be validated against.
##
## This should be called on a connected socket, and will perform
## an SSL handshake immediately.
@@ -450,6 +627,10 @@ when defineSsl:
wrapSocket(ctx, socket)
case handshake
of handshakeAsClient:
if not hostname.isNil and not isIpAddress(hostname):
# Discard result in case OpenSSL version doesn't support SNI, or we're
# not using TLSv1+
discard SSL_set_tlsext_host_name(socket.sslHandle, hostname)
let ret = SSLConnect(socket.sslHandle)
socketError(socket, ret)
of handshakeAsServer:
@@ -1302,181 +1483,6 @@ proc `$`*(address: IpAddress): string =
mask = mask shr 4
printedLastGroup = true
proc parseIPv4Address(address_str: string): IpAddress =
## Parses IPv4 adresses
## Raises EInvalidValue on errors
var
byteCount = 0
currentByte:uint16 = 0
seperatorValid = false
result.family = IpAddressFamily.IPv4
for i in 0 .. high(address_str):
if address_str[i] in strutils.Digits: # Character is a number
currentByte = currentByte * 10 +
cast[uint16](ord(address_str[i]) - ord('0'))
if currentByte > 255'u16:
raise newException(ValueError,
"Invalid IP Address. Value is out of range")
seperatorValid = true
elif address_str[i] == '.': # IPv4 address separator
if not seperatorValid or byteCount >= 3:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
result.address_v4[byteCount] = cast[uint8](currentByte)
currentByte = 0
byteCount.inc
seperatorValid = false
else:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid character")
if byteCount != 3 or not seperatorValid:
raise newException(ValueError, "Invalid IP Address")
result.address_v4[byteCount] = cast[uint8](currentByte)
proc parseIPv6Address(address_str: string): IpAddress =
## Parses IPv6 adresses
## Raises EInvalidValue on errors
result.family = IpAddressFamily.IPv6
if address_str.len < 2:
raise newException(ValueError, "Invalid IP Address")
var
groupCount = 0
currentGroupStart = 0
currentShort:uint32 = 0
seperatorValid = true
dualColonGroup = -1
lastWasColon = false
v4StartPos = -1
byteCount = 0
for i,c in address_str:
if c == ':':
if not seperatorValid:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid seperator")
if lastWasColon:
if dualColonGroup != -1:
raise newException(ValueError,
"Invalid IP Address. Address contains more than one \"::\" seperator")
dualColonGroup = groupCount
seperatorValid = false
elif i != 0 and i != high(address_str):
if groupCount >= 8:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
currentShort = 0
groupCount.inc()
if dualColonGroup != -1: seperatorValid = false
elif i == 0: # only valid if address starts with ::
if address_str[1] != ':':
raise newException(ValueError,
"Invalid IP Address. Address may not start with \":\"")
else: # i == high(address_str) - only valid if address ends with ::
if address_str[high(address_str)-1] != ':':
raise newException(ValueError,
"Invalid IP Address. Address may not end with \":\"")
lastWasColon = true
currentGroupStart = i + 1
elif c == '.': # Switch to parse IPv4 mode
if i < 3 or not seperatorValid or groupCount >= 7:
raise newException(ValueError, "Invalid IP Address")
v4StartPos = currentGroupStart
currentShort = 0
seperatorValid = false
break
elif c in strutils.HexDigits:
if c in strutils.Digits: # Normal digit
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
elif c >= 'a' and c <= 'f': # Lower case hex
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
else: # Upper case hex
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
if currentShort > 65535'u32:
raise newException(ValueError,
"Invalid IP Address. Value is out of range")
lastWasColon = false
seperatorValid = true
else:
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid character")
if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
if seperatorValid: # Copy remaining data
if groupCount >= 8:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
groupCount.inc()
else: # Must parse IPv4 address
for i,c in address_str[v4StartPos..high(address_str)]:
if c in strutils.Digits: # Character is a number
currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
if currentShort > 255'u32:
raise newException(ValueError,
"Invalid IP Address. Value is out of range")
seperatorValid = true
elif c == '.': # IPv4 address separator
if not seperatorValid or byteCount >= 3:
raise newException(ValueError, "Invalid IP Address")
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
currentShort = 0
byteCount.inc()
seperatorValid = false
else: # Invalid character
raise newException(ValueError,
"Invalid IP Address. Address contains an invalid character")
if byteCount != 3 or not seperatorValid:
raise newException(ValueError, "Invalid IP Address")
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
groupCount += 2
# Shift and fill zeros in case of ::
if groupCount > 8:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
elif groupCount < 8: # must fill
if dualColonGroup == -1:
raise newException(ValueError,
"Invalid IP Address. The address consists of too few groups")
var toFill = 8 - groupCount # The number of groups to fill
var toShift = groupCount - dualColonGroup # Nr of known groups after ::
for i in 0..2*toShift-1: # shift
result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
for i in 0..2*toFill-1: # fill with 0s
result.address_v6[dualColonGroup*2+i] = 0
elif dualColonGroup != -1:
raise newException(ValueError,
"Invalid IP Address. The address consists of too many groups")
proc parseIpAddress*(address_str: string): IpAddress =
## Parses an IP address
## Raises EInvalidValue on error
if address_str == nil:
raise newException(ValueError, "IP Address string is nil")
if address_str.contains(':'):
return parseIPv6Address(address_str)
else:
return parseIPv4Address(address_str)
proc isIpAddress*(address_str: string): bool {.tags: [].} =
## Checks if a string is an IP address
## Returns true if it is, false otherwise
try:
discard parseIpAddress(address_str)
except ValueError:
return false
return true
proc dial*(address: string, port: Port,
protocol = IPPROTO_TCP, buffered = true): Socket
{.tags: [ReadIOEffect, WriteIOEffect].} =