From fee2a7ecfa883d100e1177b1cc7d738c6cfeaa83 Mon Sep 17 00:00:00 2001 From: Zahary Karadjov Date: Thu, 22 Aug 2013 19:22:28 +0300 Subject: [PATCH] Experimental support for delayed instantiation of generics This postpones the semantic pass over the generic's body until the generic is instantiated. There are several pros and cons for this method and the capabilities that it enables may still be possible in the old framework if we teach it a few new trick. Such an attempt will follow in the next commits. pros: 1) It allows macros to be expanded during generic instantiation that will provide the body of the generic. See ``tmacrogenerics``. 2) The instantiation code is dramatically simplified. Dealing with unknown types in the generic's body pre-pass requires a lot of hacky code and error silencing in semTypeNode. See ``tgenericshardcases``. cons: 1) There is a performance penalty of roughly 5% when bootstrapping. 2) Certain errors that used to be detected in the previous pre-pass won't be detected with the new scheme until instantiation. --- compiler/ast.nim | 7 ++-- compiler/evals.nim | 20 +++++------- compiler/options.nim | 2 ++ compiler/semfold.nim | 16 ++++++--- compiler/seminst.nim | 49 ++++++++++++++++++++++++---- compiler/semmagic.nim | 2 +- compiler/semstmts.nim | 16 +++++---- compiler/semtypes.nim | 11 ++++--- compiler/semtypinst.nim | 4 +-- compiler/sigmatch.nim | 3 +- tests/compile/tgenericshardcases.nim | 30 +++++++++++++++++ tests/run/tmacrogenerics.nim | 39 ++++++++++++++++++++++ 12 files changed, 161 insertions(+), 38 deletions(-) create mode 100644 tests/compile/tgenericshardcases.nim create mode 100644 tests/run/tmacrogenerics.nim diff --git a/compiler/ast.nim b/compiler/ast.nim index bb015ea274..bb06e7163d 100644 --- a/compiler/ast.nim +++ b/compiler/ast.nim @@ -626,6 +626,7 @@ type case kind*: TSymKind of skType: typeInstCache*: seq[PType] + typScope*: PScope of routineKinds: procInstCache*: seq[PInstantiation] scope*: PScope # the scope where the proc was defined @@ -799,9 +800,9 @@ const # imported via 'importc: "fullname"' and no format string. # creator procs: -proc NewSym*(symKind: TSymKind, Name: PIdent, owner: PSym, +proc newSym*(symKind: TSymKind, Name: PIdent, owner: PSym, info: TLineInfo): PSym -proc NewType*(kind: TTypeKind, owner: PSym): PType +proc newType*(kind: TTypeKind, owner: PSym): PType proc newNode*(kind: TNodeKind): PNode proc newIntNode*(kind: TNodeKind, intVal: BiggestInt): PNode proc newIntTypeNode*(kind: TNodeKind, intVal: BiggestInt, typ: PType): PNode @@ -1111,7 +1112,7 @@ proc copySym(s: PSym, keepId: bool = false): PSym = result.loc = s.loc result.annex = s.annex # BUGFIX -proc NewSym(symKind: TSymKind, Name: PIdent, owner: PSym, +proc newSym(symKind: TSymKind, Name: PIdent, owner: PSym, info: TLineInfo): PSym = # generates a symbol and initializes the hash field too new(result) diff --git a/compiler/evals.nim b/compiler/evals.nim index 3f09664a79..35954cceca 100644 --- a/compiler/evals.nim +++ b/compiler/evals.nim @@ -904,18 +904,16 @@ proc evalParseStmt(c: PEvalContext, n: PNode): PNode = code.info.line.int) #result.typ = newType(tyStmt, c.module) -proc evalTypeTrait*(n: PNode, context: PSym): PNode = - ## XXX: This should be pretty much guaranteed to be true - # by the type traits procs' signatures, but until the - # code is more mature it doesn't hurt to be extra safe - internalAssert n.sons.len >= 2 and n.sons[1].kind == nkSym +proc evalTypeTrait*(trait, operand: PNode, context: PSym): PNode = + InternalAssert operand.kind == nkSym and + operand.sym.typ.kind == tyTypeDesc - let typ = n.sons[1].sym.typ.skipTypes({tyTypeDesc}) - case n.sons[0].sym.name.s.normalize + let typ = operand.sym.typ.skipTypes({tyTypeDesc}) + case trait.sym.name.s.normalize of "name": - result = newStrNode(nkStrLit, typ.typeToString(preferExported)) + result = newStrNode(nkStrLit, typ.typeToString(preferName)) result.typ = newType(tyString, context) - result.info = n.info + result.info = trait.info else: internalAssert false @@ -1037,8 +1035,8 @@ proc evalMagicOrCall(c: PEvalContext, n: PNode): PNode = of mParseStmtToAst: result = evalParseStmt(c, n) of mExpandToAst: result = evalExpandToAst(c, n) of mTypeTrait: - n.sons[1] = evalAux(c, n.sons[1], {}) - result = evalTypeTrait(n, c.module) + let operand = evalAux(c, n.sons[1], {}) + result = evalTypeTrait(n[0], operand, c.module) of mIs: n.sons[1] = evalAux(c, n.sons[1], {}) result = evalIsOp(n) diff --git a/compiler/options.nim b/compiler/options.nim index 5f173d240c..cfda15f6a1 100644 --- a/compiler/options.nim +++ b/compiler/options.nim @@ -156,6 +156,8 @@ var const oKeepVariableNames* = true +const oUseLateInstantiation* = true + proc mainCommandArg*: string = ## This is intended for commands like check or parse ## which will work on the main project file unless diff --git a/compiler/semfold.nim b/compiler/semfold.nim index 77ecf2e979..ed7b146849 100644 --- a/compiler/semfold.nim +++ b/compiler/semfold.nim @@ -558,9 +558,10 @@ proc getConstExpr(m: PSym, n: PNode): PNode = case n.kind of nkSym: var s = n.sym - if s.kind == skEnumField: + case s.kind + of skEnumField: result = newIntNodeT(s.position, n) - elif s.kind == skConst: + of skConst: case s.magic of mIsMainModule: result = newIntNodeT(ord(sfMainModule in m.flags), n) of mCompileDate: result = newStrNodeT(times.getDateStr(), n) @@ -578,10 +579,17 @@ proc getConstExpr(m: PSym, n: PNode): PNode = of mNegInf: result = newFloatNodeT(NegInf, n) else: if sfFakeConst notin s.flags: result = copyTree(s.ast) - elif s.kind in {skProc, skMethod}: # BUGFIX + of {skProc, skMethod}: result = n - elif s.kind in {skType, skGenericParam}: + of skType: result = newSymNodeTypeDesc(s, n.info) + of skGenericParam: + if s.typ.kind == tyExpr: + result = s.typ.n + result.typ = s.typ.sons[0] + else: + result = newSymNodeTypeDesc(s, n.info) + else: nil of nkCharLit..nkNilLit: result = copyNode(n) of nkIfExpr: diff --git a/compiler/seminst.nim b/compiler/seminst.nim index 601072833d..0cf5086a8d 100644 --- a/compiler/seminst.nim +++ b/compiler/seminst.nim @@ -130,13 +130,50 @@ proc sideEffectsCheck(c: PContext, s: PSym) = s.ast.sons[genericParamsPos].kind == nkEmpty: c.threadEntries.add(s) +proc lateInstantiateGeneric(c: PContext, invocation: PType, info: TLineInfo): PType = + InternalAssert invocation.kind == tyGenericInvokation + + let cacheHit = searchInstTypes(invocation) + if cacheHit != nil: + result = cacheHit + else: + let s = invocation.sons[0].sym + let oldScope = c.currentScope + c.currentScope = s.typScope + openScope(c) + pushInfoContext(info) + for i in 0 .. 0): # This is either a type known to sem or a typedesc # param to a regular proc (again, known at instantiation) - result = evalTypeTrait(n, GetCurrOwner()) + result = evalTypeTrait(n[0], n[1], GetCurrOwner()) else: # a typedesc variable, pass unmodified to evals result = n diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim index 33c0adac1d..a15b3e10a6 100644 --- a/compiler/semstmts.nim +++ b/compiler/semstmts.nim @@ -729,12 +729,16 @@ proc typeSectionRightSidePass(c: PContext, n: PNode) = # like: mydata.seq rawAddSon(s.typ, newTypeS(tyEmpty, c)) s.ast = a - inc c.InGenericContext - var body = semTypeNode(c, a.sons[2], nil) - dec c.InGenericContext - if body != nil: - body.sym = s - body.size = -1 # could not be computed properly + when oUseLateInstantiation: + var body: PType = nil + s.typScope = c.currentScope.parent + else: + inc c.InGenericContext + var body = semTypeNode(c, a.sons[2], nil) + dec c.InGenericContext + if body != nil: + body.sym = s + body.size = -1 # could not be computed properly s.typ.sons[sonsLen(s.typ) - 1] = body popOwner() closeScope(c) diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim index 8ae23f851d..1f7574bb5f 100644 --- a/compiler/semtypes.nim +++ b/compiler/semtypes.nim @@ -156,7 +156,7 @@ proc semRangeAux(c: PContext, n: PNode, prev: PType): PType = LocalError(n.Info, errRangeIsEmpty) var a = semConstExpr(c, n[1]) var b = semConstExpr(c, n[2]) - if not sameType(a.typ, b.typ): + if not sameType(a.typ, b.typ): LocalError(n.info, errPureTypeMismatch) elif a.typ.kind notin {tyInt..tyInt64,tyEnum,tyBool,tyChar, tyFloat..tyFloat128,tyUInt8..tyUInt32}: @@ -204,8 +204,8 @@ proc semArray(c: PContext, n: PNode, prev: PType): PType = indx = e.typ.skipTypes({tyTypeDesc}) addSonSkipIntLit(result, indx) if indx.kind == tyGenericInst: indx = lastSon(indx) - if indx.kind != tyGenericParam: - if not isOrdinalType(indx): + if indx.kind != tyGenericParam: + if not isOrdinalType(indx): LocalError(n.sons[1].info, errOrdinalTypeExpected) elif enumHasHoles(indx): LocalError(n.sons[1].info, errEnumXHasHoles, indx.sym.name.s) @@ -846,7 +846,10 @@ proc semGeneric(c: PContext, n: PNode, s: PSym, prev: PType): PType = LocalError(n.info, errCannotInstantiateX, s.name.s) result = newOrPrevType(tyError, prev, c) else: - result = instGenericContainer(c, n, result) + when oUseLateInstantiation: + result = lateInstantiateGeneric(c, result, n.info) + else: + result = instGenericContainer(c, n, result) proc semTypeExpr(c: PContext, n: PNode): PType = var n = semExprWithType(c, n, {efDetermineType}) diff --git a/compiler/semtypinst.nim b/compiler/semtypinst.nim index 31fbc33e1e..5482d5df8f 100644 --- a/compiler/semtypinst.nim +++ b/compiler/semtypinst.nim @@ -31,7 +31,7 @@ proc checkConstructedType*(info: TLineInfo, typ: PType) = if t.sons[0].kind != tyObject or tfFinal in t.sons[0].flags: localError(info, errInheritanceOnlyWithNonFinalObjects) -proc searchInstTypes(key: PType): PType = +proc searchInstTypes*(key: PType): PType = let genericTyp = key.sons[0] InternalAssert genericTyp.kind == tyGenericBody and key.sons[0] == genericTyp and @@ -55,7 +55,7 @@ proc searchInstTypes(key: PType): PType = return inst -proc cacheTypeInst(inst: PType) = +proc cacheTypeInst*(inst: PType) = # XXX: add to module's generics # update the refcount let genericTyp = inst.sons[0] diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim index 626d16d64a..fa87f27a7b 100644 --- a/compiler/sigmatch.nim +++ b/compiler/sigmatch.nim @@ -654,7 +654,7 @@ proc typeRel(c: var TCandidate, f, a: PType): TTypeRelation = result = typeRel(c, x, a) # check if it fits of tyTypeDesc: var prev = PType(idTableGet(c.bindings, f)) - if prev == nil: + if prev == nil or true: if a.kind == tyTypeDesc: if f.sonsLen == 0: result = isGeneric @@ -768,6 +768,7 @@ proc ParamTypesMatchAux(c: PContext, m: var TCandidate, f, a: PType, if evaluated != nil: r = isGeneric arg.typ = newTypeS(tyExpr, c) + arg.typ.sons = @[evaluated.typ] arg.typ.n = evaluated if r == isGeneric: diff --git a/tests/compile/tgenericshardcases.nim b/tests/compile/tgenericshardcases.nim new file mode 100644 index 0000000000..90981c7014 --- /dev/null +++ b/tests/compile/tgenericshardcases.nim @@ -0,0 +1,30 @@ +discard """ + file: "tgenericshardcases.nim" + output: "int\nfloat\nint\nstring" +""" + +import typetraits + +proc typeNameLen(x: typedesc): int {.compileTime.} = + result = x.name.len + +macro selectType(a, b: typedesc): typedesc = + result = a + +type + Foo[T] = object + data1: array[high(T), int] + data2: array[1..typeNameLen(T), selectType(float, string)] + + MyEnum = enum A, B, C,D + +var f1: Foo[MyEnum] +var f2: Foo[int8] + +static: + assert high(f1.data1) == D + assert high(f1.data2) == 6 # length of MyEnum + + assert high(f2.data1) == 127 + assert high(f2.data2) == 4 # length of int8 + diff --git a/tests/run/tmacrogenerics.nim b/tests/run/tmacrogenerics.nim new file mode 100644 index 0000000000..5ae59e0da4 --- /dev/null +++ b/tests/run/tmacrogenerics.nim @@ -0,0 +1,39 @@ +discard """ + file: "tmacrogenerics.nim" + msg: ''' +instantiation 1 with int and float +instantiation 2 with float and string +instantiation 3 with string and string +counter: 3 +''' + output: "int\nfloat\nint\nstring" +""" + +import typetraits, macros + +var counter {.compileTime.} = 0 + +macro makeBar(A, B: typedesc): typedesc = + inc counter + echo "instantiation ", counter, " with ", A.name, " and ", B.name + result = A + +type + Bar[T, U] = makeBar(T, U) + +var bb1: Bar[int, float] +var bb2: Bar[float, string] +var bb3: Bar[int, float] +var bb4: Bar[string, string] + +proc match(a: int) = echo "int" +proc match(a: string) = echo "string" +proc match(a: float) = echo "float" + +match(bb1) +match(bb2) +match(bb3) +match(bb4) + +static: + echo "counter: ", counter