mirror of
https://github.com/nim-lang/Nim.git
synced 2025-12-28 08:54:53 +00:00
correct code generation for tforstmt
This commit is contained in:
@@ -160,8 +160,8 @@ We generate roughly this:
|
||||
|
||||
proc f_wrapper(args) =
|
||||
barrierEnter(args.barrier) # for parallel statement
|
||||
var a = args.a # copy strings/seqs; thread transfer; not generated for
|
||||
# the 'parallel' statement
|
||||
var a = args.a # thread transfer; deepCopy or shallowCopy or no copy
|
||||
# depending on whether we're in a 'parallel' statement
|
||||
var b = args.b
|
||||
|
||||
args.prom = nimCreatePromise(thread, sizeof(T)) # optional
|
||||
@@ -199,9 +199,9 @@ proc createNimCreatePromiseCall(prom, threadParam: PNode): PNode =
|
||||
proc createWrapperProc(f: PNode; threadParam, argsParam: PSym;
|
||||
varSection, call, barrier, prom: PNode): PSym =
|
||||
var body = newNodeI(nkStmtList, f.info)
|
||||
body.add varSection
|
||||
if barrier != nil:
|
||||
body.add callCodeGenProc("barrierEnter", barrier)
|
||||
body.add varSection
|
||||
if prom != nil:
|
||||
body.add createNimCreatePromiseCall(prom, threadParam.newSymNode)
|
||||
if barrier == nil:
|
||||
@@ -248,6 +248,17 @@ proc createCastExpr(argsParam: PSym; objType: PType): PNode =
|
||||
result.typ = newType(tyPtr, objType.owner)
|
||||
result.typ.rawAddSon(objType)
|
||||
|
||||
proc addLocalVar(varSection: PNode; owner: PSym; typ: PType; v: PNode): PSym =
|
||||
result = newSym(skTemp, getIdent(genPrefix), owner, varSection.info)
|
||||
result.typ = typ
|
||||
incl(result.flags, sfFromGeneric)
|
||||
|
||||
var vpart = newNodeI(nkIdentDefs, varSection.info, 3)
|
||||
vpart.sons[0] = newSymNode(result)
|
||||
vpart.sons[1] = ast.emptyNode
|
||||
vpart.sons[2] = v
|
||||
varSection.add vpart
|
||||
|
||||
proc setupArgsForConcurrency(n: PNode; objType: PType; scratchObj: PSym,
|
||||
castExpr, call, varSection, result: PNode) =
|
||||
let formals = n[0].typ.n
|
||||
@@ -267,16 +278,8 @@ proc setupArgsForConcurrency(n: PNode; objType: PType; scratchObj: PSym,
|
||||
objType.addField(field)
|
||||
result.add newFastAsgnStmt(newDotExpr(scratchObj, field), n[i])
|
||||
|
||||
var temp = newSym(skTemp, tmpName, objType.owner, n.info)
|
||||
temp.typ = argType
|
||||
incl(temp.flags, sfFromGeneric)
|
||||
|
||||
var vpart = newNodeI(nkIdentDefs, n.info, 3)
|
||||
vpart.sons[0] = newSymNode(temp)
|
||||
vpart.sons[1] = ast.emptyNode
|
||||
vpart.sons[2] = indirectAccess(castExpr, field, n.info)
|
||||
varSection.add vpart
|
||||
|
||||
let temp = addLocalVar(varSection, objType.owner, argType,
|
||||
indirectAccess(castExpr, field, n.info))
|
||||
call.add(newSymNode(temp))
|
||||
|
||||
proc getRoot*(n: PNode): PSym =
|
||||
@@ -310,9 +313,11 @@ proc genHigh(n: PNode): PNode =
|
||||
result.sons[1] = n
|
||||
|
||||
proc setupArgsForParallelism(n: PNode; objType: PType; scratchObj: PSym;
|
||||
castExpr, call, result: PNode) =
|
||||
castExpr, call, varSection, result: PNode) =
|
||||
let formals = n[0].typ.n
|
||||
let tmpName = getIdent(genPrefix)
|
||||
# we need to copy the foreign scratch object fields into local variables
|
||||
# for correctness: These are called 'threadLocal' here.
|
||||
for i in 1 .. <n.len:
|
||||
let n = n[i]
|
||||
let argType = skipTypes(if i < formals.len: formals[i].typ else: n.typ,
|
||||
@@ -344,7 +349,9 @@ proc setupArgsForParallelism(n: PNode; objType: PType; scratchObj: PSym;
|
||||
result.add newFastAsgnStmt(newDotExpr(scratchObj, fieldA), n[2])
|
||||
result.add newFastAsgnStmt(newDotExpr(scratchObj, fieldB), n[3])
|
||||
|
||||
slice.sons[2] = indirectAccess(castExpr, fieldA, n.info)
|
||||
let threadLocal = addLocalVar(varSection, objType.owner, fieldA.typ,
|
||||
indirectAccess(castExpr, fieldA, n.info))
|
||||
slice.sons[2] = threadLocal.newSymNode
|
||||
else:
|
||||
let a = genAddrOf(n)
|
||||
field.typ = a.typ
|
||||
@@ -353,9 +360,12 @@ proc setupArgsForParallelism(n: PNode; objType: PType; scratchObj: PSym;
|
||||
result.add newFastAsgnStmt(newDotExpr(scratchObj, fieldB), genHigh(n))
|
||||
|
||||
slice.sons[2] = newIntLit(0)
|
||||
|
||||
# the array itself does not need to go through a thread local variable:
|
||||
slice.sons[1] = genDeref(indirectAccess(castExpr, field, n.info))
|
||||
slice.sons[3] = indirectAccess(castExpr, fieldB, n.info)
|
||||
|
||||
let threadLocal = addLocalVar(varSection, objType.owner, fieldB.typ,
|
||||
indirectAccess(castExpr, fieldB, n.info))
|
||||
slice.sons[3] = threadLocal.newSymNode
|
||||
call.add slice
|
||||
elif (let size = computeSize(argType); size < 0 or size > 16) and
|
||||
n.getRoot != nil:
|
||||
@@ -364,13 +374,17 @@ proc setupArgsForParallelism(n: PNode; objType: PType; scratchObj: PSym;
|
||||
field.typ = a.typ
|
||||
objType.addField(field)
|
||||
result.add newFastAsgnStmt(newDotExpr(scratchObj, field), a)
|
||||
call.add(genDeref(indirectAccess(castExpr, field, n.info)))
|
||||
let threadLocal = addLocalVar(varSection, objType.owner, field.typ,
|
||||
indirectAccess(castExpr, field, n.info))
|
||||
call.add(genDeref(threadLocal.newSymNode))
|
||||
else:
|
||||
# boring case
|
||||
field.typ = argType
|
||||
objType.addField(field)
|
||||
result.add newFastAsgnStmt(newDotExpr(scratchObj, field), n)
|
||||
call.add(indirectAccess(castExpr, field, n.info))
|
||||
let threadLocal = addLocalVar(varSection, objType.owner, field.typ,
|
||||
indirectAccess(castExpr, field, n.info))
|
||||
call.add(threadLocal.newSymNode)
|
||||
|
||||
proc wrapProcForSpawn*(owner: PSym; n: PNode; retType: PType;
|
||||
barrier, dest: PNode = nil): PNode =
|
||||
@@ -438,7 +452,7 @@ proc wrapProcForSpawn*(owner: PSym; n: PNode; retType: PType;
|
||||
if barrier.isNil:
|
||||
setupArgsForConcurrency(n, objType, scratchObj, castExpr, call, varSection, result)
|
||||
else:
|
||||
setupArgsForParallelism(n, objType, scratchObj, castExpr, call, result)
|
||||
setupArgsForParallelism(n, objType, scratchObj, castExpr, call, varSection, result)
|
||||
|
||||
var barrierAsExpr: PNode = nil
|
||||
if barrier != nil:
|
||||
|
||||
@@ -7,14 +7,15 @@ discard """
|
||||
sortoutput: true
|
||||
"""
|
||||
|
||||
import threadpool, math
|
||||
import threadpool, os
|
||||
|
||||
proc p(x: int) =
|
||||
os.sleep(100 - x*10)
|
||||
echo x
|
||||
|
||||
proc testFor(a, b: int; foo: var openArray[int]) =
|
||||
parallel:
|
||||
for i in max(a, 0) .. min(b, foo.len-1):
|
||||
for i in max(a, 0) .. min(b, foo.high):
|
||||
spawn p(foo[i])
|
||||
|
||||
var arr = [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
|
||||
Reference in New Issue
Block a user