diff --git a/compiler/ast.nim b/compiler/ast.nim index f7681222c6..bd244fb976 100644 --- a/compiler/ast.nim +++ b/compiler/ast.nim @@ -490,6 +490,7 @@ type tfHasStatic tfGenericTypeParam tfImplicitTypeParam + tfInferrableStatic tfWildcard # consider a proc like foo[T, I](x: Type[T, I]) # T and I here can bind to both typedesc and static types # before this is determined, we'll consider them to be a @@ -1073,6 +1074,9 @@ proc isMetaType*(t: PType): bool = (t.kind == tyStatic and t.n == nil) or tfHasMeta in t.flags +proc isUnresolvedStatic*(t: PType): bool = + return t.kind == tyStatic and t.n == nil + proc linkTo*(t: PType, s: PSym): PType {.discardable.} = t.sym = s s.typ = t diff --git a/compiler/ccgtypes.nim b/compiler/ccgtypes.nim index 0bbb6e4145..d62eab8ac7 100644 --- a/compiler/ccgtypes.nim +++ b/compiler/ccgtypes.nim @@ -164,6 +164,9 @@ proc mapType(typ: PType): TCTypeKind = of tySet: result = mapSetType(typ) of tyOpenArray, tyArray, tyVarargs: result = ctArray of tyObject, tyTuple: result = ctStruct + of tyUserTypeClass, tyUserTypeClassInst: + internalAssert typ.isResolvedUserTypeClass + return mapType(typ.lastSon) of tyGenericBody, tyGenericInst, tyGenericParam, tyDistinct, tyOrdinal, tyTypeDesc, tyAlias, tyInferred: result = mapType(lastSon(typ)) diff --git a/compiler/semdata.nim b/compiler/semdata.nim index dcd1e04b4a..5cd7556075 100644 --- a/compiler/semdata.nim +++ b/compiler/semdata.nim @@ -273,7 +273,8 @@ proc newTypeWithSons*(c: PContext, kind: TTypeKind, proc makeStaticExpr*(c: PContext, n: PNode): PNode = result = newNodeI(nkStaticExpr, n.info) result.sons = @[n] - result.typ = newTypeWithSons(c, tyStatic, @[n.typ]) + result.typ = if n.typ != nil and n.typ.kind == tyStatic: n.typ + else: newTypeWithSons(c, tyStatic, @[n.typ]) proc makeAndType*(c: PContext, t1, t2: PType): PType = result = newTypeS(tyAnd, c) @@ -317,16 +318,23 @@ proc makeRangeWithStaticExpr*(c: PContext, n: PNode): PType = let intType = getSysType(tyInt) result = newTypeS(tyRange, c) result.sons = @[intType] + if n.typ.n == nil: result.flags.incl tfUnresolved result.n = newNode(nkRange, n.info, @[ newIntTypeNode(nkIntLit, 0, intType), makeStaticExpr(c, n.nMinusOne)]) -template rangeHasStaticIf*(t: PType): bool = +template rangeHasUnresolvedStatic*(t: PType): bool = # this accepts the ranges's node t.n != nil and t.n.len > 1 and t.n[1].kind == nkStaticExpr -template getStaticTypeFromRange*(t: PType): PType = - t.n[1][0][1].typ +proc findUnresolvedStaticInRange*(t: PType): (PType, int) = + assert t.kind == tyRange + # XXX: This really needs to become more sophisticated + let upperBound = t.n[1] + if upperBound[0].kind == nkCall: + return (upperBound[0][1].typ, 1) + else: + return (upperBound.typ, 0) proc errorType*(c: PContext): PType = ## creates a type representing an error state diff --git a/compiler/semstmts.nim b/compiler/semstmts.nim index 33b684e912..64449dda4b 100644 --- a/compiler/semstmts.nim +++ b/compiler/semstmts.nim @@ -1553,6 +1553,11 @@ proc usesResult(n: PNode): bool = for c in n: if usesResult(c): return true +proc inferConceptStaticParam(c: PContext, typ: PType, n: PNode) = + let res = semConstExpr(c, n) + if not sameType(res.typ, typ.base): localError(n.info, "") + typ.n = res + proc semStmtList(c: PContext, n: PNode, flags: TExprFlags): PNode = # these must be last statements in a block: const @@ -1604,10 +1609,19 @@ proc semStmtList(c: PContext, n: PNode, flags: TExprFlags): PNode = n.typ = n.sons[i].typ return else: - n.sons[i] = semExpr(c, n.sons[i]) - if c.inTypeClass > 0 and n[i].typ != nil: - case n[i].typ.kind + var expr = semExpr(c, n.sons[i]) + n.sons[i] = expr + if c.inTypeClass > 0 and expr.typ != nil: + case expr.typ.kind of tyBool: + if expr.kind == nkInfix and expr[0].sym.name.s == "==": + if expr[1].typ.isUnresolvedStatic: + inferConceptStaticParam(c, expr[1].typ, expr[2]) + continue + elif expr[2].typ.isUnresolvedStatic: + inferConceptStaticParam(c, expr[2].typ, expr[1]) + continue + let verdict = semConstExpr(c, n[i]) if verdict.intVal == 0: localError(result.info, "type class predicate failed") diff --git a/compiler/semtypinst.nim b/compiler/semtypinst.nim index 9a1ace42e2..7e114afb8a 100644 --- a/compiler/semtypinst.nim +++ b/compiler/semtypinst.nim @@ -122,6 +122,7 @@ proc isTypeParam(n: PNode): bool = proc hasGenericArguments*(n: PNode): bool = if n.kind == nkSym: return n.sym.kind == skGenericParam or + tfInferrableStatic in n.sym.typ.flags or (n.sym.kind == skType and n.sym.typ.flags * {tfGenericTypeParam, tfImplicitTypeParam} != {}) else: diff --git a/compiler/sigmatch.nim b/compiler/sigmatch.nim index 0dee70139d..f33ac76e78 100644 --- a/compiler/sigmatch.nim +++ b/compiler/sigmatch.nim @@ -605,7 +605,10 @@ proc matchUserTypeClass*(c: PContext, m: var TCandidate, of tyStatic: param = paramSym skConst param.typ = typ.exactReplica - param.ast = typ.n + if typ.n == nil: + param.typ.flags.incl tfInferrableStatic + else: + param.ast = typ.n of tyUnknown: param = paramSym skVar param.typ = typ.exactReplica @@ -619,8 +622,6 @@ proc matchUserTypeClass*(c: PContext, m: var TCandidate, addDecl(c, param) typeParams.safeAdd((param, typ)) - #echo "A ", param.name.s, " ", typeToString(param.typ), " ", param.kind - for param in body.n[0]: var dummyName: PNode @@ -647,8 +648,6 @@ proc matchUserTypeClass*(c: PContext, m: var TCandidate, dummyParam.typ = dummyType addDecl(c, dummyParam) - #echo "B ", dummyName.ident.s, " ", typeToString(dummyType), " ", dummyparam.kind - var checkedBody = c.semTryExpr(c, body.n[3].copyTree) if checkedBody == nil: return nil @@ -733,9 +732,6 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation = aOrig if aOrig.kind == tyInferred: - # echo "INFER A" - # debug f - # debug aOrig let prev = aOrig.previouslyInferred if prev != nil: return typeRel(c, f, prev) @@ -886,6 +882,7 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation = case a.kind of tyArray: var fRange = f.sons[0] + var aRange = a.sons[0] if fRange.kind == tyGenericParam: var prev = PType(idTableGet(c.bindings, fRange)) if prev == nil: @@ -893,21 +890,28 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation = fRange = a else: fRange = prev - result = typeRel(c, f.sons[1], a.sons[1]) + result = typeRel(c, f.sons[1].skipTypes({tyTypeDesc}), + a.sons[1].skipTypes({tyTypeDesc})) if result < isGeneric: return isNone - if rangeHasStaticIf(fRange): + + proc inferStaticRange(c: var TCandidate, inferred, concrete: PType) = + var (staticT, offset) = inferred.findUnresolvedStaticInRange + var + replacementT = newTypeWithSons(c.c, tyStatic, @[tyInt.getSysType]) + concreteUpperBound = concrete.n[1].intVal + # we must correct for the off-by-one discrepancy between + # ranges and static params: + replacementT.n = newIntNode(nkIntLit, concreteUpperBound + offset) + if tfInferrableStatic in staticT.flags: + staticT.n = replacementT.n + put(c.bindings, staticT, replacementT) + + if rangeHasUnresolvedStatic(fRange): if tfUnresolved in fRange.flags: # This is a range from an array instantiated with a generic # static param. We must extract the static param here and bind # it to the size of the currently supplied array. - var - rangeStaticT = fRange.getStaticTypeFromRange - replacementT = newTypeWithSons(c.c, tyStatic, @[tyInt.getSysType]) - inputUpperBound = a.sons[0].n[1].intVal - # we must correct for the off-by-one discrepancy between - # ranges and static params: - replacementT.n = newIntNode(nkIntLit, inputUpperBound + 1) - put(c, rangeStaticT, replacementT) + inferStaticRange(c, fRange, aRange) return isGeneric let len = tryResolvingStaticExpr(c, fRange.n[1]) @@ -915,6 +919,9 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation = return # if we get this far, the result is already good else: return isNone + elif c.c.inTypeClass > 0 and aRange.rangeHasUnresolvedStatic: + inferStaticRange(c, aRange, fRange) + return isGeneric elif lengthOrd(fRange) != lengthOrd(a): result = isNone else: discard @@ -1307,9 +1314,6 @@ proc typeRel(c: var TCandidate, f, aOrig: PType, doBind = true): TTypeRelation = result = isNone of tyInferred: - # echo "INFER F" - # debug f - # debug a let prev = f.previouslyInferred if prev != nil: result = typeRel(c, prev, a) diff --git a/tests/concepts/tmatrixconcept.nim b/tests/concepts/tmatrixconcept.nim new file mode 100644 index 0000000000..4bc002dd45 --- /dev/null +++ b/tests/concepts/tmatrixconcept.nim @@ -0,0 +1,67 @@ +discard """ +output: "0\n0" +msg: ''' +R=3 C=3 TE=9 FF=14 FC=20 T=int +''' +""" + +import typetraits + +template ok(x) = assert x +template no(x) = assert(not x) + +const C = 10 + +type + Matrix[Rows, Cols, TotalElements, FromFoo, FromConst: static[int]; T] = concept m, var mvar, type M + M.M == Rows + Cols == M.N + M.T is T + + m[int, int] is T + mvar[int, int] = T + + FromConst == C * 2 + + # more complicated static param inference cases + m.data is array[TotalElements, T] + M.foo(array[0..FromFoo, type m[int, 10]]) + + MyMatrix[M, K: static[int]; T] = object + data: array[M*K, T] + +# adaptor for the concept's non-matching expectations +template N(M: type MyMatrix): expr = M.K + +proc `[]`(m: MyMatrix; r, c: int): m.T = + m.data[r * m.K + c] + +proc `[]=`(m: var MyMatrix; r, c: int, v: m.T) = + m.data[r * m.K + c] = v + +proc foo(x: MyMatrix, arr: array[15, x.T]) = discard + +proc matrixProc[R, C, TE, FF, FC, T](m: Matrix[R, C, TE, FF, FC, T]): T = + static: + echo "R=", R, " C=", C, " TE=", TE, " FF=", FF, " FC=", FC, " T=", T.name + + m[0, 0] + +proc myMatrixProc(x: MyMatrix): MyMatrix.T = matrixProc(x) + +var x: MyMatrix[3, 3, int] + +static: + # ok x is Matrix + ok x is Matrix[3, 3, 9, 14, 20, int] + + no x is Matrix[3, 3, 8, 15, 20, int] + no x is Matrix[3, 3, 9, 10, 20, int] + no x is Matrix[3, 3, 9, 15, 21, int] + no x is Matrix[3, 3, 9, 15, 20, float] + no x is Matrix[4, 3, 9, 15, 20, int] + no x is Matrix[3, 4, 9, 15, 20, int] + +echo x.myMatrixProc +echo x.matrixProc + diff --git a/tests/concepts/tvectorspace.nim b/tests/concepts/tvectorspace.nim new file mode 100644 index 0000000000..74423e0d25 --- /dev/null +++ b/tests/concepts/tvectorspace.nim @@ -0,0 +1,15 @@ +type VectorSpace[K] = concept x, y + x + y is type(x) + zero(type(x)) is type(x) + -x is type(x) + x - y is type(x) + var k: K + k * x is type(x) + +proc zero(T: typedesc): T = 0 + +static: + assert float is VectorSpace[float] + # assert float is VectorSpace[int] + # assert int is VectorSpace +