fixes #22210; transform return future in try/finally properly (#22249)

* wip; fixes #22210; transform return future in try/finally properly

* add a test case for #22210

* minor

* inserts a needsCompletion flag

* uses copyNimNode
This commit is contained in:
ringabout
2023-07-21 11:40:11 +08:00
committed by GitHub
parent 2f817ee5b4
commit 91987f8eb5
3 changed files with 164 additions and 11 deletions

View File

@@ -11,6 +11,11 @@
import macros, strutils, asyncfutures
type
Context = ref object
inTry: int
hasRet: bool
# TODO: Ref https://github.com/nim-lang/Nim/issues/5617
# TODO: Add more line infos
proc newCallWithLineInfo(fromNode: NimNode; theProc: NimNode, args: varargs[NimNode]): NimNode =
@@ -63,7 +68,7 @@ proc createFutureVarCompletions(futureVarIdents: seq[NimNode], fromNode: NimNode
)
)
proc processBody(node, retFutureSym: NimNode, futureVarIdents: seq[NimNode]): NimNode =
proc processBody(ctx: Context; node, needsCompletionSym, retFutureSym: NimNode, futureVarIdents: seq[NimNode]): NimNode =
result = node
case node.kind
of nnkReturnStmt:
@@ -72,23 +77,53 @@ proc processBody(node, retFutureSym: NimNode, futureVarIdents: seq[NimNode]): Ni
# As I've painfully found out, the order here really DOES matter.
result.add createFutureVarCompletions(futureVarIdents, node)
ctx.hasRet = true
if node[0].kind == nnkEmpty:
result.add newCall(newIdentNode("complete"), retFutureSym, newIdentNode("result"))
else:
let x = node[0].processBody(retFutureSym, futureVarIdents)
if x.kind == nnkYieldStmt: result.add x
if ctx.inTry == 0:
result.add newCallWithLineInfo(node, newIdentNode("complete"), retFutureSym, newIdentNode("result"))
else:
result.add newCall(newIdentNode("complete"), retFutureSym, x)
result.add newAssignment(needsCompletionSym, newLit(true))
else:
let x = processBody(ctx, node[0], needsCompletionSym, retFutureSym, futureVarIdents)
if x.kind == nnkYieldStmt: result.add x
elif ctx.inTry == 0:
result.add newCallWithLineInfo(node, newIdentNode("complete"), retFutureSym, x)
else:
result.add newAssignment(newIdentNode("result"), x)
result.add newAssignment(needsCompletionSym, newLit(true))
result.add newNimNode(nnkReturnStmt, node).add(newNilLit())
return # Don't process the children of this return stmt
of RoutineNodes-{nnkTemplateDef}:
# skip all the nested procedure definitions
return
else: discard
for i in 0 ..< result.len:
result[i] = processBody(result[i], retFutureSym, futureVarIdents)
of nnkTryStmt:
if result[^1].kind == nnkFinally:
inc ctx.inTry
result[0] = processBody(ctx, result[0], needsCompletionSym, retFutureSym, futureVarIdents)
dec ctx.inTry
for i in 1 ..< result.len:
result[i] = processBody(ctx, result[i], needsCompletionSym, retFutureSym, futureVarIdents)
if ctx.inTry == 0 and ctx.hasRet:
let finallyNode = copyNimNode(result[^1])
let stmtNode = newNimNode(nnkStmtList)
for child in result[^1]:
stmtNode.add child
stmtNode.add newIfStmt(
( needsCompletionSym,
newCallWithLineInfo(node, newIdentNode("complete"), retFutureSym,
newIdentNode("result")
)
)
)
finallyNode.add stmtNode
result[^1] = finallyNode
else:
for i in 0 ..< result.len:
result[i] = processBody(ctx, result[i], needsCompletionSym, retFutureSym, futureVarIdents)
else:
for i in 0 ..< result.len:
result[i] = processBody(ctx, result[i], needsCompletionSym, retFutureSym, futureVarIdents)
# echo result.repr
@@ -213,7 +248,9 @@ proc asyncSingleProc(prc: NimNode): NimNode =
# -> <proc_body>
# -> complete(retFuture, result)
var iteratorNameSym = genSym(nskIterator, $prcName & " (Async)")
var procBody = prc.body.processBody(retFutureSym, futureVarIdents)
var needsCompletionSym = genSym(nskVar, "needsCompletion")
var ctx = Context()
var procBody = processBody(ctx, prc.body, needsCompletionSym, retFutureSym, futureVarIdents)
# don't do anything with forward bodies (empty)
if procBody.kind != nnkEmpty:
# fix #13899, defer should not escape its original scope
@@ -234,6 +271,8 @@ proc asyncSingleProc(prc: NimNode): NimNode =
else:
var `resultIdent`: Future[void]
{.pop.}
var `needsCompletionSym` = false
procBody.add quote do:
complete(`retFutureSym`, `resultIdent`)

41
tests/async/t22210.nim Normal file
View File

@@ -0,0 +1,41 @@
discard """
output: '''
stage 1
stage 2
stage 3
(status: 200, data: "SOMEDATA")
'''
"""
import std/asyncdispatch
# bug #22210
type
ClientResponse = object
status*: int
data*: string
proc subFoo1(): Future[int] {.async.} =
await sleepAsync(100)
return 200
proc subFoo2(): Future[string] {.async.} =
await sleepAsync(100)
return "SOMEDATA"
proc testFoo(): Future[ClientResponse] {.async.} =
try:
let status = await subFoo1()
doAssert(status == 200)
let data = await subFoo2()
return ClientResponse(status: status, data: data)
finally:
echo "stage 1"
await sleepAsync(100)
echo "stage 2"
await sleepAsync(200)
echo "stage 3"
when isMainModule:
echo waitFor testFoo()

73
tests/async/t22210_2.nim Normal file
View File

@@ -0,0 +1,73 @@
import std/asyncdispatch
# bug #22210
type
ClientResponse = object
status*: int
data*: string
proc subFoo1(): Future[int] {.async.} =
await sleepAsync(100)
return 200
proc subFoo2(): Future[string] {.async.} =
await sleepAsync(100)
return "SOMEDATA"
proc testFoo2(): Future[ClientResponse] {.async.} =
var flag = 0
try:
let status = await subFoo1()
doAssert(status == 200)
let data = await subFoo2()
result = ClientResponse(status: status, data: data)
finally:
inc flag
await sleepAsync(100)
inc flag
await sleepAsync(200)
inc flag
doAssert flag == 3
discard waitFor testFoo2()
proc testFoo3(): Future[ClientResponse] {.async.} =
var flag = 0
try:
let status = await subFoo1()
doAssert(status == 200)
let data = await subFoo2()
if false:
return ClientResponse(status: status, data: data)
finally:
inc flag
await sleepAsync(100)
inc flag
await sleepAsync(200)
inc flag
doAssert flag == 3
discard waitFor testFoo3()
proc testFoo4(): Future[ClientResponse] {.async.} =
var flag = 0
try:
let status = await subFoo1()
doAssert(status == 200)
let data = await subFoo2()
if status == 200:
return ClientResponse(status: status, data: data)
else:
return ClientResponse()
finally:
inc flag
await sleepAsync(100)
inc flag
await sleepAsync(200)
inc flag
doAssert flag == 3
discard waitFor testFoo4()