From 0f2c4be1299fc99aeea2011c57240c8cfabd83c3 Mon Sep 17 00:00:00 2001 From: Zahary Karadjov Date: Fri, 12 Aug 2016 03:25:59 +0300 Subject: [PATCH] infer static parameters even when more complicated arithmetic is involved --- compiler/ast.nim | 12 +++- compiler/msgs.nim | 2 + compiler/semdata.nim | 15 +---- compiler/semexprs.nim | 1 + compiler/semstmts.nim | 2 +- compiler/semtypes.nim | 1 + compiler/semtypinst.nim | 5 +- compiler/sigmatch.nim | 143 ++++++++++++++++++++++++++++++++-------- compiler/types.nim | 2 +- doc/manual/generics.txt | 4 +- 10 files changed, 139 insertions(+), 48 deletions(-) diff --git a/compiler/ast.nim b/compiler/ast.nim index f13691d540..5adac92dfe 100644 --- a/compiler/ast.nim +++ b/compiler/ast.nim @@ -1581,7 +1581,7 @@ proc hasPattern*(s: PSym): bool {.inline.} = result = isRoutine(s) and s.ast.sons[patternPos].kind != nkEmpty iterator items*(n: PNode): PNode = - for i in 0.. 1 and t.n[1].kind == nkStaticExpr - -proc findUnresolvedStaticInRange*(t: PType): (PType, int) = - assert t.kind == tyRange - # XXX: This really needs to become more sophisticated - let upperBound = t.n[1] - if upperBound[0].kind == nkCall: - return (upperBound[0][1].typ, 1) - else: - return (upperBound.typ, 0) + tfUnresolved in t.flags proc errorType*(c: PContext): PType = ## creates a type representing an error state diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim index 3ec2cd3912..930c843688 100644 --- a/compiler/semexprs.nim +++ b/compiler/semexprs.nim @@ -627,6 +627,7 @@ proc evalAtCompileTime(c: PContext, n: PNode): PNode = proc semStaticExpr(c: PContext, n: PNode): PNode = let a = semExpr(c, n.sons[0]) + if a.findUnresolvedStatic != nil: return a result = evalStaticExpr(c.module, c.cache, a, c.p.owner) if result.isNil: localError(n.info, errCannotInterpretNodeX, renderTree(n)) diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim index 64449dda4b..09631c7936 100644 --- a/compiler/semstmts.nim +++ b/compiler/semstmts.nim @@ -1621,7 +1621,7 @@ proc semStmtList(c: PContext, n: PNode, flags: TExprFlags): PNode = elif expr[2].typ.isUnresolvedStatic: inferConceptStaticParam(c, expr[2].typ, expr[1]) continue - + let verdict = semConstExpr(c, n[i]) if verdict.intVal == 0: localError(result.info, "type class predicate failed") diff --git a/compiler/semtypes.nim b/compiler/semtypes.nim index bf6c243108..eef83c2a72 100644 --- a/compiler/semtypes.nim +++ b/compiler/semtypes.nim @@ -195,6 +195,7 @@ proc semRangeAux(c: PContext, n: PNode, prev: PType): PType = for i in 0..1: if hasGenericArguments(range[i]): result.n.addSon makeStaticExpr(c, range[i]) + result.flags.incl tfUnresolved else: result.n.addSon semConstExpr(c, range[i]) diff --git a/compiler/semtypinst.nim b/compiler/semtypinst.nim index 7e114afb8a..9e72e46f6a 100644 --- a/compiler/semtypinst.nim +++ b/compiler/semtypinst.nim @@ -445,7 +445,7 @@ proc replaceTypeVarsTAux(cl: var TReplTypeVars, t: PType): PType = elif t.sons[0].kind != tyNone: result = makeTypeDesc(cl.c, replaceTypeVarsT(cl, t.sons[0])) - of tyUserTypeClass: + of tyUserTypeClass, tyStatic: result = t of tyGenericInst, tyUserTypeClassInst: @@ -502,8 +502,9 @@ proc initTypeVars*(p: PContext, pt: TIdTable, info: TLineInfo; result.owner = owner proc replaceTypesInBody*(p: PContext, pt: TIdTable, n: PNode; - owner: PSym): PNode = + owner: PSym, allowMetaTypes = false): PNode = var cl = initTypeVars(p, pt, n.info, owner) + cl.allowMetaTypes = allowMetaTypes pushInfoContext(n.info) result = replaceTypeVarsN(cl, n) popInfoContext() diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim index 162385e6d8..ca9cdcaf8a 100644 --- a/compiler/sigmatch.nim +++ b/compiler/sigmatch.nim @@ -681,16 +681,125 @@ proc maybeSkipDistinct(t: PType, callee: PSym): PType = else: result = t -proc tryResolvingStaticExpr(c: var TCandidate, n: PNode): PNode = +proc tryResolvingStaticExpr(c: var TCandidate, n: PNode, + allowUnresolved = false): PNode = # Consider this example: # type Value[N: static[int]] = object # proc foo[N](a: Value[N], r: range[0..(N-1)]) # Here, N-1 will be initially nkStaticExpr that can be evaluated only after # N is bound to a concrete value during the matching of the first param. # This proc is used to evaluate such static expressions. - let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil) + let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil, + allowMetaTypes = allowUnresolved) result = c.c.semExpr(c.c, instantiated) +proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType = + # This is a simple integer arithimetic equation solver, + # capable of deriving the value of a static parameter in + # expressions such as (N + 5) / 2 = rhs + # + # Preconditions: + # + # * The input of this proc must be semantized + # - all templates should be expanded + # - aby constant folding possible should already be performed + # + # * There must be exactly one unresolved static parameter + # + # Result: + # + # The proc will return the inferred static type with the `n` field + # populated with the inferred value. + # + # `nil` will be returned if the inference was not possible + # + if lhs.kind in nkCallKinds and lhs[0].kind == nkSym: + case lhs[0].sym.magic + of mUnaryLt: + return inferStaticParam(lhs[1], rhs + 1) + + of mAddI, mAddU, mInc, mSucc: + if lhs[1].kind == nkIntLit: + return inferStaticParam(lhs[2], rhs - lhs[1].intVal) + elif lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], rhs - lhs[2].intVal) + + of mDec, mSubI, mSubU, mPred: + if lhs[1].kind == nkIntLit: + return inferStaticParam(lhs[2], lhs[1].intVal - rhs) + elif lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], rhs + lhs[2].intVal) + + of mMulI, mMulU: + if lhs[1].kind == nkIntLit: + if rhs mod lhs[1].intVal == 0: + return inferStaticParam(lhs[2], rhs div lhs[1].intVal) + elif lhs[2].kind == nkIntLit: + if rhs mod lhs[2].intVal == 0: + return inferStaticParam(lhs[1], rhs div lhs[2].intVal) + + of mDivI, mDivU: + if lhs[1].kind == nkIntLit: + if lhs[1].intVal mod rhs == 0: + return inferStaticParam(lhs[2], lhs[1].intVal div rhs) + elif lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], lhs[2].intVal * rhs) + + of mShlI: + if lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], rhs shr lhs[2].intVal) + + of mShrI: + if lhs[2].kind == nkIntLit: + return inferStaticParam(lhs[1], rhs shl lhs[2].intVal) + + of mUnaryMinusI: + return inferStaticParam(lhs[1], -rhs) + + of mUnaryPlusI, mToInt, mToBiggestInt: + return inferStaticParam(lhs[1], rhs) + + else: discard + + elif lhs.kind == nkSym and lhs.typ.kind == tyStatic and lhs.typ.n == nil: + lhs.typ.n = newIntNode(nkIntLit, rhs) + return lhs.typ + + return nil + +proc failureToInferStaticParam(n: PNode) = + let staticParam = n.findUnresolvedStatic + let name = if staticParam != nil: staticParam.sym.name.s + else: "unknown" + localError(n.info, errCannotInferStaticParam, name) + +proc inferStaticsInRange(c: var TCandidate, + inferred, concrete: PType): TTypeRelation = + let lowerBound = tryResolvingStaticExpr(c, inferred.n[0], + allowUnresolved = true) + let upperBound = tryResolvingStaticExpr(c, inferred.n[1], + allowUnresolved = true) + + template doInferStatic(c: var TCandidate, e: PNode, r: BiggestInt) = + var exp = e + var rhs = r + var inferred = inferStaticParam(exp, rhs) + if inferred != nil: + put(c.bindings, inferred, inferred) + return isGeneric + else: + failureToInferStaticParam exp + + if lowerBound.kind == nkIntLit: + if upperBound.kind == nkIntLit: + if lengthOrd(concrete) == upperBound.intVal - lowerBound.intVal + 1: + return isGeneric + else: + return isNone + doInferStatic(c, upperBound, lengthOrd(concrete) + lowerBound.intVal - 1) + elif upperBound.kind == nkIntLit: + doInferStatic(c, lowerBound, upperBound.intVal + 1 - lengthOrd(concrete)) + template subtypeCheck() = if result <= isSubrange and f.lastSon.skipTypes(abstractInst).kind in {tyRef, tyPtr, tyVar}: result = isNone @@ -894,34 +1003,10 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation = a.sons[1].skipTypes({tyTypeDesc})) if result < isGeneric: return isNone - proc inferStaticRange(c: var TCandidate, inferred, concrete: PType) = - var (staticT, offset) = inferred.findUnresolvedStaticInRange - var - replacementT = newTypeWithSons(c.c, tyStatic, @[tyInt.getSysType]) - concreteUpperBound = concrete.n[1].intVal - # we must correct for the off-by-one discrepancy between - # ranges and static params: - replacementT.n = newIntNode(nkIntLit, concreteUpperBound + offset) - if tfInferrableStatic in staticT.flags: - staticT.n = replacementT.n - put(c.bindings, staticT, replacementT) - - if rangeHasUnresolvedStatic(fRange): - if tfUnresolved in fRange.flags: - # This is a range from an array instantiated with a generic - # static param. We must extract the static param here and bind - # it to the size of the currently supplied array. - inferStaticRange(c, fRange, aRange) - return isGeneric - - let len = tryResolvingStaticExpr(c, fRange.n[1]) - if len.kind == nkIntLit and len.intVal+1 == lengthOrd(a): - return # if we get this far, the result is already good - else: - return isNone + if fRange.rangeHasUnresolvedStatic: + return inferStaticsInRange(c, fRange, a) elif c.c.inTypeClass > 0 and aRange.rangeHasUnresolvedStatic: - inferStaticRange(c, aRange, fRange) - return isGeneric + return inferStaticsInRange(c, aRange, f) elif lengthOrd(fRange) != lengthOrd(a): result = isNone else: discard diff --git a/compiler/types.nim b/compiler/types.nim index be7028f9c0..65eb6de61a 100644 --- a/compiler/types.nim +++ b/compiler/types.nim @@ -63,7 +63,7 @@ const abstractVarRange* = {tyGenericInst, tyRange, tyVar, tyDistinct, tyOrdinal, tyTypeDesc, tyAlias, tyInferred} abstractInst* = {tyGenericInst, tyDistinct, tyOrdinal, tyTypeDesc, tyAlias, - tyInferred} + tyTypeClasses + tyInferred} skipPtrs* = {tyVar, tyPtr, tyRef, tyGenericInst, tyTypeDesc, tyAlias, tyInferred} # typedescX is used if we're sure tyTypeDesc should be included (or skipped) diff --git a/doc/manual/generics.txt b/doc/manual/generics.txt index 962daa9c4b..f4afa0d11b 100644 --- a/doc/manual/generics.txt +++ b/doc/manual/generics.txt @@ -367,8 +367,8 @@ operator and also when types dependent on them are being matched: MyConcept[M, N: static[int]; T] = concept x x.foo(SquareMatrix[N, T]) is array[M, int] -Nim may include a simple linear equation solver in the future to help us -infer static params when arithmetic is involved. +The Nim compiler includes a simple linear equation solver, allowing it to +infer static params in some situations where integer arithmetic is involved. Just like in regular type classes, Nim discriminates between ``bind once`` and ``bind many`` types when matching the concept. You can add the ``distinct``