mirror of
https://github.com/nim-lang/Nim.git
synced 2026-06-10 13:48:10 +00:00
'parallel' statement: next steps
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#
|
||||
#
|
||||
# The Nimrod Compiler
|
||||
# (c) Copyright 2013 Andreas Rumpf
|
||||
# (c) Copyright 2014 Andreas Rumpf
|
||||
#
|
||||
# See the file "copying.txt", included in this
|
||||
# distribution, for details about the copyright.
|
||||
@@ -165,9 +165,6 @@ proc buildCall(op: PSym; a, b: PNode): PNode =
|
||||
result.sons[1] = a
|
||||
result.sons[2] = b
|
||||
|
||||
proc `+@`*(a: PNode; b: BiggestInt): PNode =
|
||||
(if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a)
|
||||
|
||||
proc `|+|`(a, b: PNode): PNode =
|
||||
result = copyNode(a)
|
||||
if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal
|
||||
@@ -178,22 +175,56 @@ proc `|*|`(a, b: PNode): PNode =
|
||||
if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal
|
||||
else: result.floatVal = a.floatVal * b.floatVal
|
||||
|
||||
proc negate(a, b, res: PNode): PNode =
|
||||
if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt):
|
||||
var b = copyNode(b)
|
||||
b.intVal = -b.intVal
|
||||
if a.kind in {nkCharLit..nkUInt64Lit}:
|
||||
b.intVal = b.intVal |+| a.intVal
|
||||
result = b
|
||||
else:
|
||||
result = buildCall(opAdd, a, b)
|
||||
elif b.kind in {nkFloatLit..nkFloat64Lit}:
|
||||
var b = copyNode(b)
|
||||
b.floatVal = -b.floatVal
|
||||
result = buildCall(opAdd, a, b)
|
||||
else:
|
||||
result = res
|
||||
|
||||
proc zero(): PNode = nkIntLit.newIntNode(0)
|
||||
proc one(): PNode = nkIntLit.newIntNode(1)
|
||||
proc minusOne(): PNode = nkIntLit.newIntNode(-1)
|
||||
|
||||
proc lowBound*(x: PNode): PNode = nkIntLit.newIntNode(firstOrd(x.typ))
|
||||
proc lowBound*(x: PNode): PNode =
|
||||
result = nkIntLit.newIntNode(firstOrd(x.typ))
|
||||
result.info = x.info
|
||||
|
||||
proc highBound*(x: PNode): PNode =
|
||||
if x.typ.skipTypes(abstractInst).kind == tyArray:
|
||||
nkIntLit.newIntNode(lastOrd(x.typ))
|
||||
else:
|
||||
opAdd.buildCall(opLen.buildCall(x), minusOne())
|
||||
result = if x.typ.skipTypes(abstractInst).kind == tyArray:
|
||||
nkIntLit.newIntNode(lastOrd(x.typ))
|
||||
else:
|
||||
opAdd.buildCall(opLen.buildCall(x), minusOne())
|
||||
result.info = x.info
|
||||
|
||||
proc reassociation(n: PNode): PNode =
|
||||
result = n
|
||||
# (foo+5)+5 --> foo+10; same for '*'
|
||||
case result.getMagic
|
||||
of someAdd:
|
||||
if result[2].isValue and
|
||||
result[1].getMagic in someAdd and result[1][2].isValue:
|
||||
result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2])
|
||||
of someMul:
|
||||
if result[2].isValue and
|
||||
result[1].getMagic in someMul and result[1][2].isValue:
|
||||
result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2])
|
||||
else: discard
|
||||
|
||||
proc canon*(n: PNode): PNode =
|
||||
# XXX for now only the new code in 'semparallel' uses this
|
||||
if n.safeLen >= 1:
|
||||
result = newNodeI(n.kind, n.info, n.len)
|
||||
for i in 0 .. < n.safeLen:
|
||||
result = shallowCopy(n)
|
||||
for i in 0 .. < n.len:
|
||||
result.sons[i] = canon(n.sons[i])
|
||||
else:
|
||||
result = n
|
||||
@@ -210,32 +241,12 @@ proc canon*(n: PNode): PNode =
|
||||
result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1))
|
||||
of someSub:
|
||||
# x - 4 --> x + (-4)
|
||||
var b = result[2]
|
||||
if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt):
|
||||
b = copyNode(b)
|
||||
b.intVal = -b.intVal
|
||||
result = buildCall(opAdd, result[1], b)
|
||||
elif b.kind in {nkFloatLit..nkFloat64Lit}:
|
||||
b = copyNode(b)
|
||||
b.floatVal = -b.floatVal
|
||||
result = buildCall(opAdd, result[1], b)
|
||||
result = negate(result[1], result[2], result)
|
||||
of someLen:
|
||||
result.sons[0] = opLen.newSymNode
|
||||
else: discard
|
||||
|
||||
# re-association:
|
||||
# (foo+5)+5 --> foo+10; same for '*'
|
||||
case result.getMagic
|
||||
of someAdd:
|
||||
if result[2].isValue and
|
||||
result[1].getMagic in someAdd and result[1][2].isValue:
|
||||
result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2])
|
||||
of someMul:
|
||||
if result[2].isValue and
|
||||
result[1].getMagic in someMul and result[1][2].isValue:
|
||||
result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2])
|
||||
else: discard
|
||||
|
||||
result = reassociation(result)
|
||||
# most important rule: (x-4) < a.len --> x < a.len+4
|
||||
case result.getMagic
|
||||
of someLe, someLt:
|
||||
@@ -245,21 +256,32 @@ proc canon*(n: PNode): PNode =
|
||||
isLetLocation(x[1], true):
|
||||
case x.getMagic
|
||||
of someSub:
|
||||
result = buildCall(result[0].sym, x[1], opAdd.buildCall(y, x[2]))
|
||||
result = buildCall(result[0].sym, x[1],
|
||||
reassociation(opAdd.buildCall(y, x[2])))
|
||||
of someAdd:
|
||||
result = buildCall(result[0].sym, x[1], opSub.buildCall(y, x[2]))
|
||||
# Rule A:
|
||||
let plus = negate(y, x[2], nil).reassociation
|
||||
if plus != nil: result = buildCall(result[0].sym, x[1], plus)
|
||||
else: discard
|
||||
elif y.kind in nkCallKinds and y.len == 3 and y[2].isValue and
|
||||
isLetLocation(y[1], true):
|
||||
# a.len < x-3
|
||||
case y.getMagic
|
||||
of someSub:
|
||||
result = buildCall(result[0].sym, y[1], opAdd.buildCall(x, y[2]))
|
||||
result = buildCall(result[0].sym, y[1],
|
||||
reassociation(opAdd.buildCall(x, y[2])))
|
||||
of someAdd:
|
||||
result = buildCall(result[0].sym, y[1], opSub.buildCall(x, y[2]))
|
||||
let plus = negate(x, y[2], nil).reassociation
|
||||
# ensure that Rule A will not trigger afterwards with the
|
||||
# additional 'not isLetLocation' constraint:
|
||||
if plus != nil and not isLetLocation(x, true):
|
||||
result = buildCall(result[0].sym, plus, y[1])
|
||||
else: discard
|
||||
else: discard
|
||||
|
||||
proc `+@`*(a: PNode; b: BiggestInt): PNode =
|
||||
canon(if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a)
|
||||
|
||||
proc usefulFact(n: PNode): PNode =
|
||||
case n.getMagic
|
||||
of someEq:
|
||||
@@ -639,8 +661,20 @@ proc doesImply*(facts: TModel, prop: PNode): TImplication =
|
||||
proc impliesNotNil*(facts: TModel, arg: PNode): TImplication =
|
||||
result = doesImply(facts, opIsNil.buildCall(arg).neg)
|
||||
|
||||
proc simpleSlice*(a, b: PNode): BiggestInt =
|
||||
# returns 'c' if a..b matches (i+c)..(i+c), -1 otherwise. (i)..(i) is matched
|
||||
# as if it is (i+0)..(i+0).
|
||||
if guards.sameTree(a, b):
|
||||
if a.getMagic in someAdd and a[2].kind in {nkCharLit..nkUInt64Lit}:
|
||||
result = a[2].intVal
|
||||
else:
|
||||
result = 0
|
||||
else:
|
||||
result = -1
|
||||
|
||||
proc proveLe*(m: TModel; a, b: PNode): TImplication =
|
||||
let res = canon(opLe.buildCall(a, b))
|
||||
#echo renderTree(res)
|
||||
# we hardcode lots of axioms here:
|
||||
let a = res[1]
|
||||
let b = res[2]
|
||||
@@ -662,6 +696,10 @@ proc proveLe*(m: TModel; a, b: PNode): TImplication =
|
||||
if b.getMagic in someAdd and sameTree(a, b[1]):
|
||||
return proveLe(m, zero(), b[2])
|
||||
|
||||
# x+c <= x iff c <= 0
|
||||
if a.getMagic in someAdd and sameTree(b, a[1]):
|
||||
return proveLe(m, a[2], zero())
|
||||
|
||||
# x <= x*c if 1 <= c and 0 <= x:
|
||||
if b.getMagic in someMul and sameTree(a, b[1]):
|
||||
if proveLe(m, one(), b[2]) == impYes and proveLe(m, zero(), a) == impYes:
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
|
||||
## Semantic checking for 'parallel'.
|
||||
|
||||
# - codegen needs to support mSlice
|
||||
# - lowerings must not perform unnecessary copies
|
||||
# - slices should become "nocopy" to openArray (+)
|
||||
# - need to perform bound checks (+)
|
||||
#
|
||||
@@ -153,6 +155,8 @@ proc addLowerBoundAsFacts(c: var AnalysisCtx) =
|
||||
|
||||
proc addSlice(c: var AnalysisCtx; n: PNode; x, le, ri: PNode) =
|
||||
checkLocal(c, n)
|
||||
let le = le.canon
|
||||
let ri = ri.canon
|
||||
# perform static bounds checking here; and not later!
|
||||
let oldState = c.guards.len
|
||||
addLowerBoundAsFacts(c)
|
||||
@@ -166,16 +170,16 @@ proc overlap(m: TModel; x,y,c,d: PNode) =
|
||||
case proveLe(m, x, d)
|
||||
of impUnknown:
|
||||
localError(x.info,
|
||||
"cannot prove: $# > $#; required for $#..$# disjoint from $#..$#" %
|
||||
"cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
|
||||
[?x, ?d, ?x, ?y, ?c, ?d])
|
||||
of impYes:
|
||||
case proveLe(m, c, y)
|
||||
of impUnknown:
|
||||
localError(x.info,
|
||||
"cannot prove: $# > $#; required for $#..$# disjoint from $#..$#" %
|
||||
"cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
|
||||
[?y, ?d, ?x, ?y, ?c, ?d])
|
||||
of impYes:
|
||||
localError(x.info, "$#..$# not disjoint from $#..$#" % [?x, ?y, ?c, ?d])
|
||||
localError(x.info, "($#)..($#) not disjoint from ($#)..($#)" % [?x, ?y, ?c, ?d])
|
||||
of impNo: discard
|
||||
of impNo: discard
|
||||
|
||||
@@ -220,14 +224,25 @@ proc checkSlicesAreDisjoint(c: var AnalysisCtx) =
|
||||
let x = c.slices[i]
|
||||
let y = c.slices[j]
|
||||
if x.spawnId != y.spawnId and guards.sameTree(x.x, y.x):
|
||||
if not x.inLoop and not y.inLoop:
|
||||
if not x.inLoop or not y.inLoop:
|
||||
# XXX strictly speaking, 'or' is not correct here and it needs to
|
||||
# be 'and'. However this prevents too many obviously correct programs
|
||||
# like f(a[0..x]); for i in x+1 .. a.high: f(a[i])
|
||||
overlap(c.guards, x.a, x.b, y.a, y.b)
|
||||
elif (let k = simpleSlice(x.a, x.b); let m = simpleSlice(y.a, y.b);
|
||||
k >= 0 and m >= 0):
|
||||
# ah I cannot resist the temptation and add another sweet heuristic:
|
||||
# if both slices have the form (i+k)..(i+k) and (i+m)..(i+m) we
|
||||
# check they are disjoint and k < stride and m < stride:
|
||||
overlap(c.guards, x.a, x.b, y.a, y.b)
|
||||
let stride = min(c.stride(x.a), c.stride(y.a))
|
||||
if k < stride and m < stride:
|
||||
discard
|
||||
else:
|
||||
localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
|
||||
[?x.a, ?x.b, ?y.a, ?y.b])
|
||||
else:
|
||||
# ah I cannot resists the temptation and add another sweet heuristic:
|
||||
# if both slices have the form (i+c)..(i+c) and (i+d)..(i+d) we
|
||||
# check they are disjoint and c <= stride and d <= stride:
|
||||
# XXX
|
||||
localError(x.x.info, "cannot prove $#..$# disjoint from $#..$#" %
|
||||
localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
|
||||
[?x.a, ?x.b, ?y.a, ?y.b])
|
||||
|
||||
proc analyse(c: var AnalysisCtx; n: PNode)
|
||||
@@ -369,9 +384,9 @@ proc transformSlices(n: PNode): PNode =
|
||||
result.add n[2][2]
|
||||
return result
|
||||
if n.safeLen > 0:
|
||||
result = copyNode(n)
|
||||
result = shallowCopy(n)
|
||||
for i in 0 .. < n.len:
|
||||
result.add transformSlices(n.sons[i])
|
||||
result.sons[i] = transformSlices(n.sons[i])
|
||||
else:
|
||||
result = n
|
||||
|
||||
@@ -383,9 +398,9 @@ proc transformSpawn(owner: PSym; n, barrier: PNode): PNode =
|
||||
result = transformSlices(n)
|
||||
return wrapProcForSpawn(owner, result[1], barrier)
|
||||
elif n.safeLen > 0:
|
||||
result = copyNode(n)
|
||||
result = shallowCopy(n)
|
||||
for i in 0 .. < n.len:
|
||||
result.add transformSpawn(owner, n.sons[i], barrier)
|
||||
result.sons[i] = transformSpawn(owner, n.sons[i], barrier)
|
||||
else:
|
||||
result = n
|
||||
|
||||
|
||||
21
tests/parallel/tdisjoint_slice1.nim
Normal file
21
tests/parallel/tdisjoint_slice1.nim
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
import threadpool
|
||||
|
||||
proc f(a: openArray[int]) =
|
||||
for x in a: echo x
|
||||
|
||||
proc f(a: int) = echo a
|
||||
|
||||
proc main() =
|
||||
var a: array[0..30, int]
|
||||
parallel:
|
||||
#spawn f(a[0..15])
|
||||
#spawn f(a[16..30])
|
||||
var i = 0
|
||||
while i <= 29:
|
||||
spawn f(a[i])
|
||||
spawn f(a[i+1])
|
||||
inc i, 2
|
||||
# is correct here
|
||||
|
||||
main()
|
||||
21
tests/parallel/tdisjoint_slice2.nim
Normal file
21
tests/parallel/tdisjoint_slice2.nim
Normal file
@@ -0,0 +1,21 @@
|
||||
|
||||
import threadpool
|
||||
|
||||
proc f(a: openArray[int]) =
|
||||
for x in a: echo x
|
||||
|
||||
proc f(a: int) = echo a
|
||||
|
||||
proc main() =
|
||||
var a: array[0..30, int]
|
||||
parallel:
|
||||
spawn f(a[0..15])
|
||||
#spawn f(a[16..30])
|
||||
var i = 16
|
||||
while i <= 29:
|
||||
spawn f(a[i])
|
||||
spawn f(a[i+1])
|
||||
inc i, 2
|
||||
# is correct here
|
||||
|
||||
main()
|
||||
25
tests/parallel/tinvalid_array_bounds.nim
Normal file
25
tests/parallel/tinvalid_array_bounds.nim
Normal file
@@ -0,0 +1,25 @@
|
||||
discard """
|
||||
errormsg: "cannot prove: i + 1 <= 30"
|
||||
line: 21
|
||||
"""
|
||||
|
||||
import threadpool
|
||||
|
||||
proc f(a: openArray[int]) =
|
||||
for x in a: echo x
|
||||
|
||||
proc f(a: int) = echo a
|
||||
|
||||
proc main() =
|
||||
var a: array[0..30, int]
|
||||
parallel:
|
||||
spawn f(a[0..15])
|
||||
spawn f(a[16..30])
|
||||
var i = 0
|
||||
while i <= 30:
|
||||
spawn f(a[i])
|
||||
spawn f(a[i+1])
|
||||
inc i
|
||||
#inc i # inc i, 2 would be correct here
|
||||
|
||||
main()
|
||||
26
tests/parallel/tinvalid_counter_usage.nim
Normal file
26
tests/parallel/tinvalid_counter_usage.nim
Normal file
@@ -0,0 +1,26 @@
|
||||
discard """
|
||||
errormsg: "invalid usage of counter after increment"
|
||||
line: 21
|
||||
"""
|
||||
|
||||
import threadpool
|
||||
|
||||
proc f(a: openArray[int]) =
|
||||
for x in a: echo x
|
||||
|
||||
proc f(a: int) = echo a
|
||||
|
||||
proc main() =
|
||||
var a: array[0..30, int]
|
||||
parallel:
|
||||
spawn f(a[0..15])
|
||||
spawn f(a[16..30])
|
||||
var i = 0
|
||||
while i <= 30:
|
||||
inc i
|
||||
spawn f(a[i])
|
||||
inc i
|
||||
#spawn f(a[i+1])
|
||||
#inc i # inc i, 2 would be correct here
|
||||
|
||||
main()
|
||||
25
tests/parallel/tnon_disjoint_slice1.nim
Normal file
25
tests/parallel/tnon_disjoint_slice1.nim
Normal file
@@ -0,0 +1,25 @@
|
||||
discard """
|
||||
errormsg: "cannot prove (i)..(i) disjoint from (i + 1)..(i + 1)"
|
||||
line: 20
|
||||
"""
|
||||
|
||||
import threadpool
|
||||
|
||||
proc f(a: openArray[int]) =
|
||||
for x in a: echo x
|
||||
|
||||
proc f(a: int) = echo a
|
||||
|
||||
proc main() =
|
||||
var a: array[0..30, int]
|
||||
parallel:
|
||||
#spawn f(a[0..15])
|
||||
#spawn f(a[16..30])
|
||||
var i = 0
|
||||
while i <= 29:
|
||||
spawn f(a[i])
|
||||
spawn f(a[i+1])
|
||||
inc i
|
||||
#inc i # inc i, 2 would be correct here
|
||||
|
||||
main()
|
||||
Reference in New Issue
Block a user