correct code generation for tforstmt

This commit is contained in:
Araq
2014-05-30 13:15:54 +02:00
parent 6470bd8f87
commit ea16aca09e
2 changed files with 37 additions and 22 deletions

View File

@@ -160,8 +160,8 @@ We generate roughly this:
proc f_wrapper(args) =
barrierEnter(args.barrier) # for parallel statement
var a = args.a # copy strings/seqs; thread transfer; not generated for
# the 'parallel' statement
var a = args.a # thread transfer; deepCopy or shallowCopy or no copy
# depending on whether we're in a 'parallel' statement
var b = args.b
args.prom = nimCreatePromise(thread, sizeof(T)) # optional
@@ -199,9 +199,9 @@ proc createNimCreatePromiseCall(prom, threadParam: PNode): PNode =
proc createWrapperProc(f: PNode; threadParam, argsParam: PSym;
varSection, call, barrier, prom: PNode): PSym =
var body = newNodeI(nkStmtList, f.info)
body.add varSection
if barrier != nil:
body.add callCodeGenProc("barrierEnter", barrier)
body.add varSection
if prom != nil:
body.add createNimCreatePromiseCall(prom, threadParam.newSymNode)
if barrier == nil:
@@ -248,6 +248,17 @@ proc createCastExpr(argsParam: PSym; objType: PType): PNode =
result.typ = newType(tyPtr, objType.owner)
result.typ.rawAddSon(objType)
proc addLocalVar(varSection: PNode; owner: PSym; typ: PType; v: PNode): PSym =
result = newSym(skTemp, getIdent(genPrefix), owner, varSection.info)
result.typ = typ
incl(result.flags, sfFromGeneric)
var vpart = newNodeI(nkIdentDefs, varSection.info, 3)
vpart.sons[0] = newSymNode(result)
vpart.sons[1] = ast.emptyNode
vpart.sons[2] = v
varSection.add vpart
proc setupArgsForConcurrency(n: PNode; objType: PType; scratchObj: PSym,
castExpr, call, varSection, result: PNode) =
let formals = n[0].typ.n
@@ -267,16 +278,8 @@ proc setupArgsForConcurrency(n: PNode; objType: PType; scratchObj: PSym,
objType.addField(field)
result.add newFastAsgnStmt(newDotExpr(scratchObj, field), n[i])
var temp = newSym(skTemp, tmpName, objType.owner, n.info)
temp.typ = argType
incl(temp.flags, sfFromGeneric)
var vpart = newNodeI(nkIdentDefs, n.info, 3)
vpart.sons[0] = newSymNode(temp)
vpart.sons[1] = ast.emptyNode
vpart.sons[2] = indirectAccess(castExpr, field, n.info)
varSection.add vpart
let temp = addLocalVar(varSection, objType.owner, argType,
indirectAccess(castExpr, field, n.info))
call.add(newSymNode(temp))
proc getRoot*(n: PNode): PSym =
@@ -310,9 +313,11 @@ proc genHigh(n: PNode): PNode =
result.sons[1] = n
proc setupArgsForParallelism(n: PNode; objType: PType; scratchObj: PSym;
castExpr, call, result: PNode) =
castExpr, call, varSection, result: PNode) =
let formals = n[0].typ.n
let tmpName = getIdent(genPrefix)
# we need to copy the foreign scratch object fields into local variables
# for correctness: These are called 'threadLocal' here.
for i in 1 .. <n.len:
let n = n[i]
let argType = skipTypes(if i < formals.len: formals[i].typ else: n.typ,
@@ -344,7 +349,9 @@ proc setupArgsForParallelism(n: PNode; objType: PType; scratchObj: PSym;
result.add newFastAsgnStmt(newDotExpr(scratchObj, fieldA), n[2])
result.add newFastAsgnStmt(newDotExpr(scratchObj, fieldB), n[3])
slice.sons[2] = indirectAccess(castExpr, fieldA, n.info)
let threadLocal = addLocalVar(varSection, objType.owner, fieldA.typ,
indirectAccess(castExpr, fieldA, n.info))
slice.sons[2] = threadLocal.newSymNode
else:
let a = genAddrOf(n)
field.typ = a.typ
@@ -353,9 +360,12 @@ proc setupArgsForParallelism(n: PNode; objType: PType; scratchObj: PSym;
result.add newFastAsgnStmt(newDotExpr(scratchObj, fieldB), genHigh(n))
slice.sons[2] = newIntLit(0)
# the array itself does not need to go through a thread local variable:
slice.sons[1] = genDeref(indirectAccess(castExpr, field, n.info))
slice.sons[3] = indirectAccess(castExpr, fieldB, n.info)
let threadLocal = addLocalVar(varSection, objType.owner, fieldB.typ,
indirectAccess(castExpr, fieldB, n.info))
slice.sons[3] = threadLocal.newSymNode
call.add slice
elif (let size = computeSize(argType); size < 0 or size > 16) and
n.getRoot != nil:
@@ -364,13 +374,17 @@ proc setupArgsForParallelism(n: PNode; objType: PType; scratchObj: PSym;
field.typ = a.typ
objType.addField(field)
result.add newFastAsgnStmt(newDotExpr(scratchObj, field), a)
call.add(genDeref(indirectAccess(castExpr, field, n.info)))
let threadLocal = addLocalVar(varSection, objType.owner, field.typ,
indirectAccess(castExpr, field, n.info))
call.add(genDeref(threadLocal.newSymNode))
else:
# boring case
field.typ = argType
objType.addField(field)
result.add newFastAsgnStmt(newDotExpr(scratchObj, field), n)
call.add(indirectAccess(castExpr, field, n.info))
let threadLocal = addLocalVar(varSection, objType.owner, field.typ,
indirectAccess(castExpr, field, n.info))
call.add(threadLocal.newSymNode)
proc wrapProcForSpawn*(owner: PSym; n: PNode; retType: PType;
barrier, dest: PNode = nil): PNode =
@@ -438,7 +452,7 @@ proc wrapProcForSpawn*(owner: PSym; n: PNode; retType: PType;
if barrier.isNil:
setupArgsForConcurrency(n, objType, scratchObj, castExpr, call, varSection, result)
else:
setupArgsForParallelism(n, objType, scratchObj, castExpr, call, result)
setupArgsForParallelism(n, objType, scratchObj, castExpr, call, varSection, result)
var barrierAsExpr: PNode = nil
if barrier != nil:

View File

@@ -7,14 +7,15 @@ discard """
sortoutput: true
"""
import threadpool, math
import threadpool, os
proc p(x: int) =
os.sleep(100 - x*10)
echo x
proc testFor(a, b: int; foo: var openArray[int]) =
parallel:
for i in max(a, 0) .. min(b, foo.len-1):
for i in max(a, 0) .. min(b, foo.high):
spawn p(foo[i])
var arr = [0, 1, 2, 3, 4, 5, 6, 7]