diff --git a/compiler/parampatterns.nim b/compiler/parampatterns.nim index e8ec22fe1c..66b54a74ae 100644 --- a/compiler/parampatterns.nim +++ b/compiler/parampatterns.nim @@ -13,7 +13,7 @@ import ast, types, msgs, idents, renderer, wordrecg, trees, options -import std/strutils +import std/[strutils, assertions] # we precompile the pattern here for efficiency into some internal # stack based VM :-) Why? Because it's fun; I did no benchmarks to see if that @@ -216,6 +216,11 @@ proc exprRoot*(n: PNode; allowCalls = true): PSym = else: break +proc isAssignable*(owner: PSym, n: PNode): TAssignableResult + +proc isLentableBranch(owner: PSym, n: PNode): bool = + result = isAssignable(owner, n) in {arLentValue, arAddressableConst, arLentValue} + proc isAssignable*(owner: PSym, n: PNode): TAssignableResult = ## 'owner' can be nil! result = arNone @@ -308,6 +313,35 @@ proc isAssignable*(owner: PSym, n: PNode): TAssignableResult = # nkVarTy denotes an lvalue, but the example above is the only # possible code which will get us here result = arLValue + of nkIfExpr, nkIfStmt: + # allow 'if' expressions to be lent if all branches are lentable + for branch in n: + if branch.len == 2: + if not isLentableBranch(owner, branch[1]): + return + elif branch.len == 1: + if not isLentableBranch(owner, branch[0]): + return + else: + raiseAssert "Malformed `if` statement in isAssignable" + result = arLentValue + of nkCaseStmt: + # allow 'case' expressions to be lent if all branches are lentable + for i in 1 ..< n.len: + let branch = n[i] + case branch.kind + of nkOfBranch: + if not isLentableBranch(owner, branch[^1]): + return + of nkElifBranch: + if not isLentableBranch(owner, branch[1]): + return + of nkElse: + if not isLentableBranch(owner, branch[0]): + return + else: + raiseAssert "Malformed `case` statement in isAssignable" + result = arLentValue else: discard diff --git a/tests/lent/tlents.nim b/tests/lent/tlents.nim new file mode 100644 index 0000000000..28fe0602ed --- /dev/null +++ b/tests/lent/tlents.nim @@ -0,0 +1,25 @@ +discard """ + targets: "c cpp" +""" + +type A = object + field: int + +proc x(a: A): lent int = + result = case true + of true: + a.field + of false: + a.field + +proc y(a: A): lent int = + result = if true: + a.field + else: + a.field + +block: + var a = A(field: 1) + doAssert x(a) == 1 + doAssert y(a) == 1 +