Source code for whoosh.filedb.filetables

# Copyright 2009 Matt Chaput. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
#    1. Redistributions of source code must retain the above copyright notice,
#       this list of conditions and the following disclaimer.
#
#    2. Redistributions in binary form must reproduce the above copyright
#       notice, this list of conditions and the following disclaimer in the
#       documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY MATT CHAPUT ``AS IS'' AND ANY EXPRESS OR
# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
# EVENT SHALL MATT CHAPUT OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
# OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# The views and conclusions contained in the software and documentation are
# those of the authors and should not be interpreted as representing official
# policies, either expressed or implied, of Matt Chaput.

"""This module defines writer and reader classes for a fast, immutable
on-disk key-value database format. The current format is based heavily on
D. J. Bernstein's CDB format (http://cr.yp.to/cdb.html).
"""

import os
import struct
import sys
from binascii import crc32
from hashlib import md5  # type: ignore @UnresolvedImport

from whoosh.system import _INT_SIZE, emptybytes
from whoosh.util.numlists import GrowableArray

# Exceptions


class FileFormatError(Exception):
    pass


# Hash functions


def cdb_hash(key):
    h = 5381
    for c in key:
        h = (h + (h << 5)) & 0xFFFFFFFF ^ ord(c)
    return h


def md5_hash(key):
    if sys.version_info < (3, 9):
        return int(md5(key).hexdigest(), 16) & 0xFFFFFFFF
    return int(md5(key, usedforsecurity=False).hexdigest(), 16) & 0xFFFFFFFF


def crc_hash(key):
    return crc32(key) & 0xFFFFFFFF


_hash_functions = (md5_hash, crc_hash, cdb_hash)


# Structs

# Two uints before the key/value pair giving the length of the key and value
_lengths = struct.Struct("!ii")
# A pointer in a hash table, giving the hash value and the key position
_pointer = struct.Struct("!Iq")
# A pointer in the hash table directory, giving the position and number of slots
_dir_entry = struct.Struct("!qi")

_directory_size = 256 * _dir_entry.size


# Basic hash file


[docs]class HashWriter: """Implements a fast on-disk key-value store. This hash uses a two-level hashing scheme, where a key is hashed, the low eight bits of the hash value are used to index into one of 256 hash tables. This is basically the CDB algorithm, but unlike CDB this object writes all data serially (it doesn't seek backwards to overwrite information at the end). Also unlike CDB, this format uses 64-bit file pointers, so the file length is essentially unlimited. However, each key and value must be less than 2 GB in length. """ def __init__(self, dbfile, magic=b"HSH3", hashtype=0): """ :param dbfile: a :class:`~whoosh.filedb.structfile.StructFile` object to write to. :param magic: the format tag bytes to write at the start of the file. :param hashtype: an integer indicating which hashing algorithm to use. Possible values are 0 (MD5), 1 (CRC32), or 2 (CDB hash). """ self.dbfile = dbfile self.hashtype = hashtype self.hashfn = _hash_functions[self.hashtype] # A place for subclasses to put extra metadata self.extras = {} self.startoffset = dbfile.tell() # Write format tag dbfile.write(magic) # Write hash type dbfile.write_byte(self.hashtype) # Unused future expansion bits dbfile.write_int(0) dbfile.write_int(0) # 256 lists of hashed keys and positions self.buckets = [[] for _ in range(256)] # List to remember the positions of the hash tables self.directory = [] def tell(self): return self.dbfile.tell()
[docs] def add(self, key, value): """Adds a key/value pair to the file. Note that keys DO NOT need to be unique. You can store multiple values under the same key and retrieve them using :meth:`HashReader.all`. """ assert isinstance(key, bytes) assert isinstance(value, bytes) dbfile = self.dbfile pos = dbfile.tell() dbfile.write(_lengths.pack(len(key), len(value))) dbfile.write(key) dbfile.write(value) # Get hash value for the key h = self.hashfn(key) # Add hash and on-disk position to appropriate bucket self.buckets[h & 255].append((h, pos))
[docs] def add_all(self, items): """Convenience method to add a sequence of ``(key, value)`` pairs. This is the same as calling :meth:`HashWriter.add` on each pair in the sequence. """ add = self.add for key, value in items: add(key, value)
def _write_hashes(self): # Writes 256 hash tables containing pointers to the key/value pairs dbfile = self.dbfile # Represent and empty slot in the hash table using 0,0 (no key can # start at position 0 because of the header) null = (0, 0) for entries in self.buckets: # Start position of this bucket's hash table pos = dbfile.tell() # Remember the start position and the number of slots numslots = 2 * len(entries) self.directory.append((pos, numslots)) # Create the empty hash table hashtable = [null] * numslots # For each (hash value, key position) tuple in the bucket for hashval, position in entries: # Bitshift and wrap to get the slot for this entry slot = (hashval >> 8) % numslots # If the slot is taken, keep going until we find an empty slot while hashtable[slot] != null: slot = (slot + 1) % numslots # Insert the entry into the hashtable hashtable[slot] = (hashval, position) # Write the hash table for this bucket to disk for hashval, position in hashtable: dbfile.write(_pointer.pack(hashval, position)) def _write_directory(self): # Writes a directory of pointers to the 256 hash tables dbfile = self.dbfile for position, numslots in self.directory: dbfile.write(_dir_entry.pack(position, numslots)) def _write_extras(self): self.dbfile.write_pickle(self.extras) def close(self): dbfile = self.dbfile # Write hash tables self._write_hashes() # Write directory of pointers to hash tables self._write_directory() expos = dbfile.tell() # Write extra information self._write_extras() # Write length of pickle dbfile.write_int(dbfile.tell() - expos) endpos = dbfile.tell() dbfile.close() return endpos
[docs]class HashReader: """Reader for the fast on-disk key-value files created by :class:`HashWriter`. """ def __init__(self, dbfile, length=None, magic=b"HSH3", startoffset=0): """ :param dbfile: a :class:`~whoosh.filedb.structfile.StructFile` object to read from. :param length: the length of the file data. This is necessary since the hashing information is written at the end of the file. :param magic: the format tag bytes to look for at the start of the file. If the file's format tag does not match these bytes, the object raises a :class:`FileFormatError` exception. :param startoffset: the starting point of the file data. """ self.dbfile = dbfile self.startoffset = startoffset self.is_closed = False if length is None: dbfile.seek(0, os.SEEK_END) length = dbfile.tell() - startoffset dbfile.seek(startoffset) # Check format tag filemagic = dbfile.read(4) if filemagic != magic: raise FileFormatError(f"Unknown file header {filemagic!r}") # Read hash type self.hashtype = dbfile.read_byte() self.hashfn = _hash_functions[self.hashtype] # Skip unused future expansion bits dbfile.read_int() dbfile.read_int() self.startofdata = dbfile.tell() exptr = startoffset + length - _INT_SIZE # Get the length of extras from the end of the file exlen = dbfile.get_int(exptr) # Read the extras expos = exptr - exlen dbfile.seek(expos) self._read_extras() # Calculate the directory base from the beginning of the extras dbfile.seek(expos - _directory_size) # Read directory of hash tables self.tables = [] entrysize = _dir_entry.size unpackentry = _dir_entry.unpack for _ in range(256): # position, numslots self.tables.append(unpackentry(dbfile.read(entrysize))) # The position of the first hash table is the end of the key/value pairs self.endofdata = self.tables[0][0]
[docs] @classmethod def open(cls, storage, name): """Convenience method to open a hash file given a :class:`whoosh.filedb.filestore.Storage` object and a name. This takes care of opening the file and passing its length to the initializer. """ length = storage.file_length(name) dbfile = storage.open_file(name) return cls(dbfile, length)
def file(self): return self.dbfile def _read_extras(self): try: self.extras = self.dbfile.read_pickle() except EOFError: self.extras = {} def close(self): if self.is_closed: raise Exception(f"Tried to close {self!r} twice") self.dbfile.close() self.is_closed = True def key_at(self, pos): # Returns the key bytes at the given position dbfile = self.dbfile keylen = dbfile.get_uint(pos) return dbfile.get(pos + _lengths.size, keylen) def key_and_range_at(self, pos): # Returns a (keybytes, datapos, datalen) tuple for the key at the given # position dbfile = self.dbfile lenssize = _lengths.size if pos >= self.endofdata: return None keylen, datalen = _lengths.unpack(dbfile.get(pos, lenssize)) keybytes = dbfile.get(pos + lenssize, keylen) datapos = pos + lenssize + keylen return keybytes, datapos, datalen def _ranges(self, pos=None, eod=None): # Yields a series of (keypos, keylength, datapos, datalength) tuples # for the key/value pairs in the file dbfile = self.dbfile pos = pos or self.startofdata eod = eod or self.endofdata lenssize = _lengths.size unpacklens = _lengths.unpack while pos < eod: keylen, datalen = unpacklens(dbfile.get(pos, lenssize)) keypos = pos + lenssize datapos = keypos + keylen yield (keypos, keylen, datapos, datalen) pos = datapos + datalen def __getitem__(self, key): for value in self.all(key): return value raise KeyError(key) def __iter__(self): dbfile = self.dbfile for keypos, keylen, datapos, datalen in self._ranges(): key = dbfile.get(keypos, keylen) value = dbfile.get(datapos, datalen) yield (key, value) def __contains__(self, key): for _ in self.ranges_for_key(key): return True return False def keys(self): dbfile = self.dbfile for keypos, keylen, _, _ in self._ranges(): yield dbfile.get(keypos, keylen) def values(self): dbfile = self.dbfile for _, _, datapos, datalen in self._ranges(): yield dbfile.get(datapos, datalen) def items(self): dbfile = self.dbfile for keypos, keylen, datapos, datalen in self._ranges(): yield (dbfile.get(keypos, keylen), dbfile.get(datapos, datalen)) def get(self, key, default=None): for value in self.all(key): return value return default
[docs] def all(self, key): """Yields a sequence of values associated with the given key.""" dbfile = self.dbfile for datapos, datalen in self.ranges_for_key(key): yield dbfile.get(datapos, datalen)
[docs] def ranges_for_key(self, key): """Yields a sequence of ``(datapos, datalength)`` tuples associated with the given key. """ if not isinstance(key, bytes): raise TypeError(f"Key {key!r} should be bytes") dbfile = self.dbfile # Hash the key keyhash = self.hashfn(key) # Get the position and number of slots for the hash table in which the # key may be found tablestart, numslots = self.tables[keyhash & 255] # If the hash table is empty, we know the key doesn't exists if not numslots: return ptrsize = _pointer.size unpackptr = _pointer.unpack lenssize = _lengths.size unpacklens = _lengths.unpack # Calculate where the key's slot should be slotpos = tablestart + (((keyhash >> 8) % numslots) * ptrsize) # Read slots looking for our key's hash value for _ in range(numslots): slothash, itempos = unpackptr(dbfile.get(slotpos, ptrsize)) # If this slot is empty, we're done if not itempos: return # If the key hash in this slot matches our key's hash, we might have # a match, so read the actual key and see if it's our key if slothash == keyhash: # Read the key and value lengths keylen, datalen = unpacklens(dbfile.get(itempos, lenssize)) # Only bother reading the actual key if the lengths match if keylen == len(key): keystart = itempos + lenssize if key == dbfile.get(keystart, keylen): # The keys match, so yield (datapos, datalen) yield (keystart + keylen, datalen) slotpos += ptrsize # If we reach the end of the hashtable, wrap around if slotpos == tablestart + (numslots * ptrsize): slotpos = tablestart
def range_for_key(self, key): for item in self.ranges_for_key(key): return item raise KeyError(key)
# Ordered hash file
[docs]class OrderedHashWriter(HashWriter): """Implements an on-disk hash, but requires that keys be added in order. An :class:`OrderedHashReader` can then look up "nearest keys" based on the ordering. """ def __init__(self, dbfile): HashWriter.__init__(self, dbfile) # Keep an array of the positions of all keys self.index = GrowableArray("H") # Keep track of the last key added self.lastkey = emptybytes def add(self, key, value): if key <= self.lastkey: raise ValueError(f"Keys must increase: {self.lastkey!r}..{key!r}") self.index.append(self.dbfile.tell()) HashWriter.add(self, key, value) self.lastkey = key def _write_extras(self): dbfile = self.dbfile index = self.index # Store metadata about the index array self.extras["indextype"] = index.typecode self.extras["indexlen"] = len(index) # Write the extras HashWriter._write_extras(self) # Write the index array index.to_file(dbfile)
[docs]class OrderedHashReader(HashReader): def closest_key(self, key): """Returns the closest key equal to or greater than the given key. If there is no key in the file equal to or greater than the given key, returns None. """ pos = self.closest_key_pos(key) if pos is None: return None return self.key_at(pos) def ranges_from(self, key): """Yields a series of ``(keypos, keylen, datapos, datalen)`` tuples for the ordered series of keys equal or greater than the given key. """ pos = self.closest_key_pos(key) if pos is None: return yield from self._ranges(pos=pos) def keys_from(self, key): """Yields an ordered series of keys equal to or greater than the given key. """ dbfile = self.dbfile for keypos, keylen, _, _ in self.ranges_from(key): yield dbfile.get(keypos, keylen) def items_from(self, key): """Yields an ordered series of ``(key, value)`` tuples for keys equal to or greater than the given key. """ dbfile = self.dbfile for keypos, keylen, datapos, datalen in self.ranges_from(key): yield (dbfile.get(keypos, keylen), dbfile.get(datapos, datalen)) def _read_extras(self): dbfile = self.dbfile # Read the extras HashReader._read_extras(self) # Set up for reading the index array indextype = self.extras["indextype"] self.indexbase = dbfile.tell() self.indexlen = self.extras["indexlen"] self.indexsize = struct.calcsize(indextype) # Set up the function to read values from the index array if indextype == "B": self._get_pos = dbfile.get_byte elif indextype == "H": self._get_pos = dbfile.get_ushort elif indextype == "i": self._get_pos = dbfile.get_int elif indextype == "I": self._get_pos = dbfile.get_uint elif indextype == "q": self._get_pos = dbfile.get_long else: raise Exception(f"Unknown index type {indextype!r}") def closest_key_pos(self, key): # Given a key, return the position of that key OR the next highest key # if the given key does not exist if not isinstance(key, bytes): raise TypeError(f"Key {key!r} should be bytes") indexbase = self.indexbase indexsize = self.indexsize key_at = self.key_at _get_pos = self._get_pos # Do a binary search of the positions in the index array lo = 0 hi = self.indexlen while lo < hi: mid = (lo + hi) // 2 midkey = key_at(_get_pos(indexbase + mid * indexsize)) if midkey < key: lo = mid + 1 else: hi = mid # If we went off the end, return None if lo == self.indexlen: return None # Return the closest key return _get_pos(indexbase + lo * indexsize)
# Fielded Ordered hash file class FieldedOrderedHashWriter(HashWriter): """Implements an on-disk hash, but writes separate position indexes for each field. """ def __init__(self, dbfile): HashWriter.__init__(self, dbfile) # Map field names to (startpos, indexpos, length, typecode) self.fieldmap = self.extras["fieldmap"] = {} # Keep track of the last key added self.lastkey = emptybytes def start_field(self, fieldname): self.fieldstart = self.dbfile.tell() self.fieldname = fieldname # Keep an array of the positions of all keys self.poses = GrowableArray("H") self.lastkey = emptybytes def add(self, key, value): if key <= self.lastkey: raise ValueError(f"Keys must increase: {self.lastkey!r}..{key!r}") self.poses.append(self.dbfile.tell() - self.fieldstart) HashWriter.add(self, key, value) self.lastkey = key def end_field(self): dbfile = self.dbfile fieldname = self.fieldname poses = self.poses self.fieldmap[fieldname] = ( self.fieldstart, dbfile.tell(), len(poses), poses.typecode, ) poses.to_file(dbfile) class FieldedOrderedHashReader(HashReader): def __init__(self, *args, **kwargs): HashReader.__init__(self, *args, **kwargs) self.fieldmap = self.extras["fieldmap"] # Make a sorted list of the field names with their start and end ranges self.fieldlist = [] for fieldname in sorted(self.fieldmap.keys()): startpos, ixpos, ixsize, ixtype = self.fieldmap[fieldname] self.fieldlist.append((fieldname, startpos, ixpos)) def field_start(self, fieldname): return self.fieldmap[fieldname][0] def fielded_ranges(self, pos=None, eod=None): flist = self.fieldlist fpos = 0 fieldname, start, end = flist[fpos] for keypos, keylen, datapos, datalen in self._ranges(pos, eod): if keypos >= end: fpos += 1 fieldname, start, end = flist[fpos] yield fieldname, keypos, keylen, datapos, datalen def iter_terms(self): get = self.dbfile.get for fieldname, keypos, keylen, _, _ in self.fielded_ranges(): yield fieldname, get(keypos, keylen) def iter_term_items(self): get = self.dbfile.get for item in self.fielded_ranges(): fieldname, keypos, keylen, datapos, datalen = item yield fieldname, get(keypos, keylen), get(datapos, datalen) def contains_term(self, fieldname, btext): try: x = self.range_for_term(fieldname, btext) return True except KeyError: return False def range_for_term(self, fieldname, btext): start, ixpos, ixsize, code = self.fieldmap[fieldname] for datapos, datalen in self.ranges_for_key(btext): if start < datapos < ixpos: return datapos, datalen raise KeyError((fieldname, btext)) def term_data(self, fieldname, btext): datapos, datalen = self.range_for_term(fieldname, btext) return self.dbfile.get(datapos, datalen) def term_get(self, fieldname, btext, default=None): try: return self.term_data(fieldname, btext) except KeyError: return default def closest_term_pos(self, fieldname, key): # Given a key, return the position of that key OR the next highest key # if the given key does not exist if not isinstance(key, bytes): raise TypeError(f"Key {key!r} should be bytes") dbfile = self.dbfile key_at = self.key_at startpos, ixpos, ixsize, ixtype = self.fieldmap[fieldname] if ixtype == "B": get_pos = dbfile.get_byte elif ixtype == "H": get_pos = dbfile.get_ushort elif ixtype == "i": get_pos = dbfile.get_int elif ixtype == "I": get_pos = dbfile.get_uint elif ixtype == "q": get_pos = dbfile.get_long else: raise Exception(f"Unknown index type {ixtype!r}") # Do a binary search of the positions in the index array lo = 0 hi = ixsize while lo < hi: mid = (lo + hi) // 2 midkey = key_at(startpos + get_pos(ixpos + mid * ixsize)) if midkey < key: lo = mid + 1 else: hi = mid # If we went off the end, return None if lo == ixsize: return None # Return the closest key return startpos + get_pos(ixpos + lo * ixsize) def closest_term(self, fieldname, btext): pos = self.closest_term_pos(fieldname, btext) if pos is None: return None return self.key_at(pos) def term_ranges_from(self, fieldname, btext): pos = self.closest_term_pos(fieldname, btext) if pos is None: return startpos, ixpos, ixsize, ixtype = self.fieldmap[fieldname] yield from self._ranges(pos, ixpos) def terms_from(self, fieldname, btext): dbfile = self.dbfile for keypos, keylen, _, _ in self.term_ranges_from(fieldname, btext): yield dbfile.get(keypos, keylen) def term_items_from(self, fieldname, btext): dbfile = self.dbfile for item in self.term_ranges_from(fieldname, btext): keypos, keylen, datapos, datalen = item yield (dbfile.get(keypos, keylen), dbfile.get(datapos, datalen))