closure iterators almost work

This commit is contained in:
Araq
2014-01-23 01:41:26 +01:00
parent 37229df7fc
commit 3f87326247
2 changed files with 190 additions and 192 deletions

View File

@@ -130,25 +130,16 @@ type
TOuterContext {.final.} = object
fn: PSym # may also be a module!
currentEnv: PEnv
isIter: bool # first class iterator?
capturedVars, processed: TIntSet
localsToEnv: TIdTable # PSym->PEnv mapping
localsToAccess: TIdNodeTable
lambdasToEnv: TIdTable # PSym->PEnv mapping
up: POuterContext
proc newOuterContext(fn: PSym, up: POuterContext = nil): POuterContext =
new(result)
result.fn = fn
result.capturedVars = initIntSet()
result.processed = initIntSet()
initIdNodeTable(result.localsToAccess)
initIdTable(result.localsToEnv)
initIdTable(result.lambdasToEnv)
proc newInnerContext(fn: PSym): PInnerContext =
new(result)
result.fn = fn
initIdNodeTable(result.localsToAccess)
closureParam, state, resultSym: PSym # only if isIter
tup: PType # only if isIter
proc getStateType(iter: PSym): PType =
var n = newNodeI(nkRange, iter.info)
@@ -162,6 +153,92 @@ proc createStateField(iter: PSym): PSym =
result = newSym(skField, getIdent(":state"), iter, iter.info)
result.typ = getStateType(iter)
proc newIterResult(iter: PSym): PSym =
if resultPos < iter.ast.len:
result = iter.ast.sons[resultPos].sym
else:
# XXX a bit hacky:
result = newSym(skResult, getIdent":result", iter, iter.info)
result.typ = iter.typ.sons[0]
incl(result.flags, sfUsed)
iter.ast.add newSymNode(result)
proc addHiddenParam(routine: PSym, param: PSym) =
var params = routine.ast.sons[paramsPos]
# -1 is correct here as param.position is 0 based but we have at position 0
# some nkEffect node:
param.position = params.len-1
addSon(params, newSymNode(param))
incl(routine.typ.flags, tfCapturesEnv)
#echo "produced environment: ", param.id, " for ", routine.name.s
proc getHiddenParam(routine: PSym): PSym =
let params = routine.ast.sons[paramsPos]
let hidden = lastSon(params)
assert hidden.kind == nkSym
result = hidden.sym
proc getEnvParam(routine: PSym): PSym =
let params = routine.ast.sons[paramsPos]
let hidden = lastSon(params)
if hidden.kind == nkSym and hidden.sym.name.s == paramName:
result = hidden.sym
proc addField(tup: PType, s: PSym) =
var field = newSym(skField, s.name, s.owner, s.info)
let t = skipIntLit(s.typ)
field.typ = t
field.position = sonsLen(tup)
addSon(tup.n, newSymNode(field))
rawAddSon(tup, t)
proc initIterContext(c: POuterContext, iter: PSym) =
c.fn = iter
c.capturedVars = initIntSet()
var cp = getEnvParam(iter)
if cp == nil:
c.tup = newType(tyTuple, iter)
c.tup.n = newNodeI(nkRecList, iter.info)
cp = newSym(skParam, getIdent(paramName), iter, iter.info)
incl(cp.flags, sfFromGeneric)
cp.typ = newType(tyRef, iter)
rawAddSon(cp.typ, c.tup)
addHiddenParam(iter, cp)
c.state = createStateField(iter)
addField(c.tup, c.state)
else:
c.tup = cp.typ.sons[0]
assert c.tup.kind == tyTuple
if c.tup.len > 0:
c.state = c.tup.n[0].sym
else:
c.state = createStateField(iter)
addField(c.tup, c.state)
c.closureParam = cp
if iter.typ.sons[0] != nil:
c.resultSym = newIterResult(iter)
#iter.ast.add(newSymNode(c.resultSym))
proc newOuterContext(fn: PSym, up: POuterContext = nil): POuterContext =
new(result)
result.fn = fn
result.capturedVars = initIntSet()
result.processed = initIntSet()
initIdNodeTable(result.localsToAccess)
initIdTable(result.localsToEnv)
initIdTable(result.lambdasToEnv)
result.isIter = fn.kind == skIterator and fn.typ.callConv == ccClosure
if result.isIter: initIterContext(result, fn)
proc newInnerContext(fn: PSym): PInnerContext =
new(result)
result.fn = fn
initIdNodeTable(result.localsToAccess)
proc newEnv(outerProc: PSym, up: PEnv, n: PNode): PEnv =
new(result)
result.deps = @[]
@@ -171,14 +248,6 @@ proc newEnv(outerProc: PSym, up: PEnv, n: PNode): PEnv =
result.up = up
result.attachedNode = n
proc addField(tup: PType, s: PSym) =
var field = newSym(skField, s.name, s.owner, s.info)
let t = skipIntLit(s.typ)
field.typ = t
field.position = sonsLen(tup)
addSon(tup.n, newSymNode(field))
rawAddSon(tup, t)
proc addCapturedVar(e: PEnv, v: PSym) =
for x in e.capturedVars:
if x == v: return
@@ -221,27 +290,6 @@ proc newCall(a, b: PSym): PNode =
result.add newSymNode(a)
result.add newSymNode(b)
proc addHiddenParam(routine: PSym, param: PSym) =
var params = routine.ast.sons[paramsPos]
# -1 is correct here as param.position is 0 based but we have at position 0
# some nkEffect node:
param.position = params.len-1
addSon(params, newSymNode(param))
incl(routine.typ.flags, tfCapturesEnv)
#echo "produced environment: ", param.id, " for ", routine.name.s
proc getHiddenParam(routine: PSym): PSym =
let params = routine.ast.sons[paramsPos]
let hidden = lastSon(params)
assert hidden.kind == nkSym
result = hidden.sym
proc getEnvParam(routine: PSym): PSym =
let params = routine.ast.sons[paramsPos]
let hidden = lastSon(params)
if hidden.kind == nkSym and hidden.sym.name.s == paramName:
result = hidden.sym
proc isInnerProc(s, outerProc: PSym): bool {.inline.} =
result = (s.kind in {skProc, skMethod, skConverter} or
s.kind == skIterator and s.typ.callConv == ccClosure) and
@@ -334,7 +382,9 @@ proc gatherVars(o: POuterContext, i: PInnerContext, n: PNode) =
var s = n.sym
if interestingVar(s) and i.fn.id != s.owner.id:
captureVar(o, i, s, n.info)
elif isInnerProc(s, o.fn) and tfCapturesEnv in s.typ.flags and s != i.fn:
elif s.kind in {skProc, skMethod, skConverter} and
s.skipGenericOwner == o.fn and
tfCapturesEnv in s.typ.flags and s != i.fn:
# call to some other inner proc; we need to track the dependencies for
# this:
let env = PEnv(idTableGet(o.lambdasToEnv, i.fn))
@@ -342,7 +392,7 @@ proc gatherVars(o: POuterContext, i: PInnerContext, n: PNode) =
if o.currentEnv != env:
discard addDep(o.currentEnv, env, i.fn)
internalError(n.info, "too complex environment handling required")
of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: discard
of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit, nkClosure: discard
else:
for k in countup(0, sonsLen(n) - 1):
gatherVars(o, i, n.sons[k])
@@ -398,7 +448,8 @@ proc transformInnerProc(o: POuterContext, i: PInnerContext, n: PNode): PNode =
of nkLambdaKinds, nkIteratorDef:
if n.typ != nil:
result = transformInnerProc(o, i, n.sons[namePos])
of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef:
of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef,
nkClosure:
# don't recurse here:
discard
else:
@@ -467,7 +518,8 @@ proc searchForInnerProcs(o: POuterContext, n: PNode) =
searchForInnerProcs(o, it.sons[L-1])
else:
internalError(it.info, "transformOuter")
of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef:
of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef,
nkClosure:
# don't recurse here:
# XXX recurse here and setup 'up' pointers
discard
@@ -526,12 +578,61 @@ proc generateClosureCreation(o: POuterContext, scope: PEnv): PNode =
result.add(newAsgnStmt(indirectAccess(env, field, env.info),
newSymNode(getClosureVar(o, e)), env.info))
proc interestingIterVar(s: PSym): bool {.inline.} =
result = s.kind in {skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags
proc transformOuterProc(o: POuterContext, n: PNode): PNode
proc transformYield(c: POuterContext, n: PNode): PNode =
inc c.state.typ.n.sons[1].intVal
let stateNo = c.state.typ.n.sons[1].intVal
var stateAsgnStmt = newNodeI(nkAsgn, n.info)
stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info))
stateAsgnStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt)))
var retStmt = newNodeI(nkReturnStmt, n.info)
if n.sons[0].kind != nkEmpty:
var a = newNodeI(nkAsgn, n.sons[0].info)
var retVal = transformOuterProc(c, n.sons[0])
addSon(a, newSymNode(c.resultSym))
addSon(a, if retVal.isNil: n.sons[0] else: retVal)
retStmt.add(a)
else:
retStmt.add(emptyNode)
var stateLabelStmt = newNodeI(nkState, n.info)
stateLabelStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt)))
result = newNodeI(nkStmtList, n.info)
result.add(stateAsgnStmt)
result.add(retStmt)
result.add(stateLabelStmt)
proc transformReturn(c: POuterContext, n: PNode): PNode =
result = newNodeI(nkStmtList, n.info)
var stateAsgnStmt = newNodeI(nkAsgn, n.info)
stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info))
stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt)))
result.add(stateAsgnStmt)
result.add(n)
proc outerProcSons(o: POuterContext, n: PNode) =
for i in countup(0, sonsLen(n) - 1):
let x = transformOuterProc(o, n.sons[i])
if x != nil: n.sons[i] = x
proc transformOuterProc(o: POuterContext, n: PNode): PNode =
if n == nil: return nil
case n.kind
of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: discard
of nkSym:
var local = n.sym
if o.isIter and interestingIterVar(local) and o.fn.id == local.owner.id:
if not containsOrIncl(o.capturedVars, local.id): addField(o.tup, local)
return indirectAccess(newSymNode(o.closureParam), local, n.info)
var closure = PEnv(idTableGet(o.lambdasToEnv, local))
if closure != nil:
# we need to replace the lambda with '(lambda, env)':
@@ -567,17 +668,44 @@ proc transformOuterProc(o: POuterContext, n: PNode): PNode =
of nkLambdaKinds, nkIteratorDef:
if n.typ != nil:
result = transformOuterProc(o, n.sons[namePos])
of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef:
of nkProcDef, nkMethodDef, nkConverterDef, nkMacroDef, nkTemplateDef,
nkClosure:
# don't recurse here:
discard
of nkHiddenStdConv, nkHiddenSubConv, nkConv:
let x = transformOuterProc(o, n.sons[1])
if x != nil: n.sons[1] = x
result = transformOuterConv(n)
of nkYieldStmt:
if o.isIter: result = transformYield(o, n)
else: outerProcSons(o, n)
of nkReturnStmt:
if o.isIter: result = transformReturn(o, n)
else: outerProcSons(o, n)
else:
for i in countup(0, sonsLen(n) - 1):
let x = transformOuterProc(o, n.sons[i])
if x != nil: n.sons[i] = x
outerProcSons(o, n)
proc liftIterator(c: POuterContext, body: PNode): PNode =
let iter = c.fn
result = newNodeI(nkStmtList, iter.info)
var gs = newNodeI(nkGotoState, iter.info)
gs.add(indirectAccess(newSymNode(c.closureParam), c.state, iter.info))
result.add(gs)
var state0 = newNodeI(nkState, iter.info)
state0.add(newIntNode(nkIntLit, 0))
result.add(state0)
let newBody = transformOuterProc(c, body)
if newBody != nil:
result.add(newBody)
else:
result.add(body)
var stateAsgnStmt = newNodeI(nkAsgn, iter.info)
stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),
c.state,iter.info))
stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt)))
result.add(stateAsgnStmt)
proc liftLambdas*(fn: PSym, body: PNode): PNode =
# XXX gCmd == cmdCompileToJS does not suffice! The compiletime stuff needs
@@ -601,8 +729,11 @@ proc liftLambdas*(fn: PSym, body: PNode): PNode =
if resultPos < sonsLen(ast) and ast.sons[resultPos].kind == nkSym:
idTablePut(o.localsToEnv, ast.sons[resultPos].sym, o.currentEnv)
searchForInnerProcs(o, body)
discard transformOuterProc(o, body)
result = ex
if o.isIter:
result = liftIterator(o, ex)
else:
discard transformOuterProc(o, body)
result = ex
proc liftLambdasForTopLevel*(module: PSym, body: PNode): PNode =
if body.kind == nkEmpty or gCmd == cmdCompileToJS:
@@ -617,144 +748,11 @@ proc liftLambdasForTopLevel*(module: PSym, body: PNode): PNode =
# ------------------- iterator transformation --------------------------------
discard """
iterator chain[S, T](a, b: *S->T, args: *S): T =
for x in a(args): yield x
for x in b(args): yield x
let c = chain(f, g)
for x in c: echo x
# translated to:
let c = chain( (f, newClosure(f)), (g, newClosure(g)), newClosure(chain))
"""
type
TIterContext {.final, pure.} = object
iter, closureParam, state, resultSym: PSym
capturedVars: TIntSet
tup: PType
proc newIterResult(iter: PSym): PSym =
if resultPos < iter.ast.len:
result = iter.ast.sons[resultPos].sym
else:
# XXX a bit hacky:
result = newSym(skResult, getIdent":result", iter, iter.info)
result.typ = iter.typ.sons[0]
incl(result.flags, sfUsed)
iter.ast.add newSymNode(result)
proc interestingIterVar(s: PSym): bool {.inline.} =
result = s.kind in {skVar, skLet, skTemp, skForVar} and sfGlobal notin s.flags
proc transfIterBody(c: var TIterContext, n: PNode): PNode =
# gather used vars for closure generation
if n == nil: return nil
case n.kind
of nkSym:
var s = n.sym
if interestingIterVar(s) and c.iter.id == s.owner.id:
if not containsOrIncl(c.capturedVars, s.id): addField(c.tup, s)
result = indirectAccess(newSymNode(c.closureParam), s, n.info)
of nkEmpty..pred(nkSym), succ(nkSym)..nkNilLit: discard
of nkYieldStmt:
inc c.state.typ.n.sons[1].intVal
let stateNo = c.state.typ.n.sons[1].intVal
var stateAsgnStmt = newNodeI(nkAsgn, n.info)
stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info))
stateAsgnStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt)))
var retStmt = newNodeI(nkReturnStmt, n.info)
if n.sons[0].kind != nkEmpty:
var a = newNodeI(nkAsgn, n.sons[0].info)
var retVal = transfIterBody(c, n.sons[0])
addSon(a, newSymNode(c.resultSym))
addSon(a, if retVal.isNil: n.sons[0] else: retVal)
retStmt.add(a)
else:
retStmt.add(emptyNode)
var stateLabelStmt = newNodeI(nkState, n.info)
stateLabelStmt.add(newIntTypeNode(nkIntLit, stateNo, getSysType(tyInt)))
result = newNodeI(nkStmtList, n.info)
result.add(stateAsgnStmt)
result.add(retStmt)
result.add(stateLabelStmt)
of nkReturnStmt:
result = newNodeI(nkStmtList, n.info)
var stateAsgnStmt = newNodeI(nkAsgn, n.info)
stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),c.state,n.info))
stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt)))
result.add(stateAsgnStmt)
result.add(n)
else:
for i in countup(0, sonsLen(n)-1):
let x = transfIterBody(c, n.sons[i])
if x != nil: n.sons[i] = x
proc initIterContext(c: var TIterContext, iter: PSym) =
c.iter = iter
c.capturedVars = initIntSet()
var cp = getEnvParam(iter)
if cp == nil:
c.tup = newType(tyTuple, iter)
c.tup.n = newNodeI(nkRecList, iter.info)
cp = newSym(skParam, getIdent(paramName), iter, iter.info)
incl(cp.flags, sfFromGeneric)
cp.typ = newType(tyRef, iter)
rawAddSon(cp.typ, c.tup)
addHiddenParam(iter, cp)
c.state = createStateField(iter)
addField(c.tup, c.state)
else:
c.tup = cp.typ.sons[0]
assert c.tup.kind == tyTuple
if c.tup.len > 0:
c.state = c.tup.n[0].sym
else:
c.state = createStateField(iter)
addField(c.tup, c.state)
c.closureParam = cp
if iter.typ.sons[0] != nil:
c.resultSym = newIterResult(iter)
#iter.ast.add(newSymNode(c.resultSym))
proc liftIterator*(iter: PSym, body: PNode): PNode =
var c: TIterContext
initIterContext c, iter
result = newNodeI(nkStmtList, iter.info)
var gs = newNodeI(nkGotoState, iter.info)
gs.add(indirectAccess(newSymNode(c.closureParam), c.state, iter.info))
result.add(gs)
var state0 = newNodeI(nkState, iter.info)
state0.add(newIntNode(nkIntLit, 0))
result.add(state0)
let newBody = transfIterBody(c, body)
if newBody != nil:
result.add(newBody)
else:
result.add(body)
var stateAsgnStmt = newNodeI(nkAsgn, iter.info)
stateAsgnStmt.add(indirectAccess(newSymNode(c.closureParam),
c.state,iter.info))
stateAsgnStmt.add(newIntTypeNode(nkIntLit, -1, getSysType(tyInt)))
result.add(stateAsgnStmt)
proc liftIterSym*(n: PNode): PNode =
# transforms (iter) to (let env = newClosure[iter](); (iter, env))
let iter = n.sym
assert iter.kind == skIterator
if sfClosureCreated in iter.flags: return n
#if sfClosureCreated in iter.flags: return n
result = newNodeIT(nkStmtListExpr, n.info, n.typ)

View File

@@ -113,8 +113,8 @@ proc newAsgnStmt(c: PTransf, le: PNode, ri: PTransNode): PTransNode =
result[1] = ri
proc transformSymAux(c: PTransf, n: PNode): PNode =
if n.sym.kind == skIterator and n.sym.typ.callConv == ccClosure:
return liftIterSym(n)
#if n.sym.kind == skIterator and n.sym.typ.callConv == ccClosure:
# return liftIterSym(n)
var b: PNode
var tc = c.transCon
if sfBorrow in n.sym.flags:
@@ -636,8 +636,8 @@ proc transform(c: PTransf, n: PNode): PTransNode =
s.ast.sons[bodyPos] = n.sons[bodyPos]
#n.sons[bodyPos] = liftLambdas(s, n)
#if n.kind == nkMethodDef: methodDef(s, false)
if n.kind == nkIteratorDef and n.typ != nil:
return liftIterSym(n.sons[namePos]).PTransNode
#if n.kind == nkIteratorDef and n.typ != nil:
# return liftIterSym(n.sons[namePos]).PTransNode
result = PTransNode(n)
of nkMacroDef:
# XXX no proper closure support yet:
@@ -741,8 +741,8 @@ proc transformBody*(module: PSym, n: PNode, prc: PSym): PNode =
var c = openTransf(module, "")
result = processTransf(c, n, prc)
result = liftLambdas(prc, result)
if prc.kind == skIterator and prc.typ.callConv == ccClosure:
result = lambdalifting.liftIterator(prc, result)
#if prc.kind == skIterator and prc.typ.callConv == ccClosure:
# result = lambdalifting.liftIterator(prc, result)
incl(result.flags, nfTransf)
when useEffectSystem: trackProc(prc, result)