infer static parameters even when more complicated arithmetic is involved

This commit is contained in:
Zahary Karadjov
2016-08-12 03:25:59 +03:00
parent 0b0a3e5f20
commit 0f2c4be129
10 changed files with 139 additions and 48 deletions

View File

@@ -1581,7 +1581,7 @@ proc hasPattern*(s: PSym): bool {.inline.} =
result = isRoutine(s) and s.ast.sons[patternPos].kind != nkEmpty
iterator items*(n: PNode): PNode =
for i in 0.. <n.len: yield n.sons[i]
for i in 0.. <n.safeLen: yield n.sons[i]
iterator pairs*(n: PNode): tuple[i: int, n: PNode] =
for i in 0.. <n.len: yield (i, n.sons[i])
@@ -1624,6 +1624,16 @@ proc toObject*(typ: PType): PType =
if result.kind == tyRef:
result = result.lastSon
proc findUnresolvedStatic*(n: PNode): PNode =
if n.kind == nkSym and n.typ.kind == tyStatic and n.typ.n == nil:
return n
for son in n:
let n = son.findUnresolvedStatic
if n != nil: return n
return nil
when false:
proc containsNil*(n: PNode): bool =
# only for debugging

View File

@@ -109,6 +109,7 @@ type
errXCannotBeClosure, errXMustBeCompileTime,
errCannotInferTypeOfTheLiteral,
errCannotInferReturnType,
errCannotInferStaticParam,
errGenericLambdaNotAllowed,
errProcHasNoConcreteType,
errCompilerDoesntSupportTarget,
@@ -373,6 +374,7 @@ const
errXMustBeCompileTime: "'$1' can only be used in compile-time context",
errCannotInferTypeOfTheLiteral: "cannot infer the type of the $1",
errCannotInferReturnType: "cannot infer the return type of the proc",
errCannotInferStaticParam: "cannot infer the value of the static param `$1`",
errGenericLambdaNotAllowed: "A nested proc can have generic parameters only when " &
"it is used as an operand to another routine and the types " &
"of the generic paramers can be inferred from the expected signature.",

View File

@@ -318,23 +318,14 @@ proc makeRangeWithStaticExpr*(c: PContext, n: PNode): PType =
let intType = getSysType(tyInt)
result = newTypeS(tyRange, c)
result.sons = @[intType]
if n.typ.n == nil: result.flags.incl tfUnresolved
if n.typ != nil and n.typ.n == nil:
result.flags.incl tfUnresolved
result.n = newNode(nkRange, n.info, @[
newIntTypeNode(nkIntLit, 0, intType),
makeStaticExpr(c, n.nMinusOne)])
template rangeHasUnresolvedStatic*(t: PType): bool =
# this accepts the ranges's node
t.n != nil and t.n.len > 1 and t.n[1].kind == nkStaticExpr
proc findUnresolvedStaticInRange*(t: PType): (PType, int) =
assert t.kind == tyRange
# XXX: This really needs to become more sophisticated
let upperBound = t.n[1]
if upperBound[0].kind == nkCall:
return (upperBound[0][1].typ, 1)
else:
return (upperBound.typ, 0)
tfUnresolved in t.flags
proc errorType*(c: PContext): PType =
## creates a type representing an error state

View File

@@ -627,6 +627,7 @@ proc evalAtCompileTime(c: PContext, n: PNode): PNode =
proc semStaticExpr(c: PContext, n: PNode): PNode =
let a = semExpr(c, n.sons[0])
if a.findUnresolvedStatic != nil: return a
result = evalStaticExpr(c.module, c.cache, a, c.p.owner)
if result.isNil:
localError(n.info, errCannotInterpretNodeX, renderTree(n))

View File

@@ -1621,7 +1621,7 @@ proc semStmtList(c: PContext, n: PNode, flags: TExprFlags): PNode =
elif expr[2].typ.isUnresolvedStatic:
inferConceptStaticParam(c, expr[2].typ, expr[1])
continue
let verdict = semConstExpr(c, n[i])
if verdict.intVal == 0:
localError(result.info, "type class predicate failed")

View File

@@ -195,6 +195,7 @@ proc semRangeAux(c: PContext, n: PNode, prev: PType): PType =
for i in 0..1:
if hasGenericArguments(range[i]):
result.n.addSon makeStaticExpr(c, range[i])
result.flags.incl tfUnresolved
else:
result.n.addSon semConstExpr(c, range[i])

View File

@@ -445,7 +445,7 @@ proc replaceTypeVarsTAux(cl: var TReplTypeVars, t: PType): PType =
elif t.sons[0].kind != tyNone:
result = makeTypeDesc(cl.c, replaceTypeVarsT(cl, t.sons[0]))
of tyUserTypeClass:
of tyUserTypeClass, tyStatic:
result = t
of tyGenericInst, tyUserTypeClassInst:
@@ -502,8 +502,9 @@ proc initTypeVars*(p: PContext, pt: TIdTable, info: TLineInfo;
result.owner = owner
proc replaceTypesInBody*(p: PContext, pt: TIdTable, n: PNode;
owner: PSym): PNode =
owner: PSym, allowMetaTypes = false): PNode =
var cl = initTypeVars(p, pt, n.info, owner)
cl.allowMetaTypes = allowMetaTypes
pushInfoContext(n.info)
result = replaceTypeVarsN(cl, n)
popInfoContext()

View File

@@ -681,16 +681,125 @@ proc maybeSkipDistinct(t: PType, callee: PSym): PType =
else:
result = t
proc tryResolvingStaticExpr(c: var TCandidate, n: PNode): PNode =
proc tryResolvingStaticExpr(c: var TCandidate, n: PNode,
allowUnresolved = false): PNode =
# Consider this example:
# type Value[N: static[int]] = object
# proc foo[N](a: Value[N], r: range[0..(N-1)])
# Here, N-1 will be initially nkStaticExpr that can be evaluated only after
# N is bound to a concrete value during the matching of the first param.
# This proc is used to evaluate such static expressions.
let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil)
let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil,
allowMetaTypes = allowUnresolved)
result = c.c.semExpr(c.c, instantiated)
proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType =
# This is a simple integer arithimetic equation solver,
# capable of deriving the value of a static parameter in
# expressions such as (N + 5) / 2 = rhs
#
# Preconditions:
#
# * The input of this proc must be semantized
# - all templates should be expanded
# - aby constant folding possible should already be performed
#
# * There must be exactly one unresolved static parameter
#
# Result:
#
# The proc will return the inferred static type with the `n` field
# populated with the inferred value.
#
# `nil` will be returned if the inference was not possible
#
if lhs.kind in nkCallKinds and lhs[0].kind == nkSym:
case lhs[0].sym.magic
of mUnaryLt:
return inferStaticParam(lhs[1], rhs + 1)
of mAddI, mAddU, mInc, mSucc:
if lhs[1].kind == nkIntLit:
return inferStaticParam(lhs[2], rhs - lhs[1].intVal)
elif lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], rhs - lhs[2].intVal)
of mDec, mSubI, mSubU, mPred:
if lhs[1].kind == nkIntLit:
return inferStaticParam(lhs[2], lhs[1].intVal - rhs)
elif lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], rhs + lhs[2].intVal)
of mMulI, mMulU:
if lhs[1].kind == nkIntLit:
if rhs mod lhs[1].intVal == 0:
return inferStaticParam(lhs[2], rhs div lhs[1].intVal)
elif lhs[2].kind == nkIntLit:
if rhs mod lhs[2].intVal == 0:
return inferStaticParam(lhs[1], rhs div lhs[2].intVal)
of mDivI, mDivU:
if lhs[1].kind == nkIntLit:
if lhs[1].intVal mod rhs == 0:
return inferStaticParam(lhs[2], lhs[1].intVal div rhs)
elif lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], lhs[2].intVal * rhs)
of mShlI:
if lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], rhs shr lhs[2].intVal)
of mShrI:
if lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], rhs shl lhs[2].intVal)
of mUnaryMinusI:
return inferStaticParam(lhs[1], -rhs)
of mUnaryPlusI, mToInt, mToBiggestInt:
return inferStaticParam(lhs[1], rhs)
else: discard
elif lhs.kind == nkSym and lhs.typ.kind == tyStatic and lhs.typ.n == nil:
lhs.typ.n = newIntNode(nkIntLit, rhs)
return lhs.typ
return nil
proc failureToInferStaticParam(n: PNode) =
let staticParam = n.findUnresolvedStatic
let name = if staticParam != nil: staticParam.sym.name.s
else: "unknown"
localError(n.info, errCannotInferStaticParam, name)
proc inferStaticsInRange(c: var TCandidate,
inferred, concrete: PType): TTypeRelation =
let lowerBound = tryResolvingStaticExpr(c, inferred.n[0],
allowUnresolved = true)
let upperBound = tryResolvingStaticExpr(c, inferred.n[1],
allowUnresolved = true)
template doInferStatic(c: var TCandidate, e: PNode, r: BiggestInt) =
var exp = e
var rhs = r
var inferred = inferStaticParam(exp, rhs)
if inferred != nil:
put(c.bindings, inferred, inferred)
return isGeneric
else:
failureToInferStaticParam exp
if lowerBound.kind == nkIntLit:
if upperBound.kind == nkIntLit:
if lengthOrd(concrete) == upperBound.intVal - lowerBound.intVal + 1:
return isGeneric
else:
return isNone
doInferStatic(c, upperBound, lengthOrd(concrete) + lowerBound.intVal - 1)
elif upperBound.kind == nkIntLit:
doInferStatic(c, lowerBound, upperBound.intVal + 1 - lengthOrd(concrete))
template subtypeCheck() =
if result <= isSubrange and f.lastSon.skipTypes(abstractInst).kind in {tyRef, tyPtr, tyVar}:
result = isNone
@@ -894,34 +1003,10 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation =
a.sons[1].skipTypes({tyTypeDesc}))
if result < isGeneric: return isNone
proc inferStaticRange(c: var TCandidate, inferred, concrete: PType) =
var (staticT, offset) = inferred.findUnresolvedStaticInRange
var
replacementT = newTypeWithSons(c.c, tyStatic, @[tyInt.getSysType])
concreteUpperBound = concrete.n[1].intVal
# we must correct for the off-by-one discrepancy between
# ranges and static params:
replacementT.n = newIntNode(nkIntLit, concreteUpperBound + offset)
if tfInferrableStatic in staticT.flags:
staticT.n = replacementT.n
put(c.bindings, staticT, replacementT)
if rangeHasUnresolvedStatic(fRange):
if tfUnresolved in fRange.flags:
# This is a range from an array instantiated with a generic
# static param. We must extract the static param here and bind
# it to the size of the currently supplied array.
inferStaticRange(c, fRange, aRange)
return isGeneric
let len = tryResolvingStaticExpr(c, fRange.n[1])
if len.kind == nkIntLit and len.intVal+1 == lengthOrd(a):
return # if we get this far, the result is already good
else:
return isNone
if fRange.rangeHasUnresolvedStatic:
return inferStaticsInRange(c, fRange, a)
elif c.c.inTypeClass > 0 and aRange.rangeHasUnresolvedStatic:
inferStaticRange(c, aRange, fRange)
return isGeneric
return inferStaticsInRange(c, aRange, f)
elif lengthOrd(fRange) != lengthOrd(a):
result = isNone
else: discard

View File

@@ -63,7 +63,7 @@ const
abstractVarRange* = {tyGenericInst, tyRange, tyVar, tyDistinct, tyOrdinal,
tyTypeDesc, tyAlias, tyInferred}
abstractInst* = {tyGenericInst, tyDistinct, tyOrdinal, tyTypeDesc, tyAlias,
tyInferred} + tyTypeClasses
tyInferred}
skipPtrs* = {tyVar, tyPtr, tyRef, tyGenericInst, tyTypeDesc, tyAlias,
tyInferred}
# typedescX is used if we're sure tyTypeDesc should be included (or skipped)

View File

@@ -367,8 +367,8 @@ operator and also when types dependent on them are being matched:
MyConcept[M, N: static[int]; T] = concept x
x.foo(SquareMatrix[N, T]) is array[M, int]
Nim may include a simple linear equation solver in the future to help us
infer static params when arithmetic is involved.
The Nim compiler includes a simple linear equation solver, allowing it to
infer static params in some situations where integer arithmetic is involved.
Just like in regular type classes, Nim discriminates between ``bind once``
and ``bind many`` types when matching the concept. You can add the ``distinct``