symmetric difference operation for sets via xor (#24286)

closes https://github.com/nim-lang/RFCs/issues/554

Adds a symmetric difference operation to the language bitset type. This
maps to a simple `xor` operation on the backend and thus is likely
faster than the current alternatives, namely `(a - b) + (b - a)` or `a +
b - a * b`. The compiler VM implementation of bitsets already
implemented this via `symdiffSets` but it was never used.

The standalone binary operation is added to `setutils`, named
`symmetricDifference` in line with [hash
sets](https://nim-lang.org/docs/sets.html#symmetricDifference%2CHashSet%5BA%5D%2CHashSet%5BA%5D).
An operator version `-+-` and an in-place version like `toggle` as
described in the RFC are also added, implemented as trivial sugar.
This commit is contained in:
metagn
2024-10-19 11:07:00 +03:00
committed by GitHub
parent 0a058a6b8f
commit ae9287c4f3
12 changed files with 85 additions and 7 deletions

View File

@@ -14,6 +14,11 @@ rounding guarantees (via the
## Standard library additions and changes
[//]: # "Additions:"
- `setutils.symmetricDifference` along with its operator version
`` setutils.`-+-` `` and in-place version `setutils.toggle` have been added
to more efficiently calculate the symmetric difference of bitsets.
[//]: # "Changes:"
- `std/math` The `^` symbol now supports floating-point as exponent in addition to the Natural type.

View File

@@ -491,7 +491,7 @@ type
mAnd, mOr,
mImplies, mIff, mExists, mForall, mOld,
mEqStr, mLeStr, mLtStr,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet, mXorSet,
mConStrStr, mSlice,
mDotDot, # this one is only necessary to give nice compile time warnings
mFields, mFieldPairs, mOmpParFor,
@@ -559,7 +559,7 @@ const
mStrToStr, mEnumToStr,
mAnd, mOr,
mEqStr, mLeStr, mLtStr,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet,
mEqSet, mLeSet, mLtSet, mMulSet, mPlusSet, mMinusSet, mXorSet,
mConStrStr, mAppendStrCh, mAppendStrStr, mAppendSeqElem,
mInSet, mRepr, mOpenArrayToSeq}

View File

@@ -2044,7 +2044,7 @@ proc genInOp(p: BProc, e: PNode, d: var TLoc) =
proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
const
lookupOpr: array[mLeSet..mMinusSet, string] = [
lookupOpr: array[mLeSet..mXorSet, string] = [
"for ($1 = 0; $1 < $2; $1++) { $n" &
" $3 = (($4[$1] & ~ $5[$1]) == 0);$n" &
" if (!$3) break;}$n",
@@ -2054,7 +2054,8 @@ proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
"if ($3) $3 = (#nimCmpMem($4, $5, $2) != 0);$n",
"&",
"|",
"& ~"]
"& ~",
"^"]
var a, b: TLoc
var i: TLoc
var setType = skipTypes(e[1].typ, abstractVar)
@@ -2085,6 +2086,7 @@ proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
of mMulSet: binaryExpr(p, e, d, "($1 & $2)")
of mPlusSet: binaryExpr(p, e, d, "($1 | $2)")
of mMinusSet: binaryExpr(p, e, d, "($1 & ~ $2)")
of mXorSet: binaryExpr(p, e, d, "($1 ^ $2)")
of mInSet:
genInOp(p, e, d)
else: internalError(p.config, e.info, "genSetOp()")
@@ -2112,7 +2114,7 @@ proc genSetOp(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
var a = initLocExpr(p, e[1])
var b = initLocExpr(p, e[2])
putIntoDest(p, d, e, ropecg(p.module, "(#nimCmpMem($1, $2, $3)==0)", [a.rdCharLoc, b.rdCharLoc, size]))
of mMulSet, mPlusSet, mMinusSet:
of mMulSet, mPlusSet, mMinusSet, mXorSet:
# we inline the simple for loop for better code generation:
i = getTemp(p, getSysType(p.module.g.graph, unknownLineInfo, tyInt)) # our counter
a = initLocExpr(p, e[1])
@@ -2548,7 +2550,7 @@ proc genMagicExpr(p: BProc, e: PNode, d: var TLoc, op: TMagic) =
of mSetLengthStr: genSetLengthStr(p, e, d)
of mSetLengthSeq: genSetLengthSeq(p, e, d)
of mIncl, mExcl, mCard, mLtSet, mLeSet, mEqSet, mMulSet, mPlusSet, mMinusSet,
mInSet:
mInSet, mXorSet:
genSetOp(p, e, d, op)
of mNewString, mNewStringOfCap, mExit, mParseBiggestFloat:
var opr = e[0].sym

View File

@@ -170,3 +170,4 @@ proc initDefines*(symbols: StringTableRef) =
defineSymbol("nimHasGenericsOpenSym3")
defineSymbol("nimHasJsNoLambdaLifting")
defineSymbol("nimHasDefaultFloatRoundtrip")
defineSymbol("nimHasXorSet")

View File

@@ -2458,6 +2458,7 @@ proc genMagic(p: PProc, n: PNode, r: var TCompRes) =
of mMulSet: binaryExpr(p, n, r, "SetMul", "SetMul($1, $2)")
of mPlusSet: binaryExpr(p, n, r, "SetPlus", "SetPlus($1, $2)")
of mMinusSet: binaryExpr(p, n, r, "SetMinus", "SetMinus($1, $2)")
of mXorSet: binaryExpr(p, n, r, "SetXor", "SetXor($1, $2)")
of mIncl: binaryExpr(p, n, r, "", "$1[$2] = true")
of mExcl: binaryExpr(p, n, r, "", "delete $1[$2]")
of mInSet:

View File

@@ -302,6 +302,9 @@ proc evalOp(m: TMagic, n, a, b, c: PNode; idgen: IdGenerator; g: ModuleGraph): P
of mMinusSet:
result = nimsets.diffSets(g.config, a, b)
result.info = n.info
of mXorSet:
result = nimsets.symdiffSets(g.config, a, b)
result.info = n.info
of mConStrStr: result = newStrNodeT(getStrOrChar(a) & getStrOrChar(b), n, g)
of mInSet: result = newIntNodeT(toInt128(ord(inSet(a, b))), n, idgen, g)
of mRepr:

View File

@@ -1276,6 +1276,11 @@ proc rawExecute(c: PCtx, start: int, tos: PStackFrame): TFullReg =
createSet(regs[ra])
move(regs[ra].node.sons,
nimsets.diffSets(c.config, regs[rb].node, regs[rc].node).sons)
of opcXorSet:
decodeBC(rkNode)
createSet(regs[ra])
move(regs[ra].node.sons,
nimsets.symdiffSets(c.config, regs[rb].node, regs[rc].node).sons)
of opcConcatStr:
decodeBC(rkNode)
createStr regs[ra]

View File

@@ -100,7 +100,7 @@ type
opcEqRef, opcEqNimNode, opcSameNodeType,
opcXor, opcNot, opcUnaryMinusInt, opcUnaryMinusFloat, opcBitnotInt,
opcEqStr, opcEqCString, opcLeStr, opcLtStr, opcEqSet, opcLeSet, opcLtSet,
opcMulSet, opcPlusSet, opcMinusSet, opcConcatStr,
opcMulSet, opcPlusSet, opcMinusSet, opcXorSet, opcConcatStr,
opcContainsSet, opcRepr, opcSetLenStr, opcSetLenSeq,
opcIsNil, opcOf, opcIs,
opcParseFloat, opcConv, opcCast,

View File

@@ -1212,6 +1212,7 @@ proc genMagic(c: PCtx; n: PNode; dest: var TDest; flags: TGenFlags = {}, m: TMag
of mMulSet: genBinarySet(c, n, dest, opcMulSet)
of mPlusSet: genBinarySet(c, n, dest, opcPlusSet)
of mMinusSet: genBinarySet(c, n, dest, opcMinusSet)
of mXorSet: genBinarySet(c, n, dest, opcXorSet)
of mConStrStr: genVarargsABC(c, n, dest, opcConcatStr)
of mInSet: genBinarySet(c, n, dest, opcContainsSet)
of mRepr: genUnaryABC(c, n, dest, opcRepr)

View File

@@ -75,3 +75,31 @@ func `[]=`*[T](t: var set[T], key: T, val: bool) {.inline.} =
s[a3] = true
assert s == {a2, a3}
if val: t.incl key else: t.excl key
when defined(nimHasXorSet):
func symmetricDifference*[T](x, y: set[T]): set[T] {.magic: "XorSet".} =
## This operator computes the symmetric difference of two sets,
## equivalent to but more efficient than `x + y - x * y` or
## `(x - y) + (y - x)`.
runnableExamples:
assert symmetricDifference({1, 2, 3}, {2, 3, 4}) == {1, 4}
else:
func symmetricDifference*[T](x, y: set[T]): set[T] {.inline.} =
result = x + y - (x * y)
proc `-+-`*[T](x, y: set[T]): set[T] {.inline.} =
## Operator alias for `symmetricDifference`.
runnableExamples:
assert {1, 2, 3} -+- {2, 3, 4} == {1, 4}
result = symmetricDifference(x, y)
proc toggle*[T](x: var set[T], y: set[T]) {.inline.} =
## Toggles the existence of each value of `y` in `x`.
## If any element in `y` is also in `x`, it is excluded from `x`;
## otherwise it is included.
## Equivalent to `x = symmetricDifference(x, y)`.
runnableExamples:
var x = {1, 2, 3}
x.toggle({2, 3, 4})
assert x == {1, 4}
x = symmetricDifference(x, y)

View File

@@ -337,6 +337,18 @@ proc SetMinus(a, b: int): int {.compilerproc, asmNoStackFrame.} =
return result;
""".}
proc SetXor(a, b: int): int {.compilerproc, asmNoStackFrame.} =
{.emit: """
var result = {};
for (var elem in `a`) {
if (!`b`[elem]) { result[elem] = true; }
}
for (var elem in `b`) {
if (!`a`[elem]) { result[elem] = true; }
}
return result;
""".}
proc cmpStrings(a, b: string): int {.asmNoStackFrame, compilerproc.} =
{.emit: """
if (`a` == `b`) return 0;

View File

@@ -44,6 +44,26 @@ template main =
s[a2] = true
s[a3] = true
doAssert s == {a2, a3}
block: # set symmetric difference (xor), https://github.com/nim-lang/RFCs/issues/554
type T = set[range[0..15]]
let x: T = {1, 4, 5, 8, 9}
let y: T = {0, 2..6, 9}
let res = symmetricDifference(x, y)
doAssert res == {0, 1, 2, 3, 6, 8}
doAssert res == (x + y - x * y)
doAssert res == ((x - y) + (y - x))
var z = x
doAssert z == {1, 4, 5, 8, 9}
doAssert z == x
z.toggle(y)
doAssert z == res
z.toggle(y)
doAssert z == x
z.toggle({1, 5})
doAssert z == {4, 8, 9}
z.toggle({3, 8})
doAssert z == {3, 4, 9}
main()
static: main()