diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e1752d..1ee2a04 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # PyProbables Changelog +### Version 0.1.1: +* CuckooFilter + * Import / Export functionality + * Enforce single insertion per key + * Auto expand when insertion failure OR when called to do so (settable) + ### Version 0.1.0: * Cuckoo Filter * Added basic Cuckoo Filter code diff --git a/probables/__init__.py b/probables/__init__.py index 5e6a161..1efadcb 100644 --- a/probables/__init__.py +++ b/probables/__init__.py @@ -11,7 +11,7 @@ __maintainer__ = 'Tyler Barrus' __email__ = 'barrust@gmail.com' __license__ = 'MIT' -__version__ = '0.1.0' +__version__ = '0.1.1' __credits__ = [] __url__ = 'https://github.com/barrust/pyprobables' __bugtrack_url__ = 'https://github.com/barrust/pyprobables/issues' diff --git a/probables/blooms/countingbloom.py b/probables/blooms/countingbloom.py index 277d42d..c6c58f0 100644 --- a/probables/blooms/countingbloom.py +++ b/probables/blooms/countingbloom.py @@ -6,6 +6,7 @@ from __future__ import (unicode_literals, absolute_import, print_function, division) from . basebloom import (BaseBloom) +from .. constants import (UINT32_T_MAX, UINT64_T_MAX) MISMATCH_MSG = ('The parameter second must be of type CountingBloomFilter') @@ -44,8 +45,6 @@ def __init__(self, est_elements=None, false_positive_rate=None, false_positive_rate, filepath, hex_string, hash_function) - self.__uint32_t_max = 2**32 - 1 - self.__uint64_t_max = 2**64 - 1 def __str__(self): ''' correctly handle python 3 vs python2 encoding if necessary ''' @@ -110,20 +109,20 @@ def add_alt(self, hashes, num_els=1): Returns: int: Maximum number of insertions ''' - res = self.__uint32_t_max + res = UINT32_T_MAX for i in list(range(0, self.number_hashes)): k = int(hashes[i]) % self.number_bits j = self._get_element(k) tmp = j + num_els - if tmp <= self.__uint32_t_max: + if tmp <= UINT32_T_MAX: self._bloom[k] = self._get_set_element(j + num_els) else: - self._bloom[k] = self.__uint32_t_max + self._bloom[k] = UINT32_T_MAX if self._bloom[k] < res: res = self._bloom[k] self.elements_added += num_els - if self.elements_added > self.__uint64_t_max: - self.elements_added = self.__uint64_t_max + if self.elements_added > UINT64_T_MAX: + self.elements_added = UINT64_T_MAX return res def check(self, key): @@ -147,7 +146,7 @@ def check_alt(self, hashes): Returns: int: Maximum number of insertions ''' - res = self.__uint32_t_max + res = UINT32_T_MAX for i in list(range(0, self.number_hashes)): k = int(hashes[i]) % self.number_bits tmp = self._get_element(k) @@ -179,8 +178,8 @@ def remove_alt(self, hashes, num_els=1): int: Maximum number of insertions after the removal ''' tmp = self.check_alt(hashes) - if tmp == self.__uint32_t_max: # cannot remove if we have hit the max - return self.__uint32_t_max + if tmp == UINT32_T_MAX: # cannot remove if we have hit the max + return UINT32_T_MAX elif tmp == 0: return 0 diff --git a/probables/constants.py b/probables/constants.py new file mode 100644 index 0000000..29f5bb6 --- /dev/null +++ b/probables/constants.py @@ -0,0 +1,7 @@ +''' Project Constants (or basic numerical constants...) ''' +INT32_T_MIN = -2147483648 +INT32_T_MAX = 2147483647 +INT64_T_MIN = -9223372036854775808 +INT64_T_MAX = 9223372036854775807 +UINT32_T_MAX = 2**32 - 1 +UINT64_T_MAX = 2**64 - 1 diff --git a/probables/countminsketch/countminsketch.py b/probables/countminsketch/countminsketch.py index dd1a8a8..89792f9 100644 --- a/probables/countminsketch/countminsketch.py +++ b/probables/countminsketch/countminsketch.py @@ -11,6 +11,7 @@ from .. exceptions import (InitializationError, NotSupportedError) from .. hashes import (default_fnv_1a) from .. utilities import (is_valid_file) +from .. constants import (INT32_T_MIN, INT32_T_MAX, INT64_T_MIN, INT64_T_MAX) class CountMinSketch(object): @@ -50,11 +51,6 @@ def __init__(self, width=None, depth=None, confidence=None, self.__error_rate = 0.0 self.__elements_added = 0 self.__query_method = self.__min_query - # for python2 and python3 support - self.__int32_t_min = -2147483648 - self.__int32_t_max = 2147483647 - self.__int64_t_min = -9223372036854775808 - self.__int64_t_max = 9223372036854775807 if is_valid_file(filepath): self.__load(filepath) @@ -216,13 +212,13 @@ def add_alt(self, hashes, num_els=1): for i, val in enumerate(hashes): t_bin = (val % self.__width) + (i * self.__width) self._bins[t_bin] += num_els - if self._bins[t_bin] > self.__int32_t_max: - self._bins[t_bin] = self.__int32_t_max + if self._bins[t_bin] > INT32_T_MAX: + self._bins[t_bin] = INT32_T_MAX res.append(self._bins[t_bin]) self.__elements_added += num_els - if self.__elements_added > self.__int64_t_max: - self.__elements_added = self.__int64_t_max + if self.__elements_added > INT64_T_MAX: + self.__elements_added = INT64_T_MAX return self.__query_method(sorted(res)) def remove(self, key, num_els=1): @@ -252,12 +248,12 @@ def remove_alt(self, hashes, num_els=1): for i, val in enumerate(hashes): t_bin = (val % self.__width) + (i * self.__width) self._bins[t_bin] -= num_els - if self._bins[t_bin] < self.__int32_t_min: - self._bins[t_bin] = self.__int32_t_min + if self._bins[t_bin] < INT32_T_MIN: + self._bins[t_bin] = INT32_T_MIN res.append(self._bins[t_bin]) self.__elements_added -= num_els - if self.__elements_added < self.__int64_t_min: - self.__elements_added = self.__int64_t_min + if self.__elements_added < INT64_T_MIN: + self.__elements_added = INT64_T_MIN return self.__query_method(sorted(res)) diff --git a/probables/cuckoo/cuckoo.py b/probables/cuckoo/cuckoo.py index ce33a69..3f2fbb5 100644 --- a/probables/cuckoo/cuckoo.py +++ b/probables/cuckoo/cuckoo.py @@ -5,8 +5,9 @@ from __future__ import (unicode_literals, absolute_import, print_function, division) +import os import random -from itertools import chain +from struct import (pack, unpack, calcsize) from .. hashes import (fnv_1a) from .. utilities import (get_x_bits) @@ -20,19 +21,31 @@ class CuckooFilter(object): capacity (int): The number of bins bucket_size (int): The number of buckets per bin max_swaps (int): The number of cuckoo swaps before stopping + expansion_rate (int): The rate at which to expand + auto_expand (bool): If the filter should automatically expand + filename (str): The path to the file to load or None if no file Returns: CuckooFilter: A Cuckoo Filter object ''' - def __init__(self, capacity=10000, bucket_size=4, max_swaps=500): + def __init__(self, capacity=10000, bucket_size=4, max_swaps=500, + expansion_rate=2, auto_expand=True, filepath=None): ''' setup the data structure ''' self.__bucket_size = bucket_size self.__cuckoo_capacity = capacity self.__max_cuckoo_swaps = max_swaps - self.__buckets = list() - for _ in range(self.capacity): - self.__buckets.append(list()) + self.__expansion_rate = None + self.expansion_rate = expansion_rate + self.__auto_expand = None + self.auto_expand = auto_expand + self.__hash_func = fnv_1a self.__inserted_elements = 0 + if filepath is None: + self.__buckets = list() + for _ in range(self.capacity): + self.__buckets.append(list()) + else: + self.__load(filepath) def __contains__(self, key): ''' setup the `in` keyword ''' @@ -70,6 +83,38 @@ def bucket_size(self): Not settable ''' return self.__bucket_size + @property + def buckets(self): + ''' list(list): The buckets holding the fingerprints + + Note: + Not settable ''' + return self.__buckets + + @property + def expansion_rate(self): + ''' int: The rate at expansion when the filter grows''' + return self.__expansion_rate + + @expansion_rate.setter + def expansion_rate(self, val): + ''' set the self expand value ''' + self.__expansion_rate = int(val) + + @property + def auto_expand(self): + ''' bool: True if the cuckoo filter will expand automatically ''' + return self.__auto_expand + + @auto_expand.setter + def auto_expand(self, val): + ''' set the self expand value ''' + self.__auto_expand = bool(val) + + def load_factor(self): + ''' float: How full the Cuckoo Filter is currently ''' + return self.elements_added / (self.capacity * self.bucket_size) + def add(self, key): ''' Add element key to the filter @@ -79,12 +124,71 @@ def add(self, key): CuckooFilterFullError: When element not inserted after \ maximum number of swaps or 'kicks' ''' idx_1, idx_2, fingerprint = self._generate_fingerprint_info(key) + + is_present = self._check_if_present(idx_1, idx_2, fingerprint) + if is_present is not None: # already there, nothing to do + return + finger = self._insert_fingerprint(fingerprint, idx_1, idx_2) + if finger is None: + return + elif self.__auto_expand: + self.__expand_logic(finger) + else: + raise CuckooFilterFullError('The CuckooFilter is currently full') + + def check(self, key): + ''' Check if an element is in the filter + + Args: + key (str): Element to check ''' + idx_1, idx_2, fingerprint = self._generate_fingerprint_info(key) + is_present = self._check_if_present(idx_1, idx_2, fingerprint) + if is_present is not None: + return True + return False + + def remove(self, key): + ''' Remove an element from the filter + + Args: + key (str): Element to remove ''' + idx_1, idx_2, fingerprint = self._generate_fingerprint_info(key) + idx = self._check_if_present(idx_1, idx_2, fingerprint) + if idx is None: + return False + self.__buckets[idx].remove(fingerprint) + self.__inserted_elements -= 1 + return True + + def export(self, filename): + ''' Export cuckoo filter to file + + Args: + filename (str): Path to file to export + ''' + with open(filename, 'wb') as filepointer: + for bucket in self.__buckets: + # do something for each... + rep = len(bucket) * 'I' + filepointer.write(pack(rep, *bucket)) + leftover = self.bucket_size - len(bucket) + rep = leftover * 'I' + filepointer.write(pack(rep, *([0] * leftover))) + # now put out the required information at the end + filepointer.write(pack('II', self.bucket_size, self.max_swaps)) + + def expand(self): + ''' Expand the cuckoo filter ''' + self.__expand_logic(None) + + def _insert_fingerprint(self, fingerprint, idx_1, idx_2): + ''' insert a fingerprint ''' if self.__insert_element(fingerprint, idx_1): self.__inserted_elements += 1 - return idx_1 + return elif self.__insert_element(fingerprint, idx_2): self.__inserted_elements += 1 - return idx_2 + return # we didn't insert, so now we need to randomly select one index to use # and move things around to the other index, if possible, until we @@ -105,34 +209,41 @@ def add(self, key): if self.__insert_element(fingerprint, idx): self.__inserted_elements += 1 - return idx - raise CuckooFilterFullError('The CuckooFilter is currently full') - - def check(self, key): - ''' Check if an element is in the filter - - Args: - key (str): Element to check ''' - idx_1, idx_2, fingerprint = self._generate_fingerprint_info(key) - if fingerprint in chain(self.__buckets[idx_1], self.__buckets[idx_2]): - return True - return False - - def remove(self, key): - ''' Remove an element from the filter - - Args: - key (str): Element to remove ''' - idx_1, idx_2, fingerprint = self._generate_fingerprint_info(key) + return + + # if we got here we have an error... we might need to know what is left + return fingerprint + + def __load(self, filename): + ''' load a cuckoo filter from file ''' + with open(filename, 'rb') as filepointer: + offset = calcsize('II') + int_size = calcsize('I') + filepointer.seek(offset * -1, os.SEEK_END) + list_size = filepointer.tell() + mybytes = unpack('II', filepointer.read(offset)) + self.__bucket_size = mybytes[0] + self.__max_cuckoo_swaps = mybytes[1] + self.__cuckoo_capacity = list_size // int_size // self.bucket_size + self.__inserted_elements = 0 + # now pull everything in! + filepointer.seek(0, os.SEEK_SET) + self.__buckets = list() + for i in range(self.capacity): + self.__buckets.append(list()) + for _ in range(self.bucket_size): + fingerprint = unpack('I', filepointer.read(int_size))[0] + if fingerprint != 0: + self.__buckets[i].append(fingerprint) + self.__inserted_elements += 1 + + def _check_if_present(self, idx_1, idx_2, fingerprint): + ''' wrapper for checking if fingerprint is already inserted ''' if fingerprint in self.__buckets[idx_1]: - self.__buckets[idx_1].remove(fingerprint) - self.__inserted_elements -= 1 - return True + return idx_1 elif fingerprint in self.__buckets[idx_2]: - self.__buckets[idx_2].remove(fingerprint) - self.__inserted_elements -= 1 - return True - return False + return idx_2 + return None def __insert_element(self, fingerprint, idx): ''' insert element wrapper ''' @@ -141,6 +252,28 @@ def __insert_element(self, fingerprint, idx): return True return False + def __expand_logic(self, extra_fingerprint): + ''' the logic to acutally expand the cuckoo filter ''' + # get all the fingerprints + fingerprints = list() + if extra_fingerprint is not None: + fingerprints.append(extra_fingerprint) + for idx in range(self.capacity): + fingerprints.extend(self.buckets[idx]) + + self.__cuckoo_capacity = self.capacity * self.expansion_rate + self.__buckets = list() + self.__inserted_elements = 0 + for _ in range(self.capacity): + self.buckets.append(list()) + + for finger in fingerprints: + idx_1, idx_2 = self._indicies_from_fingerprint(finger) + res = self._insert_fingerprint(finger, idx_1, idx_2) + if res is not None: # again, this *shouldn't* happen + msg = ('The CuckooFilter failed to expand') + raise CuckooFilterFullError(msg) + def _indicies_from_fingerprint(self, fingerprint): ''' Generate the possible insertion indicies from a fingerprint diff --git a/probables/hashes.py b/probables/hashes.py index a65f811..6d99c21 100644 --- a/probables/hashes.py +++ b/probables/hashes.py @@ -4,8 +4,7 @@ from hashlib import (md5, sha256) from struct import (unpack) # needed to turn digests into numbers - -UIN64_MAX = 2 ** 64 +from . constants import (UINT64_T_MAX) def default_fnv_1a(key, depth): @@ -32,11 +31,12 @@ def fnv_1a(key): Args: key (str): The element to be hashed ''' + max64mod = UINT64_T_MAX + 1 hval = 14695981039346656073 fnv_64_prime = 1099511628211 for t_str in key: hval = hval ^ ord(t_str) - hval = (hval * fnv_64_prime) % UIN64_MAX + hval = (hval * fnv_64_prime) % max64mod return hval diff --git a/tests/bloom_test.py b/tests/bloom_test.py index 5e3796b..a4d1e2c 100644 --- a/tests/bloom_test.py +++ b/tests/bloom_test.py @@ -490,8 +490,10 @@ def test_bfod_export(self): def test_bfod_export_hex(self): ''' test that page error is thrown correctly ''' filename = 'tmp.blm' - blm = BloomFilterOnDisk(filename, 10, 0.05) - self.assertRaises(NotSupportedError, lambda: blm.export_hex()) + def runner(): + blm = BloomFilterOnDisk(filename, 10, 0.05) + blm.export_hex() + self.assertRaises(NotSupportedError, runner) os.remove(filename) def test_bfod_export_hex_msg(self): diff --git a/tests/countminsketch_test.py b/tests/countminsketch_test.py index b7d8768..3bd5193 100644 --- a/tests/countminsketch_test.py +++ b/tests/countminsketch_test.py @@ -7,6 +7,8 @@ CountMeanSketch, CountMeanMinSketch) from probables.exceptions import (InitializationError, NotSupportedError) from . utilities import(calc_file_md5, different_hash) +from probables.constants import (INT32_T_MIN, INT32_T_MAX, INT64_T_MAX, + INT64_T_MIN) class TestCountMinSketch(unittest.TestCase): @@ -237,20 +239,20 @@ def test_cms_different_hash(self): def test_cms_min_val(self): ''' test when we come to the bottom of the 32 bit int (stop overflow) ''' - too_large = 2 ** 31 + too_large = INT64_T_MAX + 5 cms = CountMinSketch(width=1000, depth=5) - cms.remove('this is a test', too_large + 1) - self.assertEqual(cms.check('this is a test'), -too_large) - self.assertEqual(cms.elements_added, -too_large - 1) + cms.remove('this is a test', too_large) + self.assertEqual(cms.check('this is a test'), INT32_T_MIN) + self.assertEqual(cms.elements_added, INT64_T_MIN) def test_cms_max_val(self): ''' test when we come to the top of the 32 bit int (stop overflow) ''' - too_large = 2 ** 31 + too_large = INT64_T_MAX + 5 cms = CountMinSketch(width=1000, depth=5) cms.add('this is a test', too_large) - self.assertEqual(cms.check('this is a test'), too_large - 1) - self.assertEqual(cms.elements_added, too_large) + self.assertEqual(cms.check('this is a test'), INT32_T_MAX) + self.assertEqual(cms.elements_added, INT64_T_MAX) def test_cms_clear(self): ''' test the clear functionality ''' @@ -262,23 +264,6 @@ def test_cms_clear(self): self.assertEqual(cms.elements_added, 0) self.assertEqual(cms.check('this is a test'), 0) - # def test_cms_bad_query(self): - # ''' test a bad query ''' - # cms = CountMinSketch(width=1000, depth=5) - # self.assertEqual(cms.add('this is a test', 100), 100) - # self.assertRaises(NotSupportedError, - # lambda: cms.check('this is a test', 'unknown')) - - # def test_cms_bad_query_msg(self): - # ''' test a bad query ''' - # cms = CountMinSketch(width=1000, depth=5) - # self.assertEqual(cms.add('this is a test', 100), 100) - # try: - # cms.check('this is a test', 'unknown') - # except NotSupportedError as ex: - # msg = "`check`: Invalid query type" - # self.assertEqual(str(ex), msg) - def test_cms_str(self): ''' test the string representation of the count-min sketch ''' cms = CountMinSketch(width=1000, depth=5) diff --git a/tests/cuckoo_test.py b/tests/cuckoo_test.py index a75c52d..ab67cb7 100644 --- a/tests/cuckoo_test.py +++ b/tests/cuckoo_test.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- ''' Unittest class ''' from __future__ import (unicode_literals, absolute_import, print_function) +import os import unittest from probables import (CuckooFilter, CuckooFilterFullError) -# from . utilities import(calc_file_md5, different_hash) +from . utilities import(calc_file_md5) class TestCuckooFilter(unittest.TestCase): ''' base Cuckoo Filter test ''' @@ -15,13 +16,18 @@ def test_cuckoo_filter_default(self): self.assertEqual(10000, cko.capacity) self.assertEqual(4, cko.bucket_size) self.assertEqual(500, cko.max_swaps) + self.assertEqual(2, cko.expansion_rate) + self.assertEqual(True, cko.auto_expand) def test_cuckoo_filter_diff(self): ''' test cuckoo filter non-standard properties ''' - cko = CuckooFilter(capacity=100, bucket_size=2, max_swaps=5) + cko = CuckooFilter(capacity=100, bucket_size=2, max_swaps=5, + expansion_rate=4, auto_expand=False) self.assertEqual(100, cko.capacity) self.assertEqual(2, cko.bucket_size) self.assertEqual(5, cko.max_swaps) + self.assertEqual(4, cko.expansion_rate) + self.assertEqual(False, cko.auto_expand) def test_cuckoo_filter_add(self): ''' test adding to the cuckoo filter ''' @@ -77,7 +83,8 @@ def test_cuckoo_filter_lots(self): def test_cuckoo_filter_full(self): ''' test inserting until cuckoo filter is full ''' def runner(): - cko = CuckooFilter(capacity=100, bucket_size=2, max_swaps=100) + cko = CuckooFilter(capacity=100, bucket_size=2, max_swaps=100, + auto_expand=False) for i in range(175): cko.add(str(i)) self.assertRaises(CuckooFilterFullError, runner) @@ -85,11 +92,14 @@ def runner(): def test_cuckoo_full_msg(self): ''' test exception message for full cuckoo filter ''' try: - cko = CuckooFilter(capacity=100, bucket_size=2, max_swaps=100) + cko = CuckooFilter(capacity=100, bucket_size=2, max_swaps=100, + auto_expand=False) for i in range(175): cko.add(str(i)) except CuckooFilterFullError as ex: self.assertEqual(str(ex), 'The CuckooFilter is currently full') + else: + self.assertEqual(True, False) def test_cuckoo_idx(self): ''' test that the indexing works correctly for cuckoo filter swap ''' @@ -123,3 +133,79 @@ def test_cuckoo_filter_in(self): self.assertEqual('this is yet another test' in cko, True) self.assertEqual('this is not another test' in cko, False) self.assertEqual('this is not a test' in cko, False) + + def test_cuckoo_filter_dup_add(self): + ''' test adding same item multiple times cuckoo filter ''' + cko = CuckooFilter() + cko.add('this is a test') + cko.add('this is another test') + cko.add('this is yet another test') + self.assertEqual(cko.elements_added, 3) + cko.add('this is a test') + cko.add('this is another test') + cko.add('this is yet another test') + self.assertEqual(cko.elements_added, 3) + + def test_cuckoo_filter_l_fact(self): + ''' test the load factor of the cuckoo filter ''' + cko = CuckooFilter(capacity=100, bucket_size=2, max_swaps=10) + self.assertEqual(cko.load_factor(), 0.0) + for i in range(50): + cko.add(str(i)) + self.assertEqual(cko.load_factor(), 0.25) + for i in range(50): + cko.add(str(i + 50)) + self.assertEqual(cko.load_factor(), 0.50) + + def test_cuckoo_filter_export(self): + ''' test exporting a cuckoo filter ''' + filename = './test.cko' + md5sum = '49b947ddf364d27934570a6b33076b93' + cko = CuckooFilter() + for i in range(1000): + cko.add(str(i)) + cko.export(filename) + md5_out = calc_file_md5(filename) + self.assertEqual(md5sum, md5_out) + os.remove(filename) + + def test_cuckoo_filter_load(self): + ''' test loading a saved cuckoo filter ''' + filename = './test.cko' + md5sum = '49b947ddf364d27934570a6b33076b93' + cko = CuckooFilter() + for i in range(1000): + cko.add(str(i)) + cko.export(filename) + md5_out = calc_file_md5(filename) + self.assertEqual(md5sum, md5_out) + + ckf = CuckooFilter(filepath='./test.cko') + for i in range(1000): + self.assertTrue(ckf.check(str(i))) + + self.assertEqual(10000, ckf.capacity) + self.assertEqual(4, ckf.bucket_size) + self.assertEqual(500, ckf.max_swaps) + self.assertEqual(0.025, ckf.load_factor()) + os.remove(filename) + + def test_cuckoo_filter_expand_els(self): + ''' test out the expansion of the cuckoo filter ''' + cko = CuckooFilter() + for i in range(200): + cko.add(str(i)) + cko.expand() + for i in range(200): + self.assertTrue(cko.check(str(i))) + self.assertEqual(20000, cko.capacity) + + def test_cuckoo_filter_auto_expand(self): + ''' test inserting until cuckoo filter is full ''' + cko = CuckooFilter(capacity=100, bucket_size=2, max_swaps=100) + for i in range(375): # this would fail if it doesn't expand + cko.add(str(i)) + self.assertEqual(400, cko.capacity) + self.assertEqual(375, cko.elements_added) + for i in range(375): + self.assertTrue(cko.check(str(i)))