mirror of
https://github.com/nim-lang/Nim.git
synced 2025-12-29 17:34:43 +00:00
Add 'hostname' param to wrapConnectedSocket
This commit is contained in:
@@ -647,9 +647,12 @@ when defineSsl:
|
||||
sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
|
||||
|
||||
proc wrapConnectedSocket*(ctx: SslContext, socket: AsyncSocket,
|
||||
handshake: SslHandshakeType) =
|
||||
handshake: SslHandshakeType,
|
||||
hostname: string = nil) =
|
||||
## Wraps a connected socket in an SSL context. This function effectively
|
||||
## turns ``socket`` into an SSL socket.
|
||||
## ``hostname`` should be specified so that the client knows which hostname
|
||||
## the server certificate should be validated against.
|
||||
##
|
||||
## This should be called on a connected socket, and will perform
|
||||
## an SSL handshake immediately.
|
||||
@@ -660,6 +663,10 @@ when defineSsl:
|
||||
|
||||
case handshake
|
||||
of handshakeAsClient:
|
||||
if not hostname.isNil and not isIpAddress(hostname):
|
||||
# Set the SNI address for this connection. This call can fail if
|
||||
# we're not using TLSv1+.
|
||||
discard SSL_set_tlsext_host_name(socket.sslHandle, hostname)
|
||||
sslSetConnectState(socket.sslHandle)
|
||||
of handshakeAsServer:
|
||||
sslSetAcceptState(socket.sslHandle)
|
||||
|
||||
@@ -512,7 +512,7 @@ proc request*(url: string, httpMethod: string, extraHeaders = "",
|
||||
raise newException(HttpRequestError,
|
||||
"The proxy server rejected a CONNECT request, " &
|
||||
"so a secure connection could not be established.")
|
||||
sslContext.wrapConnectedSocket(s, handshakeAsClient)
|
||||
sslContext.wrapConnectedSocket(s, handshakeAsClient, hostUrl.hostname)
|
||||
else:
|
||||
raise newException(HttpRequestError, "SSL support not available. Cannot connect via proxy over SSL")
|
||||
else:
|
||||
@@ -1060,7 +1060,8 @@ proc newConnection(client: HttpClient | AsyncHttpClient,
|
||||
when defined(ssl):
|
||||
if isSsl:
|
||||
try:
|
||||
client.sslContext.wrapConnectedSocket(client.socket, handshakeAsClient)
|
||||
client.sslContext.wrapConnectedSocket(
|
||||
client.socket, handshakeAsClient, url.hostname)
|
||||
except:
|
||||
client.socket.close()
|
||||
raise getCurrentException()
|
||||
@@ -1102,7 +1103,8 @@ proc requestAux(client: HttpClient | AsyncHttpClient, url: string,
|
||||
raise newException(HttpRequestError,
|
||||
"The proxy server rejected a CONNECT request, " &
|
||||
"so a secure connection could not be established.")
|
||||
client.sslContext.wrapConnectedSocket(client.socket, handshakeAsClient)
|
||||
client.sslContext.wrapConnectedSocket(
|
||||
client.socket, handshakeAsClient, requestUrl.hostname)
|
||||
client.proxy = nil
|
||||
else:
|
||||
raise newException(HttpRequestError,
|
||||
|
||||
358
lib/pure/net.nim
358
lib/pure/net.nim
@@ -237,6 +237,180 @@ proc newSocket*(domain: Domain = AF_INET, sockType: SockType = SOCK_STREAM,
|
||||
raiseOSError(osLastError())
|
||||
result = newSocket(fd, domain, sockType, protocol, buffered)
|
||||
|
||||
proc parseIPv4Address(address_str: string): IpAddress =
|
||||
## Parses IPv4 adresses
|
||||
## Raises EInvalidValue on errors
|
||||
var
|
||||
byteCount = 0
|
||||
currentByte:uint16 = 0
|
||||
seperatorValid = false
|
||||
|
||||
result.family = IpAddressFamily.IPv4
|
||||
|
||||
for i in 0 .. high(address_str):
|
||||
if address_str[i] in strutils.Digits: # Character is a number
|
||||
currentByte = currentByte * 10 +
|
||||
cast[uint16](ord(address_str[i]) - ord('0'))
|
||||
if currentByte > 255'u16:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Value is out of range")
|
||||
seperatorValid = true
|
||||
elif address_str[i] == '.': # IPv4 address separator
|
||||
if not seperatorValid or byteCount >= 3:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
result.address_v4[byteCount] = cast[uint8](currentByte)
|
||||
currentByte = 0
|
||||
byteCount.inc
|
||||
seperatorValid = false
|
||||
else:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains an invalid character")
|
||||
|
||||
if byteCount != 3 or not seperatorValid:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
result.address_v4[byteCount] = cast[uint8](currentByte)
|
||||
|
||||
proc parseIPv6Address(address_str: string): IpAddress =
|
||||
## Parses IPv6 adresses
|
||||
## Raises EInvalidValue on errors
|
||||
result.family = IpAddressFamily.IPv6
|
||||
if address_str.len < 2:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
|
||||
var
|
||||
groupCount = 0
|
||||
currentGroupStart = 0
|
||||
currentShort:uint32 = 0
|
||||
seperatorValid = true
|
||||
dualColonGroup = -1
|
||||
lastWasColon = false
|
||||
v4StartPos = -1
|
||||
byteCount = 0
|
||||
|
||||
for i,c in address_str:
|
||||
if c == ':':
|
||||
if not seperatorValid:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains an invalid seperator")
|
||||
if lastWasColon:
|
||||
if dualColonGroup != -1:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains more than one \"::\" seperator")
|
||||
dualColonGroup = groupCount
|
||||
seperatorValid = false
|
||||
elif i != 0 and i != high(address_str):
|
||||
if groupCount >= 8:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
|
||||
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
|
||||
currentShort = 0
|
||||
groupCount.inc()
|
||||
if dualColonGroup != -1: seperatorValid = false
|
||||
elif i == 0: # only valid if address starts with ::
|
||||
if address_str[1] != ':':
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address may not start with \":\"")
|
||||
else: # i == high(address_str) - only valid if address ends with ::
|
||||
if address_str[high(address_str)-1] != ':':
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address may not end with \":\"")
|
||||
lastWasColon = true
|
||||
currentGroupStart = i + 1
|
||||
elif c == '.': # Switch to parse IPv4 mode
|
||||
if i < 3 or not seperatorValid or groupCount >= 7:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
v4StartPos = currentGroupStart
|
||||
currentShort = 0
|
||||
seperatorValid = false
|
||||
break
|
||||
elif c in strutils.HexDigits:
|
||||
if c in strutils.Digits: # Normal digit
|
||||
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
|
||||
elif c >= 'a' and c <= 'f': # Lower case hex
|
||||
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
|
||||
else: # Upper case hex
|
||||
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
|
||||
if currentShort > 65535'u32:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Value is out of range")
|
||||
lastWasColon = false
|
||||
seperatorValid = true
|
||||
else:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains an invalid character")
|
||||
|
||||
|
||||
if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
|
||||
if seperatorValid: # Copy remaining data
|
||||
if groupCount >= 8:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
|
||||
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
|
||||
groupCount.inc()
|
||||
else: # Must parse IPv4 address
|
||||
for i,c in address_str[v4StartPos..high(address_str)]:
|
||||
if c in strutils.Digits: # Character is a number
|
||||
currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
|
||||
if currentShort > 255'u32:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Value is out of range")
|
||||
seperatorValid = true
|
||||
elif c == '.': # IPv4 address separator
|
||||
if not seperatorValid or byteCount >= 3:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
|
||||
currentShort = 0
|
||||
byteCount.inc()
|
||||
seperatorValid = false
|
||||
else: # Invalid character
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains an invalid character")
|
||||
|
||||
if byteCount != 3 or not seperatorValid:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
|
||||
groupCount += 2
|
||||
|
||||
# Shift and fill zeros in case of ::
|
||||
if groupCount > 8:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
elif groupCount < 8: # must fill
|
||||
if dualColonGroup == -1:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too few groups")
|
||||
var toFill = 8 - groupCount # The number of groups to fill
|
||||
var toShift = groupCount - dualColonGroup # Nr of known groups after ::
|
||||
for i in 0..2*toShift-1: # shift
|
||||
result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
|
||||
for i in 0..2*toFill-1: # fill with 0s
|
||||
result.address_v6[dualColonGroup*2+i] = 0
|
||||
elif dualColonGroup != -1:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
|
||||
proc parseIpAddress*(address_str: string): IpAddress =
|
||||
## Parses an IP address
|
||||
## Raises EInvalidValue on error
|
||||
if address_str == nil:
|
||||
raise newException(ValueError, "IP Address string is nil")
|
||||
if address_str.contains(':'):
|
||||
return parseIPv6Address(address_str)
|
||||
else:
|
||||
return parseIPv4Address(address_str)
|
||||
|
||||
proc isIpAddress*(address_str: string): bool {.tags: [].} =
|
||||
## Checks if a string is an IP address
|
||||
## Returns true if it is, false otherwise
|
||||
try:
|
||||
discard parseIpAddress(address_str)
|
||||
except ValueError:
|
||||
return false
|
||||
return true
|
||||
|
||||
when defineSsl:
|
||||
CRYPTO_malloc_init()
|
||||
SslLibraryInit()
|
||||
@@ -438,9 +612,12 @@ when defineSsl:
|
||||
raiseSSLError()
|
||||
|
||||
proc wrapConnectedSocket*(ctx: SSLContext, socket: Socket,
|
||||
handshake: SslHandshakeType) =
|
||||
handshake: SslHandshakeType,
|
||||
hostname: string = nil) =
|
||||
## Wraps a connected socket in an SSL context. This function effectively
|
||||
## turns ``socket`` into an SSL socket.
|
||||
## ``hostname`` should be specified so that the client knows which hostname
|
||||
## the server certificate should be validated against.
|
||||
##
|
||||
## This should be called on a connected socket, and will perform
|
||||
## an SSL handshake immediately.
|
||||
@@ -450,6 +627,10 @@ when defineSsl:
|
||||
wrapSocket(ctx, socket)
|
||||
case handshake
|
||||
of handshakeAsClient:
|
||||
if not hostname.isNil and not isIpAddress(hostname):
|
||||
# Discard result in case OpenSSL version doesn't support SNI, or we're
|
||||
# not using TLSv1+
|
||||
discard SSL_set_tlsext_host_name(socket.sslHandle, hostname)
|
||||
let ret = SSLConnect(socket.sslHandle)
|
||||
socketError(socket, ret)
|
||||
of handshakeAsServer:
|
||||
@@ -1302,181 +1483,6 @@ proc `$`*(address: IpAddress): string =
|
||||
mask = mask shr 4
|
||||
printedLastGroup = true
|
||||
|
||||
proc parseIPv4Address(address_str: string): IpAddress =
|
||||
## Parses IPv4 adresses
|
||||
## Raises EInvalidValue on errors
|
||||
var
|
||||
byteCount = 0
|
||||
currentByte:uint16 = 0
|
||||
seperatorValid = false
|
||||
|
||||
result.family = IpAddressFamily.IPv4
|
||||
|
||||
for i in 0 .. high(address_str):
|
||||
if address_str[i] in strutils.Digits: # Character is a number
|
||||
currentByte = currentByte * 10 +
|
||||
cast[uint16](ord(address_str[i]) - ord('0'))
|
||||
if currentByte > 255'u16:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Value is out of range")
|
||||
seperatorValid = true
|
||||
elif address_str[i] == '.': # IPv4 address separator
|
||||
if not seperatorValid or byteCount >= 3:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
result.address_v4[byteCount] = cast[uint8](currentByte)
|
||||
currentByte = 0
|
||||
byteCount.inc
|
||||
seperatorValid = false
|
||||
else:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains an invalid character")
|
||||
|
||||
if byteCount != 3 or not seperatorValid:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
result.address_v4[byteCount] = cast[uint8](currentByte)
|
||||
|
||||
proc parseIPv6Address(address_str: string): IpAddress =
|
||||
## Parses IPv6 adresses
|
||||
## Raises EInvalidValue on errors
|
||||
result.family = IpAddressFamily.IPv6
|
||||
if address_str.len < 2:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
|
||||
var
|
||||
groupCount = 0
|
||||
currentGroupStart = 0
|
||||
currentShort:uint32 = 0
|
||||
seperatorValid = true
|
||||
dualColonGroup = -1
|
||||
lastWasColon = false
|
||||
v4StartPos = -1
|
||||
byteCount = 0
|
||||
|
||||
for i,c in address_str:
|
||||
if c == ':':
|
||||
if not seperatorValid:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains an invalid seperator")
|
||||
if lastWasColon:
|
||||
if dualColonGroup != -1:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains more than one \"::\" seperator")
|
||||
dualColonGroup = groupCount
|
||||
seperatorValid = false
|
||||
elif i != 0 and i != high(address_str):
|
||||
if groupCount >= 8:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
|
||||
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
|
||||
currentShort = 0
|
||||
groupCount.inc()
|
||||
if dualColonGroup != -1: seperatorValid = false
|
||||
elif i == 0: # only valid if address starts with ::
|
||||
if address_str[1] != ':':
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address may not start with \":\"")
|
||||
else: # i == high(address_str) - only valid if address ends with ::
|
||||
if address_str[high(address_str)-1] != ':':
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address may not end with \":\"")
|
||||
lastWasColon = true
|
||||
currentGroupStart = i + 1
|
||||
elif c == '.': # Switch to parse IPv4 mode
|
||||
if i < 3 or not seperatorValid or groupCount >= 7:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
v4StartPos = currentGroupStart
|
||||
currentShort = 0
|
||||
seperatorValid = false
|
||||
break
|
||||
elif c in strutils.HexDigits:
|
||||
if c in strutils.Digits: # Normal digit
|
||||
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('0'))
|
||||
elif c >= 'a' and c <= 'f': # Lower case hex
|
||||
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('a')) + 10
|
||||
else: # Upper case hex
|
||||
currentShort = (currentShort shl 4) + cast[uint32](ord(c) - ord('A')) + 10
|
||||
if currentShort > 65535'u32:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Value is out of range")
|
||||
lastWasColon = false
|
||||
seperatorValid = true
|
||||
else:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains an invalid character")
|
||||
|
||||
|
||||
if v4StartPos == -1: # Don't parse v4. Copy the remaining v6 stuff
|
||||
if seperatorValid: # Copy remaining data
|
||||
if groupCount >= 8:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
result.address_v6[groupCount*2] = cast[uint8](currentShort shr 8)
|
||||
result.address_v6[groupCount*2+1] = cast[uint8](currentShort and 0xFF)
|
||||
groupCount.inc()
|
||||
else: # Must parse IPv4 address
|
||||
for i,c in address_str[v4StartPos..high(address_str)]:
|
||||
if c in strutils.Digits: # Character is a number
|
||||
currentShort = currentShort * 10 + cast[uint32](ord(c) - ord('0'))
|
||||
if currentShort > 255'u32:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Value is out of range")
|
||||
seperatorValid = true
|
||||
elif c == '.': # IPv4 address separator
|
||||
if not seperatorValid or byteCount >= 3:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
|
||||
currentShort = 0
|
||||
byteCount.inc()
|
||||
seperatorValid = false
|
||||
else: # Invalid character
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. Address contains an invalid character")
|
||||
|
||||
if byteCount != 3 or not seperatorValid:
|
||||
raise newException(ValueError, "Invalid IP Address")
|
||||
result.address_v6[groupCount*2 + byteCount] = cast[uint8](currentShort)
|
||||
groupCount += 2
|
||||
|
||||
# Shift and fill zeros in case of ::
|
||||
if groupCount > 8:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
elif groupCount < 8: # must fill
|
||||
if dualColonGroup == -1:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too few groups")
|
||||
var toFill = 8 - groupCount # The number of groups to fill
|
||||
var toShift = groupCount - dualColonGroup # Nr of known groups after ::
|
||||
for i in 0..2*toShift-1: # shift
|
||||
result.address_v6[15-i] = result.address_v6[groupCount*2-i-1]
|
||||
for i in 0..2*toFill-1: # fill with 0s
|
||||
result.address_v6[dualColonGroup*2+i] = 0
|
||||
elif dualColonGroup != -1:
|
||||
raise newException(ValueError,
|
||||
"Invalid IP Address. The address consists of too many groups")
|
||||
|
||||
|
||||
proc parseIpAddress*(address_str: string): IpAddress =
|
||||
## Parses an IP address
|
||||
## Raises EInvalidValue on error
|
||||
if address_str == nil:
|
||||
raise newException(ValueError, "IP Address string is nil")
|
||||
if address_str.contains(':'):
|
||||
return parseIPv6Address(address_str)
|
||||
else:
|
||||
return parseIPv4Address(address_str)
|
||||
|
||||
proc isIpAddress*(address_str: string): bool {.tags: [].} =
|
||||
## Checks if a string is an IP address
|
||||
## Returns true if it is, false otherwise
|
||||
try:
|
||||
discard parseIpAddress(address_str)
|
||||
except ValueError:
|
||||
return false
|
||||
return true
|
||||
|
||||
proc dial*(address: string, port: Port,
|
||||
protocol = IPPROTO_TCP, buffered = true): Socket
|
||||
{.tags: [ReadIOEffect, WriteIOEffect].} =
|
||||
|
||||
Reference in New Issue
Block a user