'parallel' statement: next steps

This commit is contained in:
Araq
2014-05-14 23:36:28 +02:00
parent c43e8df90c
commit 31b8fd66b1
7 changed files with 221 additions and 50 deletions

View File

@@ -1,7 +1,7 @@
#
#
# The Nimrod Compiler
# (c) Copyright 2013 Andreas Rumpf
# (c) Copyright 2014 Andreas Rumpf
#
# See the file "copying.txt", included in this
# distribution, for details about the copyright.
@@ -165,9 +165,6 @@ proc buildCall(op: PSym; a, b: PNode): PNode =
result.sons[1] = a
result.sons[2] = b
proc `+@`*(a: PNode; b: BiggestInt): PNode =
(if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a)
proc `|+|`(a, b: PNode): PNode =
result = copyNode(a)
if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal
@@ -178,22 +175,56 @@ proc `|*|`(a, b: PNode): PNode =
if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal
else: result.floatVal = a.floatVal * b.floatVal
proc negate(a, b, res: PNode): PNode =
if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt):
var b = copyNode(b)
b.intVal = -b.intVal
if a.kind in {nkCharLit..nkUInt64Lit}:
b.intVal = b.intVal |+| a.intVal
result = b
else:
result = buildCall(opAdd, a, b)
elif b.kind in {nkFloatLit..nkFloat64Lit}:
var b = copyNode(b)
b.floatVal = -b.floatVal
result = buildCall(opAdd, a, b)
else:
result = res
proc zero(): PNode = nkIntLit.newIntNode(0)
proc one(): PNode = nkIntLit.newIntNode(1)
proc minusOne(): PNode = nkIntLit.newIntNode(-1)
proc lowBound*(x: PNode): PNode = nkIntLit.newIntNode(firstOrd(x.typ))
proc lowBound*(x: PNode): PNode =
result = nkIntLit.newIntNode(firstOrd(x.typ))
result.info = x.info
proc highBound*(x: PNode): PNode =
if x.typ.skipTypes(abstractInst).kind == tyArray:
nkIntLit.newIntNode(lastOrd(x.typ))
else:
opAdd.buildCall(opLen.buildCall(x), minusOne())
result = if x.typ.skipTypes(abstractInst).kind == tyArray:
nkIntLit.newIntNode(lastOrd(x.typ))
else:
opAdd.buildCall(opLen.buildCall(x), minusOne())
result.info = x.info
proc reassociation(n: PNode): PNode =
result = n
# (foo+5)+5 --> foo+10; same for '*'
case result.getMagic
of someAdd:
if result[2].isValue and
result[1].getMagic in someAdd and result[1][2].isValue:
result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2])
of someMul:
if result[2].isValue and
result[1].getMagic in someMul and result[1][2].isValue:
result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2])
else: discard
proc canon*(n: PNode): PNode =
# XXX for now only the new code in 'semparallel' uses this
if n.safeLen >= 1:
result = newNodeI(n.kind, n.info, n.len)
for i in 0 .. < n.safeLen:
result = shallowCopy(n)
for i in 0 .. < n.len:
result.sons[i] = canon(n.sons[i])
else:
result = n
@@ -210,32 +241,12 @@ proc canon*(n: PNode): PNode =
result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1))
of someSub:
# x - 4 --> x + (-4)
var b = result[2]
if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt):
b = copyNode(b)
b.intVal = -b.intVal
result = buildCall(opAdd, result[1], b)
elif b.kind in {nkFloatLit..nkFloat64Lit}:
b = copyNode(b)
b.floatVal = -b.floatVal
result = buildCall(opAdd, result[1], b)
result = negate(result[1], result[2], result)
of someLen:
result.sons[0] = opLen.newSymNode
else: discard
# re-association:
# (foo+5)+5 --> foo+10; same for '*'
case result.getMagic
of someAdd:
if result[2].isValue and
result[1].getMagic in someAdd and result[1][2].isValue:
result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2])
of someMul:
if result[2].isValue and
result[1].getMagic in someMul and result[1][2].isValue:
result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2])
else: discard
result = reassociation(result)
# most important rule: (x-4) < a.len --> x < a.len+4
case result.getMagic
of someLe, someLt:
@@ -245,21 +256,32 @@ proc canon*(n: PNode): PNode =
isLetLocation(x[1], true):
case x.getMagic
of someSub:
result = buildCall(result[0].sym, x[1], opAdd.buildCall(y, x[2]))
result = buildCall(result[0].sym, x[1],
reassociation(opAdd.buildCall(y, x[2])))
of someAdd:
result = buildCall(result[0].sym, x[1], opSub.buildCall(y, x[2]))
# Rule A:
let plus = negate(y, x[2], nil).reassociation
if plus != nil: result = buildCall(result[0].sym, x[1], plus)
else: discard
elif y.kind in nkCallKinds and y.len == 3 and y[2].isValue and
isLetLocation(y[1], true):
# a.len < x-3
case y.getMagic
of someSub:
result = buildCall(result[0].sym, y[1], opAdd.buildCall(x, y[2]))
result = buildCall(result[0].sym, y[1],
reassociation(opAdd.buildCall(x, y[2])))
of someAdd:
result = buildCall(result[0].sym, y[1], opSub.buildCall(x, y[2]))
let plus = negate(x, y[2], nil).reassociation
# ensure that Rule A will not trigger afterwards with the
# additional 'not isLetLocation' constraint:
if plus != nil and not isLetLocation(x, true):
result = buildCall(result[0].sym, plus, y[1])
else: discard
else: discard
proc `+@`*(a: PNode; b: BiggestInt): PNode =
canon(if b != 0: opAdd.buildCall(a, nkIntLit.newIntNode(b)) else: a)
proc usefulFact(n: PNode): PNode =
case n.getMagic
of someEq:
@@ -639,8 +661,20 @@ proc doesImply*(facts: TModel, prop: PNode): TImplication =
proc impliesNotNil*(facts: TModel, arg: PNode): TImplication =
result = doesImply(facts, opIsNil.buildCall(arg).neg)
proc simpleSlice*(a, b: PNode): BiggestInt =
# returns 'c' if a..b matches (i+c)..(i+c), -1 otherwise. (i)..(i) is matched
# as if it is (i+0)..(i+0).
if guards.sameTree(a, b):
if a.getMagic in someAdd and a[2].kind in {nkCharLit..nkUInt64Lit}:
result = a[2].intVal
else:
result = 0
else:
result = -1
proc proveLe*(m: TModel; a, b: PNode): TImplication =
let res = canon(opLe.buildCall(a, b))
#echo renderTree(res)
# we hardcode lots of axioms here:
let a = res[1]
let b = res[2]
@@ -662,6 +696,10 @@ proc proveLe*(m: TModel; a, b: PNode): TImplication =
if b.getMagic in someAdd and sameTree(a, b[1]):
return proveLe(m, zero(), b[2])
# x+c <= x iff c <= 0
if a.getMagic in someAdd and sameTree(b, a[1]):
return proveLe(m, a[2], zero())
# x <= x*c if 1 <= c and 0 <= x:
if b.getMagic in someMul and sameTree(a, b[1]):
if proveLe(m, one(), b[2]) == impYes and proveLe(m, zero(), a) == impYes:

View File

@@ -9,6 +9,8 @@
## Semantic checking for 'parallel'.
# - codegen needs to support mSlice
# - lowerings must not perform unnecessary copies
# - slices should become "nocopy" to openArray (+)
# - need to perform bound checks (+)
#
@@ -153,6 +155,8 @@ proc addLowerBoundAsFacts(c: var AnalysisCtx) =
proc addSlice(c: var AnalysisCtx; n: PNode; x, le, ri: PNode) =
checkLocal(c, n)
let le = le.canon
let ri = ri.canon
# perform static bounds checking here; and not later!
let oldState = c.guards.len
addLowerBoundAsFacts(c)
@@ -166,16 +170,16 @@ proc overlap(m: TModel; x,y,c,d: PNode) =
case proveLe(m, x, d)
of impUnknown:
localError(x.info,
"cannot prove: $# > $#; required for $#..$# disjoint from $#..$#" %
"cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
[?x, ?d, ?x, ?y, ?c, ?d])
of impYes:
case proveLe(m, c, y)
of impUnknown:
localError(x.info,
"cannot prove: $# > $#; required for $#..$# disjoint from $#..$#" %
"cannot prove: $# > $#; required for ($#)..($#) disjoint from ($#)..($#)" %
[?y, ?d, ?x, ?y, ?c, ?d])
of impYes:
localError(x.info, "$#..$# not disjoint from $#..$#" % [?x, ?y, ?c, ?d])
localError(x.info, "($#)..($#) not disjoint from ($#)..($#)" % [?x, ?y, ?c, ?d])
of impNo: discard
of impNo: discard
@@ -220,14 +224,25 @@ proc checkSlicesAreDisjoint(c: var AnalysisCtx) =
let x = c.slices[i]
let y = c.slices[j]
if x.spawnId != y.spawnId and guards.sameTree(x.x, y.x):
if not x.inLoop and not y.inLoop:
if not x.inLoop or not y.inLoop:
# XXX strictly speaking, 'or' is not correct here and it needs to
# be 'and'. However this prevents too many obviously correct programs
# like f(a[0..x]); for i in x+1 .. a.high: f(a[i])
overlap(c.guards, x.a, x.b, y.a, y.b)
elif (let k = simpleSlice(x.a, x.b); let m = simpleSlice(y.a, y.b);
k >= 0 and m >= 0):
# ah I cannot resist the temptation and add another sweet heuristic:
# if both slices have the form (i+k)..(i+k) and (i+m)..(i+m) we
# check they are disjoint and k < stride and m < stride:
overlap(c.guards, x.a, x.b, y.a, y.b)
let stride = min(c.stride(x.a), c.stride(y.a))
if k < stride and m < stride:
discard
else:
localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
[?x.a, ?x.b, ?y.a, ?y.b])
else:
# ah I cannot resists the temptation and add another sweet heuristic:
# if both slices have the form (i+c)..(i+c) and (i+d)..(i+d) we
# check they are disjoint and c <= stride and d <= stride:
# XXX
localError(x.x.info, "cannot prove $#..$# disjoint from $#..$#" %
localError(x.x.info, "cannot prove ($#)..($#) disjoint from ($#)..($#)" %
[?x.a, ?x.b, ?y.a, ?y.b])
proc analyse(c: var AnalysisCtx; n: PNode)
@@ -369,9 +384,9 @@ proc transformSlices(n: PNode): PNode =
result.add n[2][2]
return result
if n.safeLen > 0:
result = copyNode(n)
result = shallowCopy(n)
for i in 0 .. < n.len:
result.add transformSlices(n.sons[i])
result.sons[i] = transformSlices(n.sons[i])
else:
result = n
@@ -383,9 +398,9 @@ proc transformSpawn(owner: PSym; n, barrier: PNode): PNode =
result = transformSlices(n)
return wrapProcForSpawn(owner, result[1], barrier)
elif n.safeLen > 0:
result = copyNode(n)
result = shallowCopy(n)
for i in 0 .. < n.len:
result.add transformSpawn(owner, n.sons[i], barrier)
result.sons[i] = transformSpawn(owner, n.sons[i], barrier)
else:
result = n

View File

@@ -0,0 +1,21 @@
import threadpool
proc f(a: openArray[int]) =
for x in a: echo x
proc f(a: int) = echo a
proc main() =
var a: array[0..30, int]
parallel:
#spawn f(a[0..15])
#spawn f(a[16..30])
var i = 0
while i <= 29:
spawn f(a[i])
spawn f(a[i+1])
inc i, 2
# is correct here
main()

View File

@@ -0,0 +1,21 @@
import threadpool
proc f(a: openArray[int]) =
for x in a: echo x
proc f(a: int) = echo a
proc main() =
var a: array[0..30, int]
parallel:
spawn f(a[0..15])
#spawn f(a[16..30])
var i = 16
while i <= 29:
spawn f(a[i])
spawn f(a[i+1])
inc i, 2
# is correct here
main()

View File

@@ -0,0 +1,25 @@
discard """
errormsg: "cannot prove: i + 1 <= 30"
line: 21
"""
import threadpool
proc f(a: openArray[int]) =
for x in a: echo x
proc f(a: int) = echo a
proc main() =
var a: array[0..30, int]
parallel:
spawn f(a[0..15])
spawn f(a[16..30])
var i = 0
while i <= 30:
spawn f(a[i])
spawn f(a[i+1])
inc i
#inc i # inc i, 2 would be correct here
main()

View File

@@ -0,0 +1,26 @@
discard """
errormsg: "invalid usage of counter after increment"
line: 21
"""
import threadpool
proc f(a: openArray[int]) =
for x in a: echo x
proc f(a: int) = echo a
proc main() =
var a: array[0..30, int]
parallel:
spawn f(a[0..15])
spawn f(a[16..30])
var i = 0
while i <= 30:
inc i
spawn f(a[i])
inc i
#spawn f(a[i+1])
#inc i # inc i, 2 would be correct here
main()

View File

@@ -0,0 +1,25 @@
discard """
errormsg: "cannot prove (i)..(i) disjoint from (i + 1)..(i + 1)"
line: 20
"""
import threadpool
proc f(a: openArray[int]) =
for x in a: echo x
proc f(a: int) = echo a
proc main() =
var a: array[0..30, int]
parallel:
#spawn f(a[0..15])
#spawn f(a[16..30])
var i = 0
while i <= 29:
spawn f(a[i])
spawn f(a[i+1])
inc i
#inc i # inc i, 2 would be correct here
main()