From 000c413f35d304576d77f3bebcf3193cc99cd411 Mon Sep 17 00:00:00 2001 From: Araq Date: Wed, 23 Sep 2015 19:47:48 +0200 Subject: [PATCH] disjoint checker is smarter (and slower) --- compiler/guards.nim | 120 +++++++++++++++++++++++++++++++++--- compiler/semparallel.nim | 4 +- tests/parallel/tparfind.nim | 28 +++++++++ 3 files changed, 143 insertions(+), 9 deletions(-) create mode 100644 tests/parallel/tparfind.nim diff --git a/compiler/guards.nim b/compiler/guards.nim index eeadcb6c7c..5ad932e484 100644 --- a/compiler/guards.nim +++ b/compiler/guards.nim @@ -37,7 +37,7 @@ const someMod = {mModI} someMax = {mMaxI, mMaxF64} someMin = {mMinI, mMinF64} - someBinaryOp = someAdd+someSub+someMul+someDiv+someMod+someMax+someMin + someBinaryOp = someAdd+someSub+someMul+someMax+someMin proc isValue(n: PNode): bool = n.kind in {nkCharLit..nkNilLit} proc isLocation(n: PNode): bool = not n.isValue @@ -166,11 +166,21 @@ proc `|+|`(a, b: PNode): PNode = if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal else: result.floatVal = a.floatVal + b.floatVal +proc `|-|`(a, b: PNode): PNode = + result = copyNode(a) + if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |-| b.intVal + else: result.floatVal = a.floatVal - b.floatVal + proc `|*|`(a, b: PNode): PNode = result = copyNode(a) if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal else: result.floatVal = a.floatVal * b.floatVal +proc `|div|`(a, b: PNode): PNode = + result = copyNode(a) + if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal div b.intVal + else: result.floatVal = a.floatVal / b.floatVal + proc negate(a, b, res: PNode): PNode = if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt): var b = copyNode(b) @@ -214,10 +224,16 @@ proc reassociation(n: PNode): PNode = if result[2].isValue and result[1].getMagic in someAdd and result[1][2].isValue: result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2]) + if result[2].intVal == 0: + result = result[1] of someMul: if result[2].isValue and result[1].getMagic in someMul and result[1][2].isValue: - result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2]) + result = opMul.buildCall(result[1][1], result[1][2] |*| result[2]) + if result[2].intVal == 1: + result = result[1] + elif result[2].intVal == 0: + result = zero() else: discard proc pred(n: PNode): PNode = @@ -235,7 +251,7 @@ proc canon*(n: PNode): PNode = result.sons[i] = canon(n.sons[i]) elif n.kind == nkSym and n.sym.kind == skLet and n.sym.ast.getMagic in (someEq + someAdd + someMul + someMin + - someMax + someHigh + {mUnaryLt} + someSub + someLen): + someMax + someHigh + {mUnaryLt} + someSub + someLen + someDiv): result = n.sym.ast.copyTree else: result = n @@ -249,7 +265,7 @@ proc canon*(n: PNode): PNode = # high == len+(-1) result = opAdd.buildCall(opLen.buildCall(result[1]), minusOne()) of mUnaryLt: - result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1)) + result = buildCall(opAdd, result[1], minusOne()) of someSub: # x - 4 --> x + (-4) result = negate(result[1], result[2], result) @@ -295,6 +311,16 @@ proc canon*(n: PNode): PNode = if plus != nil and not isLetLocation(x, true): result = buildCall(result[0].sym, plus, y[1]) else: discard + elif x.isValue and y.getMagic in someAdd and y[2].isValue: + # 0 <= a.len + 3 + # -3 <= a.len + result.sons[1] = x |-| y[2] + result.sons[2] = y[1] + elif x.isValue and y.getMagic in someSub and y[2].isValue: + # 0 <= a.len - 3 + # 3 <= a.len + result.sons[1] = x |+| y[2] + result.sons[2] = y[1] else: discard proc `+@`*(a: PNode; b: BiggestInt): PNode = @@ -314,6 +340,9 @@ proc usefulFact(n: PNode): PNode = if isLetLocation(n.sons[1], true) or isLetLocation(n.sons[2], true): # XXX algebraic simplifications! 'i-1 < a.len' --> 'i < a.len+1' result = n + elif n[1].getMagic in someLen or n[2].getMagic in someLen: + # XXX Rethink this whole idea of 'usefulFact' for semparallel + result = n of mIsNil: if isLetLocation(n.sons[1], false) or isVar(n.sons[1]): result = n @@ -367,8 +396,8 @@ proc usefulFact(n: PNode): PNode = type TModel* = seq[PNode] # the "knowledge base" -proc addFact*(m: var TModel, n: PNode) = - let n = usefulFact(n) +proc addFact*(m: var TModel, nn: PNode) = + let n = usefulFact(nn) if n != nil: m.add n proc addFactNeg*(m: var TModel, n: PNode) = @@ -698,10 +727,57 @@ proc simpleSlice*(a, b: PNode): BiggestInt = else: result = -1 + +template isMul(x): expr = x.getMagic in someMul +template isDiv(x): expr = x.getMagic in someDiv +template isAdd(x): expr = x.getMagic in someAdd +template isSub(x): expr = x.getMagic in someSub +template isVal(x): expr = x.kind in {nkCharLit..nkUInt64Lit} +template isIntVal(x, y): expr = x.intVal == y + +import macros + +macro `=~`(x: PNode, pat: untyped): bool = + proc m(x, pat, conds: NimNode) = + case pat.kind + of nnkInfix: + case $pat[0] + of "*": conds.add getAst(isMul(x)) + of "/": conds.add getAst(isDiv(x)) + of "+": conds.add getAst(isAdd(x)) + of "-": conds.add getAst(isSub(x)) + else: + error("invalid pattern") + m(newTree(nnkBracketExpr, x, newLit(1)), pat[1], conds) + m(newTree(nnkBracketExpr, x, newLit(2)), pat[2], conds) + of nnkPar: + if pat.len == 1: + m(x, pat[0], conds) + else: + error("invalid pattern") + of nnkIdent: + let c = newTree(nnkStmtListExpr, newLetStmt(pat, x)) + conds.add c + if ($pat)[^1] == 'c': c.add(getAst(isVal(pat))) + else: c.add bindSym"true" + of nnkIntLit: + conds.add(getAst(isIntVal(pat.intVal))) + else: + error("invalid pattern") + + var conds = newTree(nnkBracket) + m(x, pat, conds) + result = nestList(!"and", conds) + + +proc isMinusOne(n: PNode): bool = + n.kind in {nkCharLit..nkUInt64Lit} and n.intVal == -1 + proc pleViaModel(model: TModel; aa, bb: PNode): TImplication proc ple(m: TModel; a, b: PNode): TImplication = template `<=?`(a,b): expr = ple(m,a,b) == impYes + template `>=?`(a,b): expr = ple(m, nkIntLit.newIntNode(b), a) == impYes # 0 <= 3 if a.isValue and b.isValue: @@ -732,14 +808,44 @@ proc ple(m: TModel; a, b: PNode): TImplication = if b.getMagic in someMul: if a <=? b[1] and one() <=? b[2] and zero() <=? b[1]: return impYes - # x+c <= x+d if c <= d. Same for *, div etc. + + if a.getMagic in someMul and a[2].isValue and a[1].getMagic in someDiv and + a[1][2].isValue: + # simplify (x div 4) * 2 <= y to x div (c div d) <= y + if ple(m, buildCall(opDiv, a[1][1], `|div|`(a[1][2], a[2])), b) == impYes: + return impYes + + # x*3 + x == x*4. It follows that: + # x*3 + y <= x*4 if y <= x and 3 <= 4 + if a =~ x*dc + y and b =~ x2*ec: + if sameTree(x, x2): + let ec1 = opAdd.buildCall(ec, minusOne()) + if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? x: + return impYes + elif a =~ x*dc and b =~ x2*ec + y: + #echo "BUG cam ehrer e ", a, " <=? ", b + if sameTree(x, x2): + let ec1 = opAdd.buildCall(ec, minusOne()) + if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? zero(): + return impYes + + # x+c <= x+d if c <= d. Same for *, - etc. if a.getMagic in someBinaryOp and a.getMagic == b.getMagic: if sameTree(a[1], b[1]) and a[2] <=? b[2]: return impYes + elif sameTree(a[2], b[2]) and a[1] <=? b[1]: return impYes # x div c <= y if 1 <= c and 0 <= y and x <= y: if a.getMagic in someDiv: if one() <=? a[2] and zero() <=? b and a[1] <=? b: return impYes + # x div c <= x div d if d <= c + if b.getMagic in someDiv: + if sameTree(a[1], b[1]) and b[2] <=? a[2]: return impYes + + # x div z <= x - 1 if z <= x + if a[2].isValue and b.getMagic in someAdd and b[2].isMinusOne: + if a[2] <=? a[1] and sameTree(a[1], b[1]): return impYes + # slightly subtle: # x <= max(y, z) iff x <= y or x <= z # note that 'x <= max(x, z)' is a special case of the above rule diff --git a/compiler/semparallel.nim b/compiler/semparallel.nim index c1609a1698..b04ba4657b 100644 --- a/compiler/semparallel.nim +++ b/compiler/semparallel.nim @@ -128,10 +128,10 @@ template `?`(x): expr = x.renderTree proc checkLe(c: AnalysisCtx; a, b: PNode) = case proveLe(c.guards, a, b) of impUnknown: - localError(a.info, "cannot prove: " & ?a & " <= " & ?b) + localError(a.info, "cannot prove: " & ?a & " <= " & ?b & " (bounds check)") of impYes: discard of impNo: - localError(a.info, "can prove: " & ?a & " > " & ?b) + localError(a.info, "can prove: " & ?a & " > " & ?b & " (bounds check)") proc checkBounds(c: AnalysisCtx; arr, idx: PNode) = checkLe(c, arr.lowBound, idx) diff --git a/tests/parallel/tparfind.nim b/tests/parallel/tparfind.nim new file mode 100644 index 0000000000..9de5012f58 --- /dev/null +++ b/tests/parallel/tparfind.nim @@ -0,0 +1,28 @@ +discard """ + output: "500" +""" + +import threadpool, sequtils + +{.experimental.} + +proc linearFind(a: openArray[int]; x, offset: int): int = + for i, y in a: + if y == x: return i+offset + result = -1 + +proc parFind(a: seq[int]; x: int): int = + var results: array[4, int] + parallel: + if a.len >= 4: + let chunk = a.len div 4 + results[0] = spawn linearFind(a[0 ..< chunk], x, 0) + results[1] = spawn linearFind(a[chunk ..< chunk*2], x, chunk) + results[2] = spawn linearFind(a[chunk*2 ..< chunk*3], x, chunk*2) + results[3] = spawn linearFind(a[chunk*3 ..< a.len], x, chunk*3) + result = max(results) + + +let data = toSeq(0..1000) +echo parFind(data, 500) +