From df4f707743879a0ea4363fcef446d89d8b421513 Mon Sep 17 00:00:00 2001 From: Imran Hendley Date: Mon, 12 Feb 2018 15:20:17 -0500 Subject: [PATCH] add more functionality from sets to intsets (#7185) * add more functionality from sets to intsets * remove -+- * < and == performance * don't hardcode s.a.len * remove shortcuts from < and == --- lib/pure/collections/intsets.nim | 177 +++++++++++++++++++++++++++---- 1 file changed, 156 insertions(+), 21 deletions(-) diff --git a/lib/pure/collections/intsets.nim b/lib/pure/collections/intsets.nim index 0852325645..bfecfe4477 100644 --- a/lib/pure/collections/intsets.nim +++ b/lib/pure/collections/intsets.nim @@ -108,6 +108,28 @@ proc contains*(s: IntSet, key: int): bool = else: result = false +iterator items*(s: IntSet): int {.inline.} = + ## iterates over any included element of `s`. + if s.elems <= s.a.len: + for i in 0..`_. + result = union(s1, s2) + +proc `*`*(s1, s2: IntSet): IntSet {.inline.} = + ## Alias for `intersection(s1, s2) <#intersection>`_. + result = intersection(s1, s2) + +proc `-`*(s1, s2: IntSet): IntSet {.inline.} = + ## Alias for `difference(s1, s2) <#difference>`_. + result = difference(s1, s2) + +proc disjoint*(s1, s2: IntSet): bool = + ## Returns true iff the sets `s1` and `s2` have no items in common. + for item in s1: + if contains(s2, item): + return false + return true + +proc len*(s: IntSet): int {.inline.} = + ## Returns the number of keys in `s`. + if s.elems < s.a.len: + result = s.elems else: - var r = s.head - while r != nil: - var i = 0 - while i <= high(r.bits): - var w = r.bits[i] - # taking a copy of r.bits[i] here is correct, because - # modifying operations are not allowed during traversation - var j = 0 - while w != 0: # test all remaining bits for zero - if (w and 1) != 0: # the bit is set! - yield (r.key shl TrunkShift) or (i shl IntShift +% j) - inc(j) - w = w shr 1 - inc(i) - r = r.next + result = 0 + for _ in s: + inc(result) + +proc card*(s: IntSet): int {.inline.} = + ## alias for `len() <#len>` _. + result = s.len() + +proc `<=`*(s1, s2: IntSet): bool = + ## Returns true iff `s1` is subset of `s2`. + for item in s1: + if not s2.contains(item): + return false + return true + +proc `<`*(s1, s2: IntSet): bool = + ## Returns true iff `s1` is proper subset of `s2`. + return s1 <= s2 and not (s2 <= s1) + +proc `==`*(s1, s2: IntSet): bool = + ## Returns true if both `s` and `t` have the same members and set size. + return s1 <= s2 and s2 <= s1 template dollarImpl(): untyped = result = "{" @@ -301,9 +381,64 @@ when isMainModule: ys.sort(cmp[int]) assert ys == @[1, 2, 7, 1056] + assert x == y + var z: IntSet for i in 0..1000: incl z, i + assert z.len() == i+1 for i in 0..1000: - assert i in z + assert z.contains(i) + var w = initIntSet() + w.incl(1) + w.incl(4) + w.incl(50) + w.incl(1001) + w.incl(1056) + + var xuw = x.union(w) + var xuws = toSeq(items(xuw)) + xuws.sort(cmp[int]) + assert xuws == @[1, 2, 4, 7, 50, 1001, 1056] + + var xiw = x.intersection(w) + var xiws = toSeq(items(xiw)) + xiws.sort(cmp[int]) + assert xiws == @[1, 1056] + + var xdw = x.difference(w) + var xdws = toSeq(items(xdw)) + xdws.sort(cmp[int]) + assert xdws == @[2, 7] + + var xsw = x.symmetricDifference(w) + var xsws = toSeq(items(xsw)) + xsws.sort(cmp[int]) + assert xsws == @[2, 4, 7, 50, 1001] + + x.incl(w) + xs = toSeq(items(x)) + xs.sort(cmp[int]) + assert xs == @[1, 2, 4, 7, 50, 1001, 1056] + + assert w <= x + + assert w < x + + assert(not disjoint(w, x)) + + var u = initIntSet() + u.incl(3) + u.incl(5) + u.incl(500) + assert disjoint(u, x) + + var v = initIntSet() + v.incl(2) + v.incl(50) + + x.excl(v) + xs = toSeq(items(x)) + xs.sort(cmp[int]) + assert xs == @[1, 4, 7, 1001, 1056]