diff --git a/compiler/lowerings.nim b/compiler/lowerings.nim index b159502dc6..d370f21f05 100644 --- a/compiler/lowerings.nim +++ b/compiler/lowerings.nim @@ -197,18 +197,21 @@ proc createNimCreatePromiseCall(prom, threadParam: PNode): PNode = result = newFastAsgnStmt(prom, castExpr) proc createWrapperProc(f: PNode; threadParam, argsParam: PSym; - varSection, call, barrier, prom: PNode): PSym = + varSection, call, barrier, prom: PNode; + spawnKind: TSpawnResult): PSym = var body = newNodeI(nkStmtList, f.info) if barrier != nil: body.add callCodeGenProc("barrierEnter", barrier) body.add varSection - if prom != nil: + if prom != nil and spawnKind != srByVar: body.add createNimCreatePromiseCall(prom, threadParam.newSymNode) if barrier == nil: body.add callCodeGenProc("nimPromiseCreateCondVar", prom) body.add callCodeGenProc("nimArgsPassingDone", threadParam.newSymNode) - if prom != nil: + if spawnKind == srByVar: + body.add newAsgnStmt(genDeref(prom), call) + elif prom != nil: let fk = prom.typ.sons[1].promiseKind if fk == promInvalid: localError(f.info, "cannot create a promise of type: " & @@ -471,9 +474,16 @@ proc wrapProcForSpawn*(owner: PSym; n: PNode; retType: PType; objType.addField(field) promField = newDotExpr(scratchObj, field) promAsExpr = indirectAccess(castExpr, field, n.info) + elif spawnKind == srByVar: + var field = newSym(skField, getIdent"prom", owner, n.info) + field.typ = newType(tyPtr, objType.owner) + field.typ.rawAddSon(retType) + objType.addField(field) + promAsExpr = indirectAccess(castExpr, field, n.info) + result.add newFastAsgnStmt(newDotExpr(scratchObj, field), genAddrOf(dest)) let wrapper = createWrapperProc(fn, threadParam, argsParam, varSection, call, - barrierAsExpr, promAsExpr) + barrierAsExpr, promAsExpr, spawnKind) result.add callCodeGenProc("nimSpawn", wrapper.newSymNode, genAddrOf(scratchObj.newSymNode)) diff --git a/compiler/semdata.nim b/compiler/semdata.nim index 987a70a419..19181d98e0 100644 --- a/compiler/semdata.nim +++ b/compiler/semdata.nim @@ -91,6 +91,7 @@ type generics*: seq[TInstantiationPair] # pending list of instantiated generics to compile lastGenericIdx*: int # used for the generics stack hloLoopDetector*: int # used to prevent endless loops in the HLO + inParallelStmt*: int proc makeInstPair*(s: PSym, inst: PInstantiation): TInstantiationPair = result.genericSym = s diff --git a/compiler/semexprs.nim b/compiler/semexprs.nim index 8f4cce547a..e507e711f3 100644 --- a/compiler/semexprs.nim +++ b/compiler/semexprs.nim @@ -1615,13 +1615,17 @@ proc semMagic(c: PContext, n: PNode, s: PSym, flags: TExprFlags): PNode = result = setMs(n, s) var x = n.lastSon if x.kind == nkDo: x = x.sons[bodyPos] + inc c.inParallelStmt result.sons[1] = semStmt(c, x) + dec c.inParallelStmt of mSpawn: result = setMs(n, s) result.sons[1] = semExpr(c, n.sons[1]) - # later passes may transform the type 'Promise[T]' back into 'T' if not result[1].typ.isEmptyType: - result.typ = createPromise(c, result[1].typ, n.info) + if c.inParallelStmt > 0: + result.typ = result[1].typ + else: + result.typ = createPromise(c, result[1].typ, n.info) else: result = semDirectOp(c, n, flags) proc semWhen(c: PContext, n: PNode, semCheck = true): PNode = diff --git a/tests/parallel/tpi.nim b/tests/parallel/tpi.nim new file mode 100644 index 0000000000..de5aa9a514 --- /dev/null +++ b/tests/parallel/tpi.nim @@ -0,0 +1,22 @@ + +import strutils, math, threadpool + +proc term(k: float): float = 4 * math.pow(-1, k) / (2*k + 1) + +proc piU(n: int): float = + var ch = newSeq[Promise[float]](n+1) + for k in 0..n: + ch[k] = spawn term(float(k)) + for k in 0..n: + result += ^ch[k] + +proc piS(n: int): float = + var ch = newSeq[float](n+1) + parallel: + for k in 0..ch.high: + ch[k] = spawn term(float(k)) + for k in 0..ch.high: + result += ch[k] + +echo formatFloat(piU(5000)) +echo formatFloat(piS(5000))