mirror of
https://github.com/nim-lang/Nim.git
synced 2026-01-03 03:32:32 +00:00
infer static parameters even when more complicated arithmetic is involved
This commit is contained in:
@@ -1581,7 +1581,7 @@ proc hasPattern*(s: PSym): bool {.inline.} =
|
||||
result = isRoutine(s) and s.ast.sons[patternPos].kind != nkEmpty
|
||||
|
||||
iterator items*(n: PNode): PNode =
|
||||
for i in 0.. <n.len: yield n.sons[i]
|
||||
for i in 0.. <n.safeLen: yield n.sons[i]
|
||||
|
||||
iterator pairs*(n: PNode): tuple[i: int, n: PNode] =
|
||||
for i in 0.. <n.len: yield (i, n.sons[i])
|
||||
@@ -1624,6 +1624,16 @@ proc toObject*(typ: PType): PType =
|
||||
if result.kind == tyRef:
|
||||
result = result.lastSon
|
||||
|
||||
proc findUnresolvedStatic*(n: PNode): PNode =
|
||||
if n.kind == nkSym and n.typ.kind == tyStatic and n.typ.n == nil:
|
||||
return n
|
||||
|
||||
for son in n:
|
||||
let n = son.findUnresolvedStatic
|
||||
if n != nil: return n
|
||||
|
||||
return nil
|
||||
|
||||
when false:
|
||||
proc containsNil*(n: PNode): bool =
|
||||
# only for debugging
|
||||
|
||||
@@ -109,6 +109,7 @@ type
|
||||
errXCannotBeClosure, errXMustBeCompileTime,
|
||||
errCannotInferTypeOfTheLiteral,
|
||||
errCannotInferReturnType,
|
||||
errCannotInferStaticParam,
|
||||
errGenericLambdaNotAllowed,
|
||||
errProcHasNoConcreteType,
|
||||
errCompilerDoesntSupportTarget,
|
||||
@@ -373,6 +374,7 @@ const
|
||||
errXMustBeCompileTime: "'$1' can only be used in compile-time context",
|
||||
errCannotInferTypeOfTheLiteral: "cannot infer the type of the $1",
|
||||
errCannotInferReturnType: "cannot infer the return type of the proc",
|
||||
errCannotInferStaticParam: "cannot infer the value of the static param `$1`",
|
||||
errGenericLambdaNotAllowed: "A nested proc can have generic parameters only when " &
|
||||
"it is used as an operand to another routine and the types " &
|
||||
"of the generic paramers can be inferred from the expected signature.",
|
||||
|
||||
@@ -318,23 +318,14 @@ proc makeRangeWithStaticExpr*(c: PContext, n: PNode): PType =
|
||||
let intType = getSysType(tyInt)
|
||||
result = newTypeS(tyRange, c)
|
||||
result.sons = @[intType]
|
||||
if n.typ.n == nil: result.flags.incl tfUnresolved
|
||||
if n.typ != nil and n.typ.n == nil:
|
||||
result.flags.incl tfUnresolved
|
||||
result.n = newNode(nkRange, n.info, @[
|
||||
newIntTypeNode(nkIntLit, 0, intType),
|
||||
makeStaticExpr(c, n.nMinusOne)])
|
||||
|
||||
template rangeHasUnresolvedStatic*(t: PType): bool =
|
||||
# this accepts the ranges's node
|
||||
t.n != nil and t.n.len > 1 and t.n[1].kind == nkStaticExpr
|
||||
|
||||
proc findUnresolvedStaticInRange*(t: PType): (PType, int) =
|
||||
assert t.kind == tyRange
|
||||
# XXX: This really needs to become more sophisticated
|
||||
let upperBound = t.n[1]
|
||||
if upperBound[0].kind == nkCall:
|
||||
return (upperBound[0][1].typ, 1)
|
||||
else:
|
||||
return (upperBound.typ, 0)
|
||||
tfUnresolved in t.flags
|
||||
|
||||
proc errorType*(c: PContext): PType =
|
||||
## creates a type representing an error state
|
||||
|
||||
@@ -627,6 +627,7 @@ proc evalAtCompileTime(c: PContext, n: PNode): PNode =
|
||||
|
||||
proc semStaticExpr(c: PContext, n: PNode): PNode =
|
||||
let a = semExpr(c, n.sons[0])
|
||||
if a.findUnresolvedStatic != nil: return a
|
||||
result = evalStaticExpr(c.module, c.cache, a, c.p.owner)
|
||||
if result.isNil:
|
||||
localError(n.info, errCannotInterpretNodeX, renderTree(n))
|
||||
|
||||
@@ -1621,7 +1621,7 @@ proc semStmtList(c: PContext, n: PNode, flags: TExprFlags): PNode =
|
||||
elif expr[2].typ.isUnresolvedStatic:
|
||||
inferConceptStaticParam(c, expr[2].typ, expr[1])
|
||||
continue
|
||||
|
||||
|
||||
let verdict = semConstExpr(c, n[i])
|
||||
if verdict.intVal == 0:
|
||||
localError(result.info, "type class predicate failed")
|
||||
|
||||
@@ -195,6 +195,7 @@ proc semRangeAux(c: PContext, n: PNode, prev: PType): PType =
|
||||
for i in 0..1:
|
||||
if hasGenericArguments(range[i]):
|
||||
result.n.addSon makeStaticExpr(c, range[i])
|
||||
result.flags.incl tfUnresolved
|
||||
else:
|
||||
result.n.addSon semConstExpr(c, range[i])
|
||||
|
||||
|
||||
@@ -445,7 +445,7 @@ proc replaceTypeVarsTAux(cl: var TReplTypeVars, t: PType): PType =
|
||||
elif t.sons[0].kind != tyNone:
|
||||
result = makeTypeDesc(cl.c, replaceTypeVarsT(cl, t.sons[0]))
|
||||
|
||||
of tyUserTypeClass:
|
||||
of tyUserTypeClass, tyStatic:
|
||||
result = t
|
||||
|
||||
of tyGenericInst, tyUserTypeClassInst:
|
||||
@@ -502,8 +502,9 @@ proc initTypeVars*(p: PContext, pt: TIdTable, info: TLineInfo;
|
||||
result.owner = owner
|
||||
|
||||
proc replaceTypesInBody*(p: PContext, pt: TIdTable, n: PNode;
|
||||
owner: PSym): PNode =
|
||||
owner: PSym, allowMetaTypes = false): PNode =
|
||||
var cl = initTypeVars(p, pt, n.info, owner)
|
||||
cl.allowMetaTypes = allowMetaTypes
|
||||
pushInfoContext(n.info)
|
||||
result = replaceTypeVarsN(cl, n)
|
||||
popInfoContext()
|
||||
|
||||
@@ -681,16 +681,125 @@ proc maybeSkipDistinct(t: PType, callee: PSym): PType =
|
||||
else:
|
||||
result = t
|
||||
|
||||
proc tryResolvingStaticExpr(c: var TCandidate, n: PNode): PNode =
|
||||
proc tryResolvingStaticExpr(c: var TCandidate, n: PNode,
|
||||
allowUnresolved = false): PNode =
|
||||
# Consider this example:
|
||||
# type Value[N: static[int]] = object
|
||||
# proc foo[N](a: Value[N], r: range[0..(N-1)])
|
||||
# Here, N-1 will be initially nkStaticExpr that can be evaluated only after
|
||||
# N is bound to a concrete value during the matching of the first param.
|
||||
# This proc is used to evaluate such static expressions.
|
||||
let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil)
|
||||
let instantiated = replaceTypesInBody(c.c, c.bindings, n, nil,
|
||||
allowMetaTypes = allowUnresolved)
|
||||
result = c.c.semExpr(c.c, instantiated)
|
||||
|
||||
proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType =
|
||||
# This is a simple integer arithimetic equation solver,
|
||||
# capable of deriving the value of a static parameter in
|
||||
# expressions such as (N + 5) / 2 = rhs
|
||||
#
|
||||
# Preconditions:
|
||||
#
|
||||
# * The input of this proc must be semantized
|
||||
# - all templates should be expanded
|
||||
# - aby constant folding possible should already be performed
|
||||
#
|
||||
# * There must be exactly one unresolved static parameter
|
||||
#
|
||||
# Result:
|
||||
#
|
||||
# The proc will return the inferred static type with the `n` field
|
||||
# populated with the inferred value.
|
||||
#
|
||||
# `nil` will be returned if the inference was not possible
|
||||
#
|
||||
if lhs.kind in nkCallKinds and lhs[0].kind == nkSym:
|
||||
case lhs[0].sym.magic
|
||||
of mUnaryLt:
|
||||
return inferStaticParam(lhs[1], rhs + 1)
|
||||
|
||||
of mAddI, mAddU, mInc, mSucc:
|
||||
if lhs[1].kind == nkIntLit:
|
||||
return inferStaticParam(lhs[2], rhs - lhs[1].intVal)
|
||||
elif lhs[2].kind == nkIntLit:
|
||||
return inferStaticParam(lhs[1], rhs - lhs[2].intVal)
|
||||
|
||||
of mDec, mSubI, mSubU, mPred:
|
||||
if lhs[1].kind == nkIntLit:
|
||||
return inferStaticParam(lhs[2], lhs[1].intVal - rhs)
|
||||
elif lhs[2].kind == nkIntLit:
|
||||
return inferStaticParam(lhs[1], rhs + lhs[2].intVal)
|
||||
|
||||
of mMulI, mMulU:
|
||||
if lhs[1].kind == nkIntLit:
|
||||
if rhs mod lhs[1].intVal == 0:
|
||||
return inferStaticParam(lhs[2], rhs div lhs[1].intVal)
|
||||
elif lhs[2].kind == nkIntLit:
|
||||
if rhs mod lhs[2].intVal == 0:
|
||||
return inferStaticParam(lhs[1], rhs div lhs[2].intVal)
|
||||
|
||||
of mDivI, mDivU:
|
||||
if lhs[1].kind == nkIntLit:
|
||||
if lhs[1].intVal mod rhs == 0:
|
||||
return inferStaticParam(lhs[2], lhs[1].intVal div rhs)
|
||||
elif lhs[2].kind == nkIntLit:
|
||||
return inferStaticParam(lhs[1], lhs[2].intVal * rhs)
|
||||
|
||||
of mShlI:
|
||||
if lhs[2].kind == nkIntLit:
|
||||
return inferStaticParam(lhs[1], rhs shr lhs[2].intVal)
|
||||
|
||||
of mShrI:
|
||||
if lhs[2].kind == nkIntLit:
|
||||
return inferStaticParam(lhs[1], rhs shl lhs[2].intVal)
|
||||
|
||||
of mUnaryMinusI:
|
||||
return inferStaticParam(lhs[1], -rhs)
|
||||
|
||||
of mUnaryPlusI, mToInt, mToBiggestInt:
|
||||
return inferStaticParam(lhs[1], rhs)
|
||||
|
||||
else: discard
|
||||
|
||||
elif lhs.kind == nkSym and lhs.typ.kind == tyStatic and lhs.typ.n == nil:
|
||||
lhs.typ.n = newIntNode(nkIntLit, rhs)
|
||||
return lhs.typ
|
||||
|
||||
return nil
|
||||
|
||||
proc failureToInferStaticParam(n: PNode) =
|
||||
let staticParam = n.findUnresolvedStatic
|
||||
let name = if staticParam != nil: staticParam.sym.name.s
|
||||
else: "unknown"
|
||||
localError(n.info, errCannotInferStaticParam, name)
|
||||
|
||||
proc inferStaticsInRange(c: var TCandidate,
|
||||
inferred, concrete: PType): TTypeRelation =
|
||||
let lowerBound = tryResolvingStaticExpr(c, inferred.n[0],
|
||||
allowUnresolved = true)
|
||||
let upperBound = tryResolvingStaticExpr(c, inferred.n[1],
|
||||
allowUnresolved = true)
|
||||
|
||||
template doInferStatic(c: var TCandidate, e: PNode, r: BiggestInt) =
|
||||
var exp = e
|
||||
var rhs = r
|
||||
var inferred = inferStaticParam(exp, rhs)
|
||||
if inferred != nil:
|
||||
put(c.bindings, inferred, inferred)
|
||||
return isGeneric
|
||||
else:
|
||||
failureToInferStaticParam exp
|
||||
|
||||
if lowerBound.kind == nkIntLit:
|
||||
if upperBound.kind == nkIntLit:
|
||||
if lengthOrd(concrete) == upperBound.intVal - lowerBound.intVal + 1:
|
||||
return isGeneric
|
||||
else:
|
||||
return isNone
|
||||
doInferStatic(c, upperBound, lengthOrd(concrete) + lowerBound.intVal - 1)
|
||||
elif upperBound.kind == nkIntLit:
|
||||
doInferStatic(c, lowerBound, upperBound.intVal + 1 - lengthOrd(concrete))
|
||||
|
||||
template subtypeCheck() =
|
||||
if result <= isSubrange and f.lastSon.skipTypes(abstractInst).kind in {tyRef, tyPtr, tyVar}:
|
||||
result = isNone
|
||||
@@ -894,34 +1003,10 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation =
|
||||
a.sons[1].skipTypes({tyTypeDesc}))
|
||||
if result < isGeneric: return isNone
|
||||
|
||||
proc inferStaticRange(c: var TCandidate, inferred, concrete: PType) =
|
||||
var (staticT, offset) = inferred.findUnresolvedStaticInRange
|
||||
var
|
||||
replacementT = newTypeWithSons(c.c, tyStatic, @[tyInt.getSysType])
|
||||
concreteUpperBound = concrete.n[1].intVal
|
||||
# we must correct for the off-by-one discrepancy between
|
||||
# ranges and static params:
|
||||
replacementT.n = newIntNode(nkIntLit, concreteUpperBound + offset)
|
||||
if tfInferrableStatic in staticT.flags:
|
||||
staticT.n = replacementT.n
|
||||
put(c.bindings, staticT, replacementT)
|
||||
|
||||
if rangeHasUnresolvedStatic(fRange):
|
||||
if tfUnresolved in fRange.flags:
|
||||
# This is a range from an array instantiated with a generic
|
||||
# static param. We must extract the static param here and bind
|
||||
# it to the size of the currently supplied array.
|
||||
inferStaticRange(c, fRange, aRange)
|
||||
return isGeneric
|
||||
|
||||
let len = tryResolvingStaticExpr(c, fRange.n[1])
|
||||
if len.kind == nkIntLit and len.intVal+1 == lengthOrd(a):
|
||||
return # if we get this far, the result is already good
|
||||
else:
|
||||
return isNone
|
||||
if fRange.rangeHasUnresolvedStatic:
|
||||
return inferStaticsInRange(c, fRange, a)
|
||||
elif c.c.inTypeClass > 0 and aRange.rangeHasUnresolvedStatic:
|
||||
inferStaticRange(c, aRange, fRange)
|
||||
return isGeneric
|
||||
return inferStaticsInRange(c, aRange, f)
|
||||
elif lengthOrd(fRange) != lengthOrd(a):
|
||||
result = isNone
|
||||
else: discard
|
||||
|
||||
@@ -63,7 +63,7 @@ const
|
||||
abstractVarRange* = {tyGenericInst, tyRange, tyVar, tyDistinct, tyOrdinal,
|
||||
tyTypeDesc, tyAlias, tyInferred}
|
||||
abstractInst* = {tyGenericInst, tyDistinct, tyOrdinal, tyTypeDesc, tyAlias,
|
||||
tyInferred} + tyTypeClasses
|
||||
tyInferred}
|
||||
skipPtrs* = {tyVar, tyPtr, tyRef, tyGenericInst, tyTypeDesc, tyAlias,
|
||||
tyInferred}
|
||||
# typedescX is used if we're sure tyTypeDesc should be included (or skipped)
|
||||
|
||||
@@ -367,8 +367,8 @@ operator and also when types dependent on them are being matched:
|
||||
MyConcept[M, N: static[int]; T] = concept x
|
||||
x.foo(SquareMatrix[N, T]) is array[M, int]
|
||||
|
||||
Nim may include a simple linear equation solver in the future to help us
|
||||
infer static params when arithmetic is involved.
|
||||
The Nim compiler includes a simple linear equation solver, allowing it to
|
||||
infer static params in some situations where integer arithmetic is involved.
|
||||
|
||||
Just like in regular type classes, Nim discriminates between ``bind once``
|
||||
and ``bind many`` types when matching the concept. You can add the ``distinct``
|
||||
|
||||
Reference in New Issue
Block a user