Source code for whoosh.idsets

"""
An implementation of an object that acts like a collection of on/off bits.
"""

import operator
from array import array
from bisect import bisect_left, bisect_right
from itertools import zip_longest

from whoosh.util.numeric import bytes_for_bits

# Number of '1' bits in each byte (0-255)
_1SPERBYTE = array(
    "B",
    [
        0,
        1,
        1,
        2,
        1,
        2,
        2,
        3,
        1,
        2,
        2,
        3,
        2,
        3,
        3,
        4,
        1,
        2,
        2,
        3,
        2,
        3,
        3,
        4,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        1,
        2,
        2,
        3,
        2,
        3,
        3,
        4,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        1,
        2,
        2,
        3,
        2,
        3,
        3,
        4,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        4,
        5,
        5,
        6,
        5,
        6,
        6,
        7,
        1,
        2,
        2,
        3,
        2,
        3,
        3,
        4,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        4,
        5,
        5,
        6,
        5,
        6,
        6,
        7,
        2,
        3,
        3,
        4,
        3,
        4,
        4,
        5,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        4,
        5,
        5,
        6,
        5,
        6,
        6,
        7,
        3,
        4,
        4,
        5,
        4,
        5,
        5,
        6,
        4,
        5,
        5,
        6,
        5,
        6,
        6,
        7,
        4,
        5,
        5,
        6,
        5,
        6,
        6,
        7,
        5,
        6,
        6,
        7,
        6,
        7,
        7,
        8,
    ],
)


[docs]class DocIdSet: """Base class for a set of positive integers, implementing a subset of the built-in ``set`` type's interface with extra docid-related methods. This is a superclass for alternative set implementations to the built-in ``set`` which are more memory-efficient and specialized toward storing sorted lists of positive integers, though they will inevitably be slower than ``set`` for most operations since they're pure Python. """ def __eq__(self, other): for a, b in zip(self, other): if a != b: return False return True def __neq__(self, other): return not self.__eq__(other) def __len__(self): raise NotImplementedError def __iter__(self): raise NotImplementedError def __contains__(self, i): raise NotImplementedError def __or__(self, other): return self.union(other) def __and__(self, other): return self.intersection(other) def __sub__(self, other): return self.difference(other) def copy(self): raise NotImplementedError def add(self, n): raise NotImplementedError def discard(self, n): raise NotImplementedError def update(self, other): add = self.add for i in other: add(i) def intersection_update(self, other): for n in self: if n not in other: self.discard(n) def difference_update(self, other): for n in other: self.discard(n)
[docs] def invert_update(self, size): """Updates the set in-place to contain numbers in the range ``[0 - size)`` except numbers that are in this set. """ for i in range(size): if i in self: self.discard(i) else: self.add(i)
def intersection(self, other): c = self.copy() c.intersection_update(other) return c def union(self, other): c = self.copy() c.update(other) return c def difference(self, other): c = self.copy() c.difference_update(other) return c def invert(self, size): c = self.copy() c.invert_update(size) return c def isdisjoint(self, other): a = self b = other if len(other) < len(self): a, b = other, self for num in a: if num in b: return False return True
[docs] def before(self, i): """Returns the previous integer in the set before ``i``, or None.""" raise NotImplementedError
[docs] def after(self, i): """Returns the next integer in the set after ``i``, or None.""" raise NotImplementedError
[docs] def first(self): """Returns the first (lowest) integer in the set.""" raise NotImplementedError
[docs] def last(self): """Returns the last (highest) integer in the set.""" raise NotImplementedError
[docs]class BaseBitSet(DocIdSet): # Methods to override def byte_count(self): raise NotImplementedError def _get_byte(self, i): raise NotImplementedError def _iter_bytes(self): raise NotImplementedError # Base implementations def __len__(self): return sum(_1SPERBYTE[b] for b in self._iter_bytes()) def __iter__(self): base = 0 for byte in self._iter_bytes(): for i in range(8): if byte & (1 << i): yield base + i base += 8 def __nonzero__(self): return any(n for n in self._iter_bytes()) __bool__ = __nonzero__ def __contains__(self, i): bucket = i // 8 if bucket >= self.byte_count(): return False return bool(self._get_byte(bucket) & (1 << (i & 7))) def first(self): return self.after(-1) def last(self): return self.before(self.byte_count() * 8 + 1) def before(self, i): _get_byte = self._get_byte size = self.byte_count() * 8 if i <= 0: return None elif i >= size: i = size - 1 else: i -= 1 bucket = i // 8 while i >= 0: byte = _get_byte(bucket) if not byte: bucket -= 1 i = bucket * 8 + 7 continue if byte & (1 << (i & 7)): return i if i % 8 == 0: bucket -= 1 i -= 1 return None def after(self, i): _get_byte = self._get_byte size = self.byte_count() * 8 if i >= size: return None elif i < 0: i = 0 else: i += 1 bucket = i // 8 while i < size: byte = _get_byte(bucket) if not byte: bucket += 1 i = bucket * 8 continue if byte & (1 << (i & 7)): return i i += 1 if i % 8 == 0: bucket += 1 return None
[docs]class OnDiskBitSet(BaseBitSet): """A DocIdSet backed by an array of bits on disk. >>> st = RamStorage() >>> f = st.create_file("test.bin") >>> bs = BitSet([1, 10, 15, 7, 2]) >>> bytecount = bs.to_disk(f) >>> f.close() >>> # ... >>> f = st.open_file("test.bin") >>> odbs = OnDiskBitSet(f, bytecount) >>> list(odbs) [1, 2, 7, 10, 15] """ def __init__(self, dbfile, basepos, bytecount): """ :param dbfile: a :class:`~whoosh.filedb.structfile.StructFile` object to read from. :param basepos: the base position of the bytes in the given file. :param bytecount: the number of bytes to use for the bit array. """ self._dbfile = dbfile self._basepos = basepos self._bytecount = bytecount def __repr__(self): return "%s(%s, %d, %d)" % ( self.__class__.__name__, self.dbfile, self._basepos, self.bytecount, ) def byte_count(self): return self._bytecount def _get_byte(self, n): return self._dbfile.get_byte(self._basepos + n) def _iter_bytes(self): dbfile = self._dbfile dbfile.seek(self._basepos) for _ in range(self._bytecount): yield dbfile.read_byte()
[docs]class BitSet(BaseBitSet): """A DocIdSet backed by an array of bits. This can also be useful as a bit array (e.g. for a Bloom filter). It is much more memory efficient than a large built-in set of integers, but wastes memory for sparse sets. """ def __init__(self, source=None, size=0): """ :param maxsize: the maximum size of the bit array. :param source: an iterable of positive integers to add to this set. :param bits: an array of unsigned bytes ("B") to use as the underlying bit array. This is used by some of the object's methods. """ # If the source is a list, tuple, or set, we can guess the size if not size and isinstance(source, (list, tuple, set, frozenset)): size = max(source) bytecount = bytes_for_bits(size) self.bits = array("B", (0 for _ in range(bytecount))) if source: add = self.add for num in source: add(num) def __repr__(self): return f"{self.__class__.__name__}({list(self)!r})" def byte_count(self): return len(self.bits) def _get_byte(self, n): return self.bits[n] def _iter_bytes(self): return iter(self.bits) def _trim(self): bits = self.bits last = len(self.bits) - 1 while last >= 0 and not bits[last]: last -= 1 del self.bits[last + 1 :] def _resize(self, tosize): curlength = len(self.bits) newlength = bytes_for_bits(tosize) if newlength > curlength: self.bits.extend((0,) * (newlength - curlength)) elif newlength < curlength: del self.bits[newlength + 1 :] def _zero_extra_bits(self, size): bits = self.bits spill = size - ((len(bits) - 1) * 8) if spill: mask = 2**spill - 1 bits[-1] = bits[-1] & mask def _logic(self, obj, op, other): objbits = obj.bits for i, (byte1, byte2) in enumerate( zip_longest(objbits, other.bits, fillvalue=0) ): value = op(byte1, byte2) & 0xFF if i >= len(objbits): objbits.append(value) else: objbits[i] = value obj._trim() return obj def to_disk(self, dbfile): dbfile.write_array(self.bits) return len(self.bits) @classmethod def from_bytes(cls, bs): b = cls() b.bits = array("B", bs) return b @classmethod def from_disk(cls, dbfile, bytecount): return cls.from_bytes(dbfile.read_array("B", bytecount)) def copy(self): b = self.__class__() b.bits = array("B", iter(self.bits)) return b def clear(self): for i in range(len(self.bits)): self.bits[i] = 0 def add(self, i): bucket = i >> 3 if bucket >= len(self.bits): self._resize(i + 1) self.bits[bucket] |= 1 << (i & 7) def discard(self, i): bucket = i >> 3 self.bits[bucket] &= ~(1 << (i & 7)) def _resize_to_other(self, other): if isinstance(other, (list, tuple, set, frozenset)): maxbit = max(other) if maxbit // 8 > len(self.bits): self._resize(maxbit) def update(self, iterable): self._resize_to_other(iterable) DocIdSet.update(self, iterable) def intersection_update(self, other): if isinstance(other, BitSet): return self._logic(self, operator.__and__, other) discard = self.discard for n in self: if n not in other: discard(n) def difference_update(self, other): if isinstance(other, BitSet): return self._logic(self, lambda x, y: x & ~y, other) discard = self.discard for n in other: discard(n) def invert_update(self, size): bits = self.bits for i in range(len(bits)): bits[i] = ~bits[i] & 0xFF self._zero_extra_bits(size) def union(self, other): if isinstance(other, BitSet): return self._logic(self.copy(), operator.__or__, other) b = self.copy() b.update(other) return b def intersection(self, other): if isinstance(other, BitSet): return self._logic(self.copy(), operator.__and__, other) return BitSet(source=(n for n in self if n in other)) def difference(self, other): if isinstance(other, BitSet): return self._logic(self.copy(), lambda x, y: x & ~y, other) return BitSet(source=(n for n in self if n not in other))
[docs]class SortedIntSet(DocIdSet): """A DocIdSet backed by a sorted array of integers.""" def __init__(self, source=None, typecode="I"): if source: self.data = array(typecode, sorted(source)) else: self.data = array(typecode) self.typecode = typecode def copy(self): sis = SortedIntSet() sis.data = array(self.typecode, self.data) return sis def size(self): return len(self.data) * self.data.itemsize def __repr__(self): return f"{self.__class__.__name__}({self.data!r})" def __len__(self): return len(self.data) def __iter__(self): return iter(self.data) def __nonzero__(self): return bool(self.data) __bool__ = __nonzero__ def __contains__(self, i): data = self.data if not data or i < data[0] or i > data[-1]: return False pos = bisect_left(data, i) if pos == len(data): return False return data[pos] == i def add(self, i): data = self.data if not data or i > data[-1]: data.append(i) else: mn = data[0] mx = data[-1] if i == mn or i == mx: return elif i > mx: data.append(i) elif i < mn: data.insert(0, i) else: pos = bisect_left(data, i) if data[pos] != i: data.insert(pos, i) def discard(self, i): data = self.data pos = bisect_left(data, i) if data[pos] == i: data.pop(pos) def clear(self): self.data = array(self.typecode) def intersection_update(self, other): self.data = array(self.typecode, (num for num in self if num in other)) def difference_update(self, other): self.data = array(self.typecode, (num for num in self if num not in other)) def intersection(self, other): return SortedIntSet(num for num in self if num in other) def difference(self, other): return SortedIntSet(num for num in self if num not in other) def first(self): return self.data[0] def last(self): return self.data[-1] def before(self, i): data = self.data pos = bisect_left(data, i) if pos < 1: return None else: return data[pos - 1] def after(self, i): data = self.data if not data or i >= data[-1]: return None elif i < data[0]: return data[0] pos = bisect_right(data, i) return data[pos]
class ReverseIdSet(DocIdSet): """ Wraps a DocIdSet object and reverses its semantics, so docs in the wrapped set are not in this set, and vice-versa. """ def __init__(self, idset, limit): """ :param idset: the DocIdSet object to wrap. :param limit: the highest possible ID plus one. """ self.idset = idset self.limit = limit def __len__(self): return self.limit - len(self.idset) def __contains__(self, i): return i not in self.idset def __iter__(self): ids = iter(self.idset) try: nx = next(ids) except StopIteration: nx = -1 for i in range(self.limit): if i == nx: try: nx = next(ids) except StopIteration: nx = -1 else: yield i def add(self, n): self.idset.discard(n) def discard(self, n): self.idset.add(n) def first(self): for i in self: return i def last(self): idset = self.idset maxid = self.limit - 1 if idset.last() < maxid - 1: return maxid for i in range(maxid, -1, -1): if i not in idset: return i ROARING_CUTOFF = 1 << 12 class RoaringIdSet(DocIdSet): """ Separates IDs into ranges of 2^16 bits, and stores each range in the most efficient type of doc set, either a BitSet (if the range has >= 2^12 IDs) or a sorted ID set of 16-bit shorts. """ cutoff = 2**12 def __init__(self, source=None): self.idsets = [] if source: self.update(source) def __len__(self): if not self.idsets: return 0 return sum(len(idset) for idset in self.idsets) def __contains__(self, n): bucket = n >> 16 if bucket >= len(self.idsets): return False return (n - (bucket << 16)) in self.idsets[bucket] def __iter__(self): for i, idset in self.idsets: floor = i << 16 for n in idset: yield floor + n def _find(self, n): bucket = n >> 16 floor = n << 16 if bucket >= len(self.idsets): self.idsets.extend( [SortedIntSet() for _ in range(len(self.idsets), bucket + 1)] ) idset = self.idsets[bucket] return bucket, floor, idset def add(self, n): bucket, floor, idset = self._find(n) oldlen = len(idset) idset.add(n - floor) if oldlen <= ROARING_CUTOFF < len(idset): self.idsets[bucket] = BitSet(idset) def discard(self, n): bucket, floor, idset = self._find(n) oldlen = len(idset) idset.discard(n - floor) if oldlen > ROARING_CUTOFF >= len(idset): self.idsets[bucket] = SortedIntSet(idset)
[docs]class MultiIdSet(DocIdSet): """Wraps multiple SERIAL sub-DocIdSet objects and presents them as an aggregated, read-only set. """ def __init__(self, idsets, offsets): """ :param idsets: a list of DocIdSet objects. :param offsets: a list of offsets corresponding to the DocIdSet objects in ``idsets``. """ assert len(idsets) == len(offsets) self.idsets = idsets self.offsets = offsets def _document_set(self, n): offsets = self.offsets return max(bisect_left(offsets, n), len(self.offsets) - 1) def _set_and_docnum(self, n): setnum = self._document_set(n) offset = self.offsets[setnum] return self.idsets[setnum], n - offset def __len__(self): return sum(len(idset) for idset in self.idsets) def __iter__(self): for idset, offset in zip(self.idsets, self.offsets): for docnum in idset: yield docnum + offset def __contains__(self, item): idset, n = self._set_and_docnum(item) return n in idset