allow referencing other parameters in default parameter values

fix #7756
fix #1201
fix #7000
fix #3002
fix #1046
This commit is contained in:
Zahary Karadjov
2018-06-16 03:51:31 +03:00
parent e719f211c6
commit 31651ecd61
10 changed files with 244 additions and 31 deletions

View File

@@ -293,6 +293,10 @@ const
# the compiler will avoid printing such names
# in user messages.
sfHoisted* = sfForward
# an expression was hoised to an anonymous variable.
# the flag is applied to the var/let symbol
sfNoForward* = sfRegister
# forward declarations are not required (per module)
sfReorder* = sfForward
@@ -455,6 +459,8 @@ type
nfBlockArg # this a stmtlist appearing in a call (e.g. a do block)
nfFromTemplate # a top-level node returned from a template
nfDefaultParam # an automatically inserter default parameter
nfDefaultRefsParam # a default param value references another parameter
# the flag is applied to proc default values and to calls
TNodeFlags* = set[TNodeFlag]
TTypeFlag* = enum # keep below 32 for efficiency reasons (now: beyond that)
@@ -972,7 +978,7 @@ const
PersistentNodeFlags*: TNodeFlags = {nfBase2, nfBase8, nfBase16,
nfDotSetter, nfDotField,
nfIsRef, nfPreventCg, nfLL,
nfFromTemplate}
nfFromTemplate, nfDefaultRefsParam}
namePos* = 0
patternPos* = 1 # empty except for term rewriting macros
genericParamsPos* = 2

View File

@@ -34,32 +34,36 @@ when declared(echo):
template debug*(x: PSym|PType|PNode) {.deprecated.} =
when compiles(c.config):
debug(c.config, x)
elif compiles(c.graph.config):
debug(c.graph.config, x)
else:
error()
template debug*(x: auto) {.deprecated.} =
echo x
template mdbg*: bool {.dirty.} =
when compiles(c.module):
c.module.fileIdx == c.config.projectMainIdx
elif compiles(c.c.module):
c.c.module.fileIdx == c.c.config.projectMainIdx
elif compiles(m.c.module):
m.c.module.fileIdx == m.c.config.projectMainIdx
elif compiles(cl.c.module):
cl.c.module.fileIdx == cl.c.config.projectMainIdx
elif compiles(p):
when compiles(p.lex):
p.lex.fileIdx == p.lex.config.projectMainIdx
template mdbg*: bool {.deprecated.} =
when compiles(c.graph):
c.module.fileIdx == c.graph.config.projectMainIdx
elif compiles(c.module):
c.module.fileIdx == c.config.projectMainIdx
elif compiles(c.c.module):
c.c.module.fileIdx == c.c.config.projectMainIdx
elif compiles(m.c.module):
m.c.module.fileIdx == m.c.config.projectMainIdx
elif compiles(cl.c.module):
cl.c.module.fileIdx == cl.c.config.projectMainIdx
elif compiles(p):
when compiles(p.lex):
p.lex.fileIdx == p.lex.config.projectMainIdx
else:
p.module.module.fileIdx == p.config.projectMainIdx
elif compiles(m.module.fileIdx):
m.module.fileIdx == m.config.projectMainIdx
elif compiles(L.fileIdx):
L.fileIdx == L.config.projectMainIdx
else:
p.module.module.fileIdx == p.config.projectMainIdx
elif compiles(m.module.fileIdx):
m.module.fileIdx == m.config.projectMainIdx
elif compiles(L.fileIdx):
L.fileIdx == L.config.projectMainIdx
else:
error()
error()
# --------------------------- ident tables ----------------------------------
proc idTableGet*(t: TIdTable, key: PIdObj): RootRef

View File

@@ -336,6 +336,18 @@ proc typeNeedsNoDeepCopy(t: PType): bool =
if t.kind in {tyVar, tyLent, tySequence}: t = t.lastSon
result = not containsGarbageCollectedRef(t)
proc hoistExpr*(varSection, expr: PNode, name: PIdent, owner: PSym): PSym =
result = newSym(skLet, name, owner, varSection.info, owner.options)
result.flags.incl sfHoisted
result.typ = expr.typ
var varDef = newNodeI(nkIdentDefs, varSection.info, 3)
varDef.sons[0] = newSymNode(result)
varDef.sons[1] = newNodeI(nkEmpty, varSection.info)
varDef.sons[2] = expr
varSection.add varDef
proc addLocalVar(g: ModuleGraph; varSection, varInit: PNode; owner: PSym; typ: PType;
v: PNode; useShallowCopy=false): PSym =
result = newSym(skTemp, getIdent(g.cache, genPrefix), owner, varSection.info,

View File

@@ -73,6 +73,16 @@ template semIdeForTemplateOrGeneric(c: PContext; n: PNode;
# echo "passing to safeSemExpr: ", renderTree(n)
discard safeSemExpr(c, n)
proc fitNodePostMatch(c: PContext, formal: PType, arg: PNode): PNode =
result = arg
let x = result.skipConv
if x.kind in {nkPar, nkTupleConstr} and formal.kind != tyExpr:
changeType(c, x, formal, check=true)
else:
result = skipHiddenSubConv(result)
#result.typ = takeType(formal, arg.typ)
#echo arg.info, " picked ", result.typ.typeToString
proc fitNode(c: PContext, formal: PType, arg: PNode; info: TLineInfo): PNode =
if arg.typ.isNil:
localError(c.config, arg.info, "expression has no type: " &
@@ -88,13 +98,7 @@ proc fitNode(c: PContext, formal: PType, arg: PNode; info: TLineInfo): PNode =
result = copyTree(arg)
result.typ = formal
else:
let x = result.skipConv
if x.kind in {nkPar, nkTupleConstr} and formal.kind != tyExpr:
changeType(c, x, formal, check=true)
else:
result = skipHiddenSubConv(result)
#result.typ = takeType(formal, arg.typ)
#echo arg.info, " picked ", result.typ.typeToString
result = fitNodePostMatch(c, formal, result)
proc inferWithMetatype(c: PContext, formal: PType,
arg: PNode, coerceDistincts = false): PNode

View File

@@ -402,7 +402,9 @@ proc updateDefaultParams(call: PNode) =
for i in countdown(call.len - 1, 1):
if nfDefaultParam notin call[i].flags:
return
call[i] = calleeParams[i].sym.ast
let def = calleeParams[i].sym.ast
if nfDefaultRefsParam in def.flags: call.flags.incl nfDefaultRefsParam
call[i] = def
proc semResolvedCall(c: PContext, x: TCandidate,
n: PNode, flags: TExprFlags): PNode =

View File

@@ -220,6 +220,14 @@ proc instGenericContainer(c: PContext, info: TLineInfo, header: PType,
result = replaceTypeVarsT(cl, header)
closeScope(c)
proc referencesAnotherParam(n: PNode, p: PSym): bool =
if n.kind == nkSym:
return n.sym.kind == skParam and n.sym.owner == p
else:
for i in 0..<n.safeLen:
if referencesAnotherParam(n[i], p): return true
return false
proc instantiateProcType(c: PContext, pt: TIdTable,
prc: PSym, info: TLineInfo) =
# XXX: Instantiates a generic proc signature, while at the same
@@ -276,8 +284,22 @@ proc instantiateProcType(c: PContext, pt: TIdTable,
if def.kind == nkCall:
for i in 1 ..< def.len:
def[i] = replaceTypeVarsN(cl, def[i])
def = semExprWithType(c, def)
param.ast = fitNode(c, typeToFit, def, def.info)
def = semExprWithType(c, def)
if def.referencesAnotherParam(getCurrOwner(c)):
def.flags.incl nfDefaultRefsParam
var converted = indexTypesMatch(c, typeToFit, def.typ, def)
if converted == nil:
# The default value doesn't match the final instantiated type.
# As an example of this, see:
# https://github.com/nim-lang/Nim/issues/1201
# We are replacing the default value with an error node in case
# the user calls an explicit instantiation of the proc (this is
# the only way the default value might be inserted).
param.ast = errorNode(c, def)
else:
param.ast = fitNodePostMatch(c, typeToFit, converted)
param.typ = result[i]
result.n[i] = newSymNode(param)

View File

@@ -1042,6 +1042,8 @@ proc semProcTypeNode(c: PContext, n, genericParams: PNode,
break determineType
def = semExprWithType(c, def, {efDetermineType})
if def.referencesAnotherParam(getCurrOwner(c)):
def.flags.incl nfDefaultRefsParam
if typ == nil:
typ = def.typ

View File

@@ -2366,7 +2366,8 @@ proc matches*(c: PContext, n, nOrig: PNode, m: var TCandidate) =
m.firstMismatch = f
break
else:
# use default value:
if nfDefaultRefsParam in formal.ast.flags:
m.call.flags.incl nfDefaultRefsParam
var def = copyTree(formal.ast)
if def.kind == nkNilLit:
def = implicitConv(nkHiddenStdConv, formal.typ, def, m, c)

View File

@@ -780,6 +780,43 @@ proc commonOptimizations*(g: ModuleGraph; c: PSym, n: PNode): PNode =
else:
result = n
proc hoistParamsUsedInDefault(c: PTransf, call, letSection, defExpr: PNode): PNode =
# This takes care of complicated signatures such as:
# proc foo(a: int, b = a)
# proc bar(a: int, b: int, c = a + b)
#
# The recursion may confuse you. It performs two duties:
#
# 1) extracting all referenced params from default expressions
# into a let section preceeding the call
#
# 2) replacing the "references" within the default expression
# with these extracted skLet symbols.
#
# The first duty is carried out directly in the code here, while the second
# duty is activated by returning a non-nil value. The caller is responsible
# for replacing the input to the function with the returned non-nil value.
# (which is the hoisted symbol)
if defExpr.kind == nkSym:
if defExpr.sym.kind == skParam and defExpr.sym.owner == call[0].sym:
let paramPos = defExpr.sym.position + 1
if call[paramPos].kind == nkSym and sfHoisted in call[paramPos].sym.flags:
# Already hoisted, we still need to return it in order to replace the
# placeholder expression in the default value.
return call[paramPos]
let hoistedVarSym = hoistExpr(letSection,
call[paramPos],
getIdent(c.graph.cache, genPrefix),
c.transCon.owner).newSymNode
call[paramPos] = hoistedVarSym
return hoistedVarSym
else:
for i in 0..<defExpr.safeLen:
let hoisted = hoistParamsUsedInDefault(c, call, letSection, defExpr[i])
if hoisted != nil: defExpr[i] = hoisted
proc transform(c: PTransf, n: PNode): PTransNode =
when false:
var oldDeferAnchor: PNode
@@ -849,6 +886,15 @@ proc transform(c: PTransf, n: PNode): PTransNode =
of nkBreakStmt: result = transformBreak(c, n)
of nkCallKinds:
result = transformCall(c, n)
var call = result.PNode
if nfDefaultRefsParam in call.flags:
# We've found a default value that references another param.
# See the notes in `hoistParamsUsedInDefault` for more details.
var hoistedParams = newNodeI(nkLetSection, call.info, 0)
for i in 1 ..< call.len:
let hoisted = hoistParamsUsedInDefault(c, call, hoistedParams, call[i])
if hoisted != nil: call[i] = hoisted
result = newTree(nkStmtListExpr, hoistedParams, call).PTransNode
of nkAddr, nkHiddenAddr:
result = transformAddrDeref(c, n, nkDerefExpr, nkHiddenDeref)
of nkDerefExpr, nkHiddenDeref:

View File

@@ -0,0 +1,114 @@
discard """
output: '''
@[1, 2, 3]@[1, 2, 3]
a
a
1
3 is an int
2 is an int
miau is a string
f1 1 1 1
f1 2 3 3
f1 10 20 30
f2 100 100 100
f2 200 300 300
f2 300 400 400
f3 10 10 20
f3 10 15 25
true true
false true
world
'''
"""
template reject(x) =
assert(not compiles(x))
block:
# https://github.com/nim-lang/Nim/issues/7756
proc foo[T](x: seq[T], y: seq[T] = x) =
echo x, y
let a = @[1, 2, 3]
foo(a)
block:
# https://github.com/nim-lang/Nim/issues/1201
proc issue1201(x: char|int = 'a') = echo x
issue1201()
issue1201('a')
issue1201(1)
# https://github.com/nim-lang/Nim/issues/7000
proc test(a: int|string = 2) =
when a is int:
echo a, " is an int"
elif a is string:
echo a, " is a string"
test(3) # works
test() # works
test("miau")
block:
# https://github.com/nim-lang/Nim/issues/3002 and similar
proc f1(a: int, b = a, c = b) =
echo "f1 ", a, " ", b, " ", c
proc f2(a: int, b = a, c: int = b) =
echo "f2 ", a, " ", b, " ", c
proc f3(a: int, b = a, c = a + b) =
echo "f3 ", a, " ", b, " ", c
f1 1
f1(2, 3)
f1 10, 20, 30
100.f2
200.f2 300
300.f2(400)
10.f3()
10.f3(15)
reject:
# This is a type mismatch error:
proc f4(a: int, b = a, c: float = b) = discard
reject:
# undeclared identifier
proc f5(a: int, b = c, c = 10) = discard
reject:
# undeclared identifier
proc f6(a: int, b = b) = discard
reject:
# undeclared identifier
proc f7(a = a) = discard
block:
proc f(a: var int, b: ptr int, c = addr(a)) =
echo addr(a) == b, " ", b == c
var x = 10
f(x, addr(x))
f(x, nil, nil)
block:
# https://github.com/nim-lang/Nim/issues/1046
proc pySubstr(s: string, start: int, endd = s.len()): string =
var
revStart = start
revEnd = endd
if start < 0:
revStart = s.len() + start
if endd < 0:
revEnd = s.len() + endd
return s[revStart .. revEnd-1]
echo pySubstr("Hello world", -5)