fixes branches interacting with break, raise etc. in strictdefs (#22627)

```nim
{.experimental: "strictdefs".}

type Test = object
  id: int

proc test(): Test =
  if true:
    return Test()
  else:
    return
echo test()
```

I will tackle https://github.com/nim-lang/Nim/issues/16735 and #21615 in
the following PR.


The old code just premises that in branches ended with returns, raise
statements etc. , all variables including the result variable are
initialized for that branch. It's true for noreturn statements. But it
is false for the result variable in a branch tailing with a return
statement, in which the result variable is not initialized. The solution
is not perfect for usages below branch statements with the result
variable uninitialized, but it should suffice for now, which gives a
proper warning.

It also fixes

```nim

{.experimental: "strictdefs".}

type Test = object
  id: int

proc foo {.noreturn.} = discard

proc test9(x: bool): Test =
  if x:
    foo()
  else:
    foo()
```
which gives a warning, but shouldn't
This commit is contained in:
ringabout
2023-09-04 20:36:45 +08:00
committed by GitHub
parent c5495f40d5
commit d13aab50cf
4 changed files with 226 additions and 20 deletions

View File

@@ -596,7 +596,7 @@ proc lookUp*(c: PContext, n: PNode): PSym =
if result == nil: result = errorUndeclaredIdentifierHint(c, n, ident)
else:
internalError(c.config, n.info, "lookUp")
return
return nil
if amb:
#contains(c.ambiguousSymbols, result.id):
result = errorUseQualifier(c, n.info, result, amb)

View File

@@ -352,20 +352,25 @@ proc useVar(a: PEffects, n: PNode) =
a.init.add s.id
useVarNoInitCheck(a, n, s)
type
BreakState = enum
bsNone
bsBreakOrReturn
bsNoReturn
type
TIntersection = seq[tuple[id, count: int]] # a simple count table
proc addToIntersection(inter: var TIntersection, s: int, initOnly: bool) =
proc addToIntersection(inter: var TIntersection, s: int, state: BreakState) =
for j in 0..<inter.len:
if s == inter[j].id:
if not initOnly:
if state == bsNone:
inc inter[j].count
return
if initOnly:
inter.add((id: s, count: 0))
else:
if state == bsNone:
inter.add((id: s, count: 1))
else:
inter.add((id: s, count: 0))
proc throws(tracked, n, orig: PNode) =
if n.typ == nil or n.typ.kind != tyError:
@@ -469,7 +474,7 @@ proc trackTryStmt(tracked: PEffects, n: PNode) =
track(tracked, n[0])
dec tracked.inTryStmt
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], false)
addToIntersection(inter, tracked.init[i], bsNone)
var branches = 1
var hasFinally = false
@@ -504,7 +509,7 @@ proc trackTryStmt(tracked: PEffects, n: PNode) =
tracked.init.add b[j][2].sym.id
track(tracked, b[^1])
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], false)
addToIntersection(inter, tracked.init[i], bsNone)
else:
setLen(tracked.init, oldState)
track(tracked, b[^1])
@@ -673,15 +678,50 @@ proc trackOperandForIndirectCall(tracked: PEffects, n: PNode, formals: PType; ar
localError(tracked.config, n.info, $n & " is not GC safe")
notNilCheck(tracked, n, paramType)
proc breaksBlock(n: PNode): bool =
proc breaksBlock(n: PNode): BreakState =
# semantic check doesn't allow statements after raise, break, return or
# call to noreturn proc, so it is safe to check just the last statements
var it = n
while it.kind in {nkStmtList, nkStmtListExpr} and it.len > 0:
it = it.lastSon
result = it.kind in {nkBreakStmt, nkReturnStmt, nkRaiseStmt} or
it.kind in nkCallKinds and it[0].kind == nkSym and sfNoReturn in it[0].sym.flags
case it.kind
of nkBreakStmt, nkReturnStmt:
result = bsBreakOrReturn
of nkRaiseStmt:
result = bsNoReturn
of nkCallKinds:
if it[0].kind == nkSym and sfNoReturn in it[0].sym.flags:
result = bsNoReturn
else:
result = bsNone
else:
result = bsNone
proc addIdToIntersection(tracked: PEffects, inter: var TIntersection, resCounter: var int,
hasBreaksBlock: BreakState, oldState: int, resSym: PSym, hasResult: bool) =
if hasResult:
var alreadySatisfy = false
if hasBreaksBlock == bsNoReturn:
alreadySatisfy = true
inc resCounter
for i in oldState..<tracked.init.len:
if tracked.init[i] == resSym.id:
if not alreadySatisfy:
inc resCounter
alreadySatisfy = true
else:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)
else:
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)
template hasResultSym(s: PSym): bool =
s != nil and s.kind in {skProc, skFunc, skConverter, skMethod} and
not isEmptyType(s.typ[0])
proc trackCase(tracked: PEffects, n: PNode) =
track(tracked, n[0])
@@ -694,6 +734,10 @@ proc trackCase(tracked: PEffects, n: PNode) =
(tracked.config.hasWarn(warnProveField) or strictCaseObjects in tracked.c.features)
var inter: TIntersection = @[]
var toCover = 0
let hasResult = hasResultSym(tracked.owner)
let resSym = if hasResult: tracked.owner.ast[resultPos].sym else: nil
var resCounter = 0
for i in 1..<n.len:
let branch = n[i]
setLen(tracked.init, oldState)
@@ -703,13 +747,14 @@ proc trackCase(tracked: PEffects, n: PNode) =
for i in 0..<branch.len:
track(tracked, branch[i])
let hasBreaksBlock = breaksBlock(branch.lastSon)
if not hasBreaksBlock:
if hasBreaksBlock == bsNone:
inc toCover
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)
addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)
setLen(tracked.init, oldState)
if not stringCase or lastSon(n).kind == nkElse:
if hasResult and resCounter == n.len-1:
tracked.init.add resSym.id
for id, count in items(inter):
if count >= toCover: tracked.init.add id
# else we can't merge
@@ -723,14 +768,17 @@ proc trackIf(tracked: PEffects, n: PNode) =
addFact(tracked.guards, n[0][0])
let oldState = tracked.init.len
let hasResult = hasResultSym(tracked.owner)
let resSym = if hasResult: tracked.owner.ast[resultPos].sym else: nil
var resCounter = 0
var inter: TIntersection = @[]
var toCover = 0
track(tracked, n[0][1])
let hasBreaksBlock = breaksBlock(n[0][1])
if not hasBreaksBlock:
if hasBreaksBlock == bsNone:
inc toCover
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)
addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)
for i in 1..<n.len:
let branch = n[i]
@@ -743,13 +791,14 @@ proc trackIf(tracked: PEffects, n: PNode) =
for i in 0..<branch.len:
track(tracked, branch[i])
let hasBreaksBlock = breaksBlock(branch.lastSon)
if not hasBreaksBlock:
if hasBreaksBlock == bsNone:
inc toCover
for i in oldState..<tracked.init.len:
addToIntersection(inter, tracked.init[i], hasBreaksBlock)
addIdToIntersection(tracked, inter, resCounter, hasBreaksBlock, oldState, resSym, hasResult)
setLen(tracked.init, oldState)
if lastSon(n).len == 1:
if hasResult and resCounter == n.len:
tracked.init.add resSym.id
for id, count in items(inter):
if count >= toCover: tracked.init.add id
# else we can't merge as it is not exhaustive

64
tests/init/tcompiles.nim Normal file
View File

@@ -0,0 +1,64 @@
discard """
matrix: "--warningAsError:ProveInit --warningAsError:Uninit"
"""
{.experimental: "strictdefs".}
type Test = object
id: int
proc foo {.noreturn.} = discard
block:
proc test(x: bool): Test =
if x:
foo()
else:
foo()
block:
proc test(x: bool): Test =
if x:
result = Test()
else:
foo()
discard test(true)
block:
proc test(x: bool): Test =
if x:
result = Test()
else:
return Test()
discard test(true)
block:
proc test(x: bool): Test =
if x:
return Test()
else:
return Test()
discard test(true)
block:
proc test(x: bool): Test =
if x:
result = Test()
else:
result = Test()
return
discard test(true)
block:
proc test(x: bool): Test =
if x:
result = Test()
return
else:
raise newException(ValueError, "unreachable")
discard test(true)

93
tests/init/treturns.nim Normal file
View File

@@ -0,0 +1,93 @@
{.experimental: "strictdefs".}
type Test = object
id: int
proc foo {.noreturn.} = discard
proc test1(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return Test()
else:
return
proc test0(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return
else:
foo()
proc test2(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return
else:
return
proc test3(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return
else:
return Test()
proc test4(): Test =
if true: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return
else:
result = Test()
return
proc test5(x: bool): Test =
case x: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
of true:
return
else:
return Test()
proc test6(x: bool): Test =
case x: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
of true:
return
else:
return
proc test7(x: bool): Test =
case x: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
of true:
return
else:
discard
proc test8(x: bool): Test =
case x: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
of true:
discard
else:
raise
proc hasImportStmt(): bool =
if false: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return true
else:
discard
discard hasImportStmt()
block:
proc hasImportStmt(): bool =
if false: #[tt.Warning
^ Cannot prove that 'result' is initialized. This will become a compile time error in the future. [ProveInit]]#
return true
else:
return
discard hasImportStmt()