fix a compilation error in linalg

This commit is contained in:
Zahary Karadjov
2017-04-16 02:44:58 +03:00
parent bf4ce87e5b
commit dfbafff2e7
2 changed files with 39 additions and 25 deletions

View File

@@ -739,7 +739,7 @@ proc tryResolvingStaticExpr(c: var TCandidate, n: PNode,
allowMetaTypes = allowUnresolved)
result = c.c.semExpr(c.c, instantiated)
proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType =
proc inferStaticParam*(c: var TCandidate, lhs: PNode, rhs: BiggestInt): bool =
# This is a simple integer arithimetic equation solver,
# capable of deriving the value of a static parameter in
# expressions such as (N + 5) / 2 = rhs
@@ -754,64 +754,65 @@ proc inferStaticParam*(lhs: PNode, rhs: BiggestInt): PType =
#
# Result:
#
# The proc will return the inferred static type with the `n` field
# populated with the inferred value.
#
# `nil` will be returned if the inference was not possible
# The proc will return true if the static types was successfully
# inferred. The result will be bound to the original static type
# in the TCandidate.
#
if lhs.kind in nkCallKinds and lhs[0].kind == nkSym:
case lhs[0].sym.magic
of mUnaryLt:
return inferStaticParam(lhs[1], rhs + 1)
return inferStaticParam(c, lhs[1], rhs + 1)
of mAddI, mAddU, mInc, mSucc:
if lhs[1].kind == nkIntLit:
return inferStaticParam(lhs[2], rhs - lhs[1].intVal)
return inferStaticParam(c, lhs[2], rhs - lhs[1].intVal)
elif lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], rhs - lhs[2].intVal)
return inferStaticParam(c, lhs[1], rhs - lhs[2].intVal)
of mDec, mSubI, mSubU, mPred:
if lhs[1].kind == nkIntLit:
return inferStaticParam(lhs[2], lhs[1].intVal - rhs)
return inferStaticParam(c, lhs[2], lhs[1].intVal - rhs)
elif lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], rhs + lhs[2].intVal)
return inferStaticParam(c, lhs[1], rhs + lhs[2].intVal)
of mMulI, mMulU:
if lhs[1].kind == nkIntLit:
if rhs mod lhs[1].intVal == 0:
return inferStaticParam(lhs[2], rhs div lhs[1].intVal)
return inferStaticParam(c, lhs[2], rhs div lhs[1].intVal)
elif lhs[2].kind == nkIntLit:
if rhs mod lhs[2].intVal == 0:
return inferStaticParam(lhs[1], rhs div lhs[2].intVal)
return inferStaticParam(c, lhs[1], rhs div lhs[2].intVal)
of mDivI, mDivU:
if lhs[1].kind == nkIntLit:
if lhs[1].intVal mod rhs == 0:
return inferStaticParam(lhs[2], lhs[1].intVal div rhs)
return inferStaticParam(c, lhs[2], lhs[1].intVal div rhs)
elif lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], lhs[2].intVal * rhs)
return inferStaticParam(c, lhs[1], lhs[2].intVal * rhs)
of mShlI:
if lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], rhs shr lhs[2].intVal)
return inferStaticParam(c, lhs[1], rhs shr lhs[2].intVal)
of mShrI:
if lhs[2].kind == nkIntLit:
return inferStaticParam(lhs[1], rhs shl lhs[2].intVal)
return inferStaticParam(c, lhs[1], rhs shl lhs[2].intVal)
of mUnaryMinusI:
return inferStaticParam(lhs[1], -rhs)
return inferStaticParam(c, lhs[1], -rhs)
of mUnaryPlusI, mToInt, mToBiggestInt:
return inferStaticParam(lhs[1], rhs)
return inferStaticParam(c, lhs[1], rhs)
else: discard
elif lhs.kind == nkSym and lhs.typ.kind == tyStatic and lhs.typ.n == nil:
lhs.typ.n = newIntNode(nkIntLit, rhs)
return lhs.typ
var inferred = newTypeWithSons(c.c, tyStatic, lhs.typ.sons)
inferred.n = newIntNode(nkIntLit, rhs)
put(c, lhs.typ, inferred)
return true
return nil
return false
proc failureToInferStaticParam(n: PNode) =
let staticParam = n.findUnresolvedStatic
@@ -825,13 +826,10 @@ proc inferStaticsInRange(c: var TCandidate,
allowUnresolved = true)
let upperBound = tryResolvingStaticExpr(c, inferred.n[1],
allowUnresolved = true)
template doInferStatic(e: PNode, r: BiggestInt) =
var exp = e
var rhs = r
var inferred = inferStaticParam(exp, rhs)
if inferred != nil:
put(c, inferred, inferred)
if inferStaticParam(c, exp, rhs):
return isGeneric
else:
failureToInferStaticParam exp

View File

@@ -13,3 +13,19 @@ const
]
echo "perm: ", a.perm, " det: ", a.det
# This tests multiple instantiations of a generic
# proc involving static params:
type
Vector64*[N: static[int]] = ref array[N, float64]
Array64[N: static[int]] = array[N, float64]
proc vector*[N: static[int]](xs: Array64[N]): Vector64[N] =
new result
for i in 0 .. < N:
result[i] = xs[i]
let v1 = vector([1.0, 2.0, 3.0, 4.0, 5.0])
let v2 = vector([1.0, 2.0, 3.0, 4.0, 5.0])
let v3 = vector([1.0, 2.0, 3.0, 4.0])