implements higher-order inline iterators and return type inference for iterators

This commit is contained in:
Zahary Karadjov
2014-03-08 22:57:06 +02:00
parent 2cbe46daff
commit 085b339b8b
9 changed files with 119 additions and 27 deletions

View File

@@ -96,8 +96,8 @@ type
errOnlyACallOpCanBeDelegator, errUsingNoSymbol,
errMacroBodyDependsOnGenericTypes,
errDestructorNotGenericEnough,
errXExpectsTwoArguments,
errInlineIteratorsAsProcParams,
errXExpectsTwoArguments,
errXExpectsObjectTypes, errXcanNeverBeOfThisSubtype, errTooManyIterations,
errCannotInterpretNodeX, errFieldXNotFound, errInvalidConversionFromTypeX,
errAssertionFailed, errCannotGenerateCodeForX, errXRequiresOneArgument,
@@ -331,6 +331,8 @@ const
"because the parameter '$1' has a generic type",
errDestructorNotGenericEnough: "Destructor signarue is too specific. " &
"A destructor must be associated will all instantiations of a generic type",
errInlineIteratorsAsProcParams: "inline iterators can be used as parameters only for " &
"templates, macros and other inline iterators",
errXExpectsTwoArguments: "\'$1\' expects two arguments",
errXExpectsObjectTypes: "\'$1\' expects object types",
errXcanNeverBeOfThisSubtype: "\'$1\' can never be of this subtype",

View File

@@ -1119,6 +1119,9 @@ proc asgnToResultVar(c: PContext, n, le, ri: PNode) {.inline.} =
n.sons[0] = x # 'result[]' --> 'result'
n.sons[1] = takeImplicitAddr(c, ri)
template resultTypeIsInferrable(typ: PType): expr =
typ.isMetaType and typ.kind != tyTypeDesc
proc semAsgn(c: PContext, n: PNode): PNode =
checkSonsLen(n, 2)
var a = n.sons[0]
@@ -1170,7 +1173,7 @@ proc semAsgn(c: PContext, n: PNode): PNode =
if lhsIsResult: {efAllowDestructor} else: {})
if lhsIsResult:
n.typ = enforceVoidContext
if lhs.sym.typ.isMetaType and lhs.sym.typ.kind != tyTypeDesc:
if resultTypeIsInferrable(lhs.sym.typ):
if cmpTypes(c, lhs.typ, rhs.typ) == isGeneric:
internalAssert c.p.resultSym != nil
lhs.typ = rhs.typ
@@ -1259,12 +1262,21 @@ proc semYield(c: PContext, n: PNode): PNode =
localError(n.info, errYieldNotAllowedInTryStmt)
elif n.sons[0].kind != nkEmpty:
n.sons[0] = semExprWithType(c, n.sons[0]) # check for type compatibility:
var restype = c.p.owner.typ.sons[0]
var iterType = c.p.owner.typ
var restype = iterType.sons[0]
if restype != nil:
let adjustedRes = if c.p.owner.kind == skIterator: restype.base
else: restype
n.sons[0] = fitNode(c, adjustedRes, n.sons[0])
if n.sons[0].typ == nil: internalError(n.info, "semYield")
if resultTypeIsInferrable(adjustedRes):
let inferred = n.sons[0].typ
if c.p.owner.kind == skIterator:
iterType.sons[0].sons[0] = inferred
else:
iterType.sons[0] = inferred
semYieldVarResult(c, n, adjustedRes)
else:
localError(n.info, errCannotReturnExpr)

View File

@@ -20,7 +20,7 @@ proc instantiateGenericParamList(c: PContext, n: PNode, pt: TIdTable,
if a.kind != nkSym:
internalError(a.info, "instantiateGenericParamList; no symbol")
var q = a.sym
if q.typ.kind notin {tyTypeDesc, tyGenericParam, tyStatic}+tyTypeClasses:
if q.typ.kind notin {tyTypeDesc, tyGenericParam, tyStatic, tyIter}+tyTypeClasses:
continue
var s = newSym(skType, q.name, getCurrOwner(), q.info)
s.flags = s.flags + {sfUsed, sfFromGeneric}

View File

@@ -661,10 +661,18 @@ proc semFor(c: PContext, n: PNode): PNode =
openScope(c)
n.sons[length-2] = semExprNoDeref(c, n.sons[length-2], {efWantIterator})
var call = n.sons[length-2]
if call.kind in nkCallKinds and call.sons[0].typ.callConv == ccClosure:
let isCallExpr = call.kind in nkCallKinds
if isCallExpr and call.sons[0].sym.magic != mNone:
if call.sons[0].sym.magic == mOmpParFor:
result = semForVars(c, n)
result.kind = nkParForStmt
else:
result = semForFields(c, n, call.sons[0].sym.magic)
elif (isCallExpr and call.sons[0].typ.callConv == ccClosure) or
call.typ.kind == tyIter:
# first class iterator:
result = semForVars(c, n)
elif call.kind notin nkCallKinds or call.sons[0].kind != nkSym or
elif not isCallExpr or call.sons[0].kind != nkSym or
call.sons[0].sym.kind notin skIterators:
if length == 3:
n.sons[length-2] = implicitIterator(c, "items", n.sons[length-2])
@@ -673,12 +681,6 @@ proc semFor(c: PContext, n: PNode): PNode =
else:
localError(n.sons[length-2].info, errIteratorExpected)
result = semForVars(c, n)
elif call.sons[0].sym.magic != mNone:
if call.sons[0].sym.magic == mOmpParFor:
result = semForVars(c, n)
result.kind = nkParForStmt
else:
result = semForFields(c, n, call.sons[0].sym.magic)
else:
result = semForVars(c, n)
# propagate any enforced VoidContext:

View File

@@ -721,7 +721,16 @@ proc liftParamType(c: PContext, procKind: TSymKind, genericParams: PNode,
allowMetaTypes = true)
result = newTypeWithSons(c, tyCompositeTypeClass, @[paramType, result])
result = addImplicitGeneric(result)
of tyIter:
if paramType.callConv == ccInline:
if procKind notin {skTemplate, skMacro, skIterator}:
localError(info, errInlineIteratorsAsProcParams)
if paramType.len == 1:
let lifted = liftingWalk(paramType.base)
if lifted != nil: paramType.sons[0] = lifted
result = addImplicitGeneric(paramType)
of tyGenericInst:
if paramType.lastSon.kind == tyUserTypeClass:
var cp = copyType(paramType, getCurrOwner(), false)
@@ -852,7 +861,11 @@ proc semProcTypeNode(c: PContext, n, genericParams: PNode,
if lifted != nil: r = lifted
r.flags.incl tfRetType
r = skipIntLit(r)
if kind == skIterator: r = newTypeWithSons(c, tyIter, @[r])
if kind == skIterator:
# see tchainediterators
# in cases like iterator foo(it: iterator): type(it)
# we don't need to change the return type to iter[T]
if not r.isInlineIterator: r = newTypeWithSons(c, tyIter, @[r])
result.sons[0] = r
res.typ = r
@@ -984,7 +997,8 @@ proc semTypeNode(c: PContext, n: PNode, prev: PType): PType =
of nkTypeOfExpr:
# for ``type(countup(1,3))``, see ``tests/ttoseq``.
checkSonsLen(n, 1)
result = semExprWithType(c, n.sons[0], {efInTypeof}).typ.skipTypes({tyIter})
let typExpr = semExprWithType(c, n.sons[0], {efInTypeof})
result = typExpr.typ.skipTypes({tyIter})
of nkPar:
if sonsLen(n) == 1: result = semTypeNode(c, n.sons[0], prev)
else:
@@ -1103,8 +1117,12 @@ proc semTypeNode(c: PContext, n: PNode, prev: PType): PType =
result = newConstraint(c, tyIter)
else:
result = semProcTypeWithScope(c, n, prev, skClosureIterator)
result.flags.incl(tfIterator)
result.callConv = ccClosure
if n.lastSon.kind == nkPragma and hasPragma(n.lastSon, wInline):
result.kind = tyIter
result.callConv = ccInline
else:
result.flags.incl(tfIterator)
result.callConv = ccClosure
of nkProcTy:
if n.sonsLen == 0:
result = newConstraint(c, tyProc)

View File

@@ -305,6 +305,11 @@ proc skipIntLiteralParams(t: PType) =
if skipped != p:
t.sons[i] = skipped
if i > 0: t.n.sons[i].sym.typ = skipped
# when the typeof operator is used on a static input
# param, the results gets infected with static as well:
if t.sons[0] != nil and t.sons[0].kind == tyStatic:
t.sons[0] = t.sons[0].base
proc propagateFieldFlags(t: PType, n: PNode) =
# This is meant for objects and tuples
@@ -323,7 +328,7 @@ proc replaceTypeVarsTAux(cl: var TReplTypeVars, t: PType): PType =
result = t
if t == nil: return
if t.kind in {tyStatic, tyGenericParam} + tyTypeClasses:
if t.kind in {tyStatic, tyGenericParam, tyIter} + tyTypeClasses:
let lookup = PType(idTableGet(cl.typeMap, t))
if lookup != nil: return lookup

View File

@@ -1014,6 +1014,10 @@ proc localConvMatch(c: PContext, m: var TCandidate, f, a: PType,
result.typ = getInstantiatedType(c, arg, m, base(f))
m.baseTypeMatch = true
proc isInlineIterator*(t: PType): bool =
result = t.kind == tyIter or
(t.kind == tyBuiltInTypeClass and t.base.kind == tyIter)
proc paramTypesMatchAux(m: var TCandidate, f, argType: PType,
argSemantized, argOrig: PNode): PNode =
var
@@ -1021,7 +1025,7 @@ proc paramTypesMatchAux(m: var TCandidate, f, argType: PType,
arg = argSemantized
argType = argType
c = m.c
if tfHasStatic in fMaybeStatic.flags:
# XXX: When implicit statics are the default
# this will be done earlier - we just have to
@@ -1060,7 +1064,14 @@ proc paramTypesMatchAux(m: var TCandidate, f, argType: PType,
return arg.typ.n
else:
return argOrig
if r != isNone and f.isInlineIterator:
var inlined = newTypeS(tyStatic, c)
inlined.sons = @[argType]
inlined.n = argSemantized
put(m.bindings, f, inlined)
return argSemantized
case r
of isConvertible:
inc(m.convMatches)
@@ -1188,7 +1199,9 @@ proc prepareOperand(c: PContext; formal: PType; a: PNode): PNode =
# a.typ == nil is valid
result = a
elif a.typ.isNil:
result = c.semOperand(c, a, {efDetermineType})
let flags = if formal.kind == tyIter: {efDetermineType, efWantIterator}
else: {efDetermineType}
result = c.semOperand(c, a, flags)
else:
result = a

View File

@@ -425,7 +425,7 @@ proc findWrongOwners(c: PTransf, n: PNode) =
x.sym.owner.name.s & " " & getCurrOwner(c).name.s)
else:
for i in 0 .. <safeLen(n): findWrongOwners(c, n.sons[i])
proc transformFor(c: PTransf, n: PNode): PTransNode =
# generate access statements for the parameters (unless they are constant)
# put mapping from formal parameters to actual parameters
@@ -433,12 +433,13 @@ proc transformFor(c: PTransf, n: PNode): PTransNode =
var length = sonsLen(n)
var call = n.sons[length - 2]
if call.kind notin nkCallKinds or call.sons[0].kind != nkSym or
call.sons[0].sym.kind != skIterator:
if call.typ.kind != tyIter and
(call.kind notin nkCallKinds or call.sons[0].kind != nkSym or
call.sons[0].sym.kind != skIterator):
n.sons[length-1] = transformLoopBody(c, n.sons[length-1]).PNode
return lambdalifting.liftForLoop(n).PTransNode
#InternalError(call.info, "transformFor")
#echo "transforming: ", renderTree(n)
result = newTransNode(nkStmtList, n.info, 0)
var loopBody = transformLoopBody(c, n.sons[length-1])
@@ -459,6 +460,7 @@ proc transformFor(c: PTransf, n: PNode): PTransNode =
for i in countup(1, sonsLen(call) - 1):
var arg = transform(c, call.sons[i]).PNode
var formal = skipTypes(iter.typ, abstractInst).n.sons[i].sym
if arg.typ.kind == tyIter: continue
case putArgInto(arg, formal.typ)
of paDirectMapping:
idNodeTablePut(newC.mapping, formal, arg)
@@ -480,7 +482,7 @@ proc transformFor(c: PTransf, n: PNode): PTransNode =
dec(c.inlining)
popInfoContext()
popTransCon(c)
#echo "transformed: ", renderTree(n)
# echo "transformed: ", result.PNode.renderTree
proc getMagicOp(call: PNode): TMagic =
if call.sons[0].kind == nkSym and

View File

@@ -0,0 +1,38 @@
discard """
output: '''16
32
48
64
128
192
'''
"""
iterator gaz(it: iterator{.inline.}): type(it) =
for x in it:
yield x*2
iterator baz(it: iterator{.inline.}): auto =
for x in gaz(it):
yield x*2
type T1 = auto
iterator bar(it: iterator: T1{.inline.}): T1 =
for x in baz(it):
yield x*2
iterator foo[T](x: iterator: T{.inline.}): T =
for e in bar(x):
yield e*2
var s = @[1, 2, 3]
# pass an interator several levels deep:
for x in s.items.foo:
echo x
# use some complex iterator as an input for another one:
for x in s.items.baz.foo:
echo x