From 06d50bfd4cac6611c5fc775c8130cad87e4dcee1 Mon Sep 17 00:00:00 2001 From: Dylan Modesitt Date: Mon, 31 May 2021 16:51:32 -0400 Subject: [PATCH] Fixes #5034 illformed AST from getImpl with proc returning value (#17976) * Fixes 5034 * address comments --- compiler/semstmts.nim | 25 +++++++++++++++++++++---- tests/macros/tmacrogetimpl.nim | 31 +++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 tests/macros/tmacrogetimpl.nim diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim index b4630294ba..085146f7ff 100644 --- a/compiler/semstmts.nim +++ b/compiler/semstmts.nim @@ -1439,16 +1439,33 @@ proc semBorrow(c: PContext, n: PNode, s: PSym) = else: localError(c.config, n.info, errNoSymbolToBorrowFromFound) +proc swapResult(n: PNode, sRes: PSym, dNode: PNode) = + ## Swap nodes that are (skResult) symbols to d(estination)Node. + for i in 0.. resultPos and n[resultPos] != nil: - if n[resultPos].sym.kind != skResult or n[resultPos].sym.owner != getCurrOwner(c): + if n[resultPos].sym.kind != skResult: localError(c.config, n.info, "incorrect result proc symbol") + if n[resultPos].sym.owner != getCurrOwner(c): + # re-write result with new ownership, and re-write the proc accordingly + let sResSym = n[resultPos].sym + genResSym(s) + n[resultPos] = newSymNode(s) + swapResult(n, sResSym, n[resultPos]) c.p.resultSym = n[resultPos].sym else: - var s = newSym(skResult, getIdent(c.cache, "result"), nextSymId c.idgen, getCurrOwner(c), n.info) - s.typ = t - incl(s.flags, sfUsed) + genResSym(s) c.p.resultSym = s n.add newSymNode(c.p.resultSym) addParamOrResult(c, c.p.resultSym, owner) diff --git a/tests/macros/tmacrogetimpl.nim b/tests/macros/tmacrogetimpl.nim new file mode 100644 index 0000000000..1d996ff295 --- /dev/null +++ b/tests/macros/tmacrogetimpl.nim @@ -0,0 +1,31 @@ +import macros + +# bug #5034 + +macro copyImpl(srsProc: typed, toSym: untyped) = + result = copyNimTree(getImplTransformed(srsProc)) + result[0] = ident $toSym.toStrLit() + +proc foo1(x: float, one: bool = true): float = + if one: + return 1'f + result = x + +proc bar1(what: string): string = + ## be a little more adversarial with `skResult` + proc buzz: string = + result = "lightyear" + if what == "buzz": + result = "buzz " & buzz() + else: + result = what + return result + +copyImpl(foo1, foo2) +doAssert foo1(1'f) == 1.0 +doAssert foo2(10.0, false) == 10.0 +doAssert foo2(10.0) == 1.0 + +copyImpl(bar1, bar2) +doAssert bar1("buzz") == "buzz lightyear" +doAssert bar1("macros") == "macros"