disjoint checker is smarter (and slower)

This commit is contained in:
Araq
2015-09-23 19:47:48 +02:00
parent f937637a92
commit 000c413f35
3 changed files with 143 additions and 9 deletions

View File

@@ -37,7 +37,7 @@ const
someMod = {mModI}
someMax = {mMaxI, mMaxF64}
someMin = {mMinI, mMinF64}
someBinaryOp = someAdd+someSub+someMul+someDiv+someMod+someMax+someMin
someBinaryOp = someAdd+someSub+someMul+someMax+someMin
proc isValue(n: PNode): bool = n.kind in {nkCharLit..nkNilLit}
proc isLocation(n: PNode): bool = not n.isValue
@@ -166,11 +166,21 @@ proc `|+|`(a, b: PNode): PNode =
if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |+| b.intVal
else: result.floatVal = a.floatVal + b.floatVal
proc `|-|`(a, b: PNode): PNode =
result = copyNode(a)
if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |-| b.intVal
else: result.floatVal = a.floatVal - b.floatVal
proc `|*|`(a, b: PNode): PNode =
result = copyNode(a)
if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal |*| b.intVal
else: result.floatVal = a.floatVal * b.floatVal
proc `|div|`(a, b: PNode): PNode =
result = copyNode(a)
if a.kind in {nkCharLit..nkUInt64Lit}: result.intVal = a.intVal div b.intVal
else: result.floatVal = a.floatVal / b.floatVal
proc negate(a, b, res: PNode): PNode =
if b.kind in {nkCharLit..nkUInt64Lit} and b.intVal != low(BiggestInt):
var b = copyNode(b)
@@ -214,10 +224,16 @@ proc reassociation(n: PNode): PNode =
if result[2].isValue and
result[1].getMagic in someAdd and result[1][2].isValue:
result = opAdd.buildCall(result[1][1], result[1][2] |+| result[2])
if result[2].intVal == 0:
result = result[1]
of someMul:
if result[2].isValue and
result[1].getMagic in someMul and result[1][2].isValue:
result = opAdd.buildCall(result[1][1], result[1][2] |*| result[2])
result = opMul.buildCall(result[1][1], result[1][2] |*| result[2])
if result[2].intVal == 1:
result = result[1]
elif result[2].intVal == 0:
result = zero()
else: discard
proc pred(n: PNode): PNode =
@@ -235,7 +251,7 @@ proc canon*(n: PNode): PNode =
result.sons[i] = canon(n.sons[i])
elif n.kind == nkSym and n.sym.kind == skLet and
n.sym.ast.getMagic in (someEq + someAdd + someMul + someMin +
someMax + someHigh + {mUnaryLt} + someSub + someLen):
someMax + someHigh + {mUnaryLt} + someSub + someLen + someDiv):
result = n.sym.ast.copyTree
else:
result = n
@@ -249,7 +265,7 @@ proc canon*(n: PNode): PNode =
# high == len+(-1)
result = opAdd.buildCall(opLen.buildCall(result[1]), minusOne())
of mUnaryLt:
result = buildCall(opAdd, result[1], newIntNode(nkIntLit, -1))
result = buildCall(opAdd, result[1], minusOne())
of someSub:
# x - 4 --> x + (-4)
result = negate(result[1], result[2], result)
@@ -295,6 +311,16 @@ proc canon*(n: PNode): PNode =
if plus != nil and not isLetLocation(x, true):
result = buildCall(result[0].sym, plus, y[1])
else: discard
elif x.isValue and y.getMagic in someAdd and y[2].isValue:
# 0 <= a.len + 3
# -3 <= a.len
result.sons[1] = x |-| y[2]
result.sons[2] = y[1]
elif x.isValue and y.getMagic in someSub and y[2].isValue:
# 0 <= a.len - 3
# 3 <= a.len
result.sons[1] = x |+| y[2]
result.sons[2] = y[1]
else: discard
proc `+@`*(a: PNode; b: BiggestInt): PNode =
@@ -314,6 +340,9 @@ proc usefulFact(n: PNode): PNode =
if isLetLocation(n.sons[1], true) or isLetLocation(n.sons[2], true):
# XXX algebraic simplifications! 'i-1 < a.len' --> 'i < a.len+1'
result = n
elif n[1].getMagic in someLen or n[2].getMagic in someLen:
# XXX Rethink this whole idea of 'usefulFact' for semparallel
result = n
of mIsNil:
if isLetLocation(n.sons[1], false) or isVar(n.sons[1]):
result = n
@@ -367,8 +396,8 @@ proc usefulFact(n: PNode): PNode =
type
TModel* = seq[PNode] # the "knowledge base"
proc addFact*(m: var TModel, n: PNode) =
let n = usefulFact(n)
proc addFact*(m: var TModel, nn: PNode) =
let n = usefulFact(nn)
if n != nil: m.add n
proc addFactNeg*(m: var TModel, n: PNode) =
@@ -698,10 +727,57 @@ proc simpleSlice*(a, b: PNode): BiggestInt =
else:
result = -1
template isMul(x): expr = x.getMagic in someMul
template isDiv(x): expr = x.getMagic in someDiv
template isAdd(x): expr = x.getMagic in someAdd
template isSub(x): expr = x.getMagic in someSub
template isVal(x): expr = x.kind in {nkCharLit..nkUInt64Lit}
template isIntVal(x, y): expr = x.intVal == y
import macros
macro `=~`(x: PNode, pat: untyped): bool =
proc m(x, pat, conds: NimNode) =
case pat.kind
of nnkInfix:
case $pat[0]
of "*": conds.add getAst(isMul(x))
of "/": conds.add getAst(isDiv(x))
of "+": conds.add getAst(isAdd(x))
of "-": conds.add getAst(isSub(x))
else:
error("invalid pattern")
m(newTree(nnkBracketExpr, x, newLit(1)), pat[1], conds)
m(newTree(nnkBracketExpr, x, newLit(2)), pat[2], conds)
of nnkPar:
if pat.len == 1:
m(x, pat[0], conds)
else:
error("invalid pattern")
of nnkIdent:
let c = newTree(nnkStmtListExpr, newLetStmt(pat, x))
conds.add c
if ($pat)[^1] == 'c': c.add(getAst(isVal(pat)))
else: c.add bindSym"true"
of nnkIntLit:
conds.add(getAst(isIntVal(pat.intVal)))
else:
error("invalid pattern")
var conds = newTree(nnkBracket)
m(x, pat, conds)
result = nestList(!"and", conds)
proc isMinusOne(n: PNode): bool =
n.kind in {nkCharLit..nkUInt64Lit} and n.intVal == -1
proc pleViaModel(model: TModel; aa, bb: PNode): TImplication
proc ple(m: TModel; a, b: PNode): TImplication =
template `<=?`(a,b): expr = ple(m,a,b) == impYes
template `>=?`(a,b): expr = ple(m, nkIntLit.newIntNode(b), a) == impYes
# 0 <= 3
if a.isValue and b.isValue:
@@ -732,14 +808,44 @@ proc ple(m: TModel; a, b: PNode): TImplication =
if b.getMagic in someMul:
if a <=? b[1] and one() <=? b[2] and zero() <=? b[1]: return impYes
# x+c <= x+d if c <= d. Same for *, div etc.
if a.getMagic in someMul and a[2].isValue and a[1].getMagic in someDiv and
a[1][2].isValue:
# simplify (x div 4) * 2 <= y to x div (c div d) <= y
if ple(m, buildCall(opDiv, a[1][1], `|div|`(a[1][2], a[2])), b) == impYes:
return impYes
# x*3 + x == x*4. It follows that:
# x*3 + y <= x*4 if y <= x and 3 <= 4
if a =~ x*dc + y and b =~ x2*ec:
if sameTree(x, x2):
let ec1 = opAdd.buildCall(ec, minusOne())
if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? x:
return impYes
elif a =~ x*dc and b =~ x2*ec + y:
#echo "BUG cam ehrer e ", a, " <=? ", b
if sameTree(x, x2):
let ec1 = opAdd.buildCall(ec, minusOne())
if x >=? 1 and ec >=? 1 and dc >=? 1 and dc <=? ec1 and y <=? zero():
return impYes
# x+c <= x+d if c <= d. Same for *, - etc.
if a.getMagic in someBinaryOp and a.getMagic == b.getMagic:
if sameTree(a[1], b[1]) and a[2] <=? b[2]: return impYes
elif sameTree(a[2], b[2]) and a[1] <=? b[1]: return impYes
# x div c <= y if 1 <= c and 0 <= y and x <= y:
if a.getMagic in someDiv:
if one() <=? a[2] and zero() <=? b and a[1] <=? b: return impYes
# x div c <= x div d if d <= c
if b.getMagic in someDiv:
if sameTree(a[1], b[1]) and b[2] <=? a[2]: return impYes
# x div z <= x - 1 if z <= x
if a[2].isValue and b.getMagic in someAdd and b[2].isMinusOne:
if a[2] <=? a[1] and sameTree(a[1], b[1]): return impYes
# slightly subtle:
# x <= max(y, z) iff x <= y or x <= z
# note that 'x <= max(x, z)' is a special case of the above rule

View File

@@ -128,10 +128,10 @@ template `?`(x): expr = x.renderTree
proc checkLe(c: AnalysisCtx; a, b: PNode) =
case proveLe(c.guards, a, b)
of impUnknown:
localError(a.info, "cannot prove: " & ?a & " <= " & ?b)
localError(a.info, "cannot prove: " & ?a & " <= " & ?b & " (bounds check)")
of impYes: discard
of impNo:
localError(a.info, "can prove: " & ?a & " > " & ?b)
localError(a.info, "can prove: " & ?a & " > " & ?b & " (bounds check)")
proc checkBounds(c: AnalysisCtx; arr, idx: PNode) =
checkLe(c, arr.lowBound, idx)

View File

@@ -0,0 +1,28 @@
discard """
output: "500"
"""
import threadpool, sequtils
{.experimental.}
proc linearFind(a: openArray[int]; x, offset: int): int =
for i, y in a:
if y == x: return i+offset
result = -1
proc parFind(a: seq[int]; x: int): int =
var results: array[4, int]
parallel:
if a.len >= 4:
let chunk = a.len div 4
results[0] = spawn linearFind(a[0 ..< chunk], x, 0)
results[1] = spawn linearFind(a[chunk ..< chunk*2], x, chunk)
results[2] = spawn linearFind(a[chunk*2 ..< chunk*3], x, chunk*2)
results[3] = spawn linearFind(a[chunk*3 ..< a.len], x, chunk*3)
result = max(results)
let data = toSeq(0..1000)
echo parFind(data, 500)