|
#ifndef UTIL_PROBING_HASH_TABLE_H |
|
#define UTIL_PROBING_HASH_TABLE_H |
|
|
|
#include "util/exception.hh" |
|
#include "util/scoped.hh" |
|
|
|
#include <algorithm> |
|
#include <cstddef> |
|
#include <functional> |
|
#include <vector> |
|
|
|
#include <assert.h> |
|
#include <stdint.h> |
|
|
|
namespace util { |
|
|
|
|
|
class ProbingSizeException : public Exception { |
|
public: |
|
ProbingSizeException() throw() {} |
|
~ProbingSizeException() throw() {} |
|
}; |
|
|
|
|
|
struct IdentityHash { |
|
template <class T> T operator()(T arg) const { return arg; } |
|
}; |
|
|
|
template <class EntryT, class HashT, class EqualT> class AutoProbing; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class ProbingHashTable { |
|
public: |
|
typedef EntryT Entry; |
|
typedef typename Entry::Key Key; |
|
typedef const Entry *ConstIterator; |
|
typedef Entry *MutableIterator; |
|
typedef HashT Hash; |
|
typedef EqualT Equal; |
|
|
|
static uint64_t Size(uint64_t entries, float multiplier) { |
|
uint64_t buckets = std::max(entries + 1, static_cast<uint64_t>(multiplier * static_cast<float>(entries))); |
|
return buckets * sizeof(Entry); |
|
} |
|
|
|
|
|
ProbingHashTable() : entries_(0) |
|
#ifdef DEBUG |
|
, initialized_(false) |
|
#endif |
|
{} |
|
|
|
ProbingHashTable(void *start, std::size_t allocated, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) |
|
: begin_(reinterpret_cast<MutableIterator>(start)), |
|
buckets_(allocated / sizeof(Entry)), |
|
end_(begin_ + buckets_), |
|
invalid_(invalid), |
|
hash_(hash_func), |
|
equal_(equal_func), |
|
entries_(0) |
|
#ifdef DEBUG |
|
, initialized_(true) |
|
#endif |
|
{} |
|
|
|
void Relocate(void *new_base) { |
|
begin_ = reinterpret_cast<MutableIterator>(new_base); |
|
end_ = begin_ + buckets_; |
|
} |
|
|
|
template <class T> MutableIterator Insert(const T &t) { |
|
#ifdef DEBUG |
|
assert(initialized_); |
|
#endif |
|
UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); |
|
return UncheckedInsert(t); |
|
} |
|
|
|
|
|
template <class T> bool FindOrInsert(const T &t, MutableIterator &out) { |
|
#ifdef DEBUG |
|
assert(initialized_); |
|
#endif |
|
for (MutableIterator i = Ideal(t);;) { |
|
Key got(i->GetKey()); |
|
if (equal_(got, t.GetKey())) { out = i; return true; } |
|
if (equal_(got, invalid_)) { |
|
UTIL_THROW_IF(++entries_ >= buckets_, ProbingSizeException, "Hash table with " << buckets_ << " buckets is full."); |
|
*i = t; |
|
out = i; |
|
return false; |
|
} |
|
if (++i == end_) i = begin_; |
|
} |
|
} |
|
|
|
void FinishedInserting() {} |
|
|
|
|
|
template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { |
|
#ifdef DEBUG |
|
assert(initialized_); |
|
#endif |
|
for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) { |
|
Key got(i->GetKey()); |
|
if (equal_(got, key)) { out = i; return true; } |
|
if (equal_(got, invalid_)) return false; |
|
if (++i == end_) i = begin_; |
|
} |
|
} |
|
|
|
|
|
template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) { |
|
for (MutableIterator i(begin_ + (hash_(key) % buckets_));;) { |
|
Key got(i->GetKey()); |
|
if (equal_(got, key)) { return i; } |
|
assert(!equal_(got, invalid_)); |
|
if (++i == end_) i = begin_; |
|
} |
|
} |
|
|
|
|
|
template <class Key> bool Find(const Key key, ConstIterator &out) const { |
|
#ifdef DEBUG |
|
assert(initialized_); |
|
#endif |
|
for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) { |
|
Key got(i->GetKey()); |
|
if (equal_(got, key)) { out = i; return true; } |
|
if (equal_(got, invalid_)) return false; |
|
if (++i == end_) i = begin_; |
|
} |
|
} |
|
|
|
|
|
template <class Key> ConstIterator MustFind(const Key key) const { |
|
for (ConstIterator i(begin_ + (hash_(key) % buckets_));;) { |
|
Key got(i->GetKey()); |
|
if (equal_(got, key)) { return i; } |
|
assert(!equal_(got, invalid_)); |
|
if (++i == end_) i = begin_; |
|
} |
|
} |
|
|
|
void Clear() { |
|
Entry invalid; |
|
invalid.SetKey(invalid_); |
|
std::fill(begin_, end_, invalid); |
|
entries_ = 0; |
|
} |
|
|
|
|
|
std::size_t SizeNoSerialization() const { |
|
return entries_; |
|
} |
|
|
|
|
|
std::size_t DoubleTo() const { |
|
return buckets_ * 2 * sizeof(Entry); |
|
} |
|
|
|
|
|
|
|
|
|
void Double(void *new_base, bool clear_new = true) { |
|
begin_ = static_cast<MutableIterator>(new_base); |
|
MutableIterator old_end = begin_ + buckets_; |
|
buckets_ *= 2; |
|
end_ = begin_ + buckets_; |
|
if (clear_new) { |
|
Entry invalid; |
|
invalid.SetKey(invalid_); |
|
std::fill(old_end, end_, invalid); |
|
} |
|
std::vector<Entry> rolled_over; |
|
|
|
for (MutableIterator i = begin_; i != old_end && !equal_(i->GetKey(), invalid_); ++i) { |
|
rolled_over.push_back(*i); |
|
i->SetKey(invalid_); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
Entry temp; |
|
for (MutableIterator i = begin_; i != old_end; ++i) { |
|
if (!equal_(i->GetKey(), invalid_)) { |
|
temp = *i; |
|
i->SetKey(invalid_); |
|
UncheckedInsert(temp); |
|
} |
|
} |
|
|
|
for (typename std::vector<Entry>::const_iterator i(rolled_over.begin()); i != rolled_over.end(); ++i) { |
|
UncheckedInsert(*i); |
|
} |
|
} |
|
|
|
|
|
void CheckConsistency() { |
|
MutableIterator last; |
|
for (last = end_ - 1; last >= begin_ && !equal_(last->GetKey(), invalid_); --last) {} |
|
UTIL_THROW_IF(last == begin_, ProbingSizeException, "Completely full"); |
|
MutableIterator i; |
|
|
|
for (i = begin_; !equal_(i->GetKey(), invalid_); ++i) { |
|
MutableIterator ideal = Ideal(*i); |
|
UTIL_THROW_IF(ideal > i && ideal <= last, Exception, "Inconsistency at position " << (i - begin_) << " should be at " << (ideal - begin_)); |
|
} |
|
MutableIterator pre_gap = i; |
|
for (; i != end_; ++i) { |
|
if (equal_(i->GetKey(), invalid_)) { |
|
pre_gap = i; |
|
continue; |
|
} |
|
MutableIterator ideal = Ideal(*i); |
|
UTIL_THROW_IF(ideal > i || ideal <= pre_gap, Exception, "Inconsistency at position " << (i - begin_) << " with ideal " << (ideal - begin_)); |
|
} |
|
} |
|
|
|
private: |
|
friend class AutoProbing<Entry, Hash, Equal>; |
|
|
|
template <class T> MutableIterator Ideal(const T &t) { |
|
return begin_ + (hash_(t.GetKey()) % buckets_); |
|
} |
|
|
|
template <class T> MutableIterator UncheckedInsert(const T &t) { |
|
for (MutableIterator i(Ideal(t));;) { |
|
if (equal_(i->GetKey(), invalid_)) { *i = t; return i; } |
|
if (++i == end_) { i = begin_; } |
|
} |
|
} |
|
|
|
MutableIterator begin_; |
|
std::size_t buckets_; |
|
MutableIterator end_; |
|
Key invalid_; |
|
Hash hash_; |
|
Equal equal_; |
|
std::size_t entries_; |
|
#ifdef DEBUG |
|
bool initialized_; |
|
#endif |
|
}; |
|
|
|
|
|
template <class EntryT, class HashT, class EqualT = std::equal_to<typename EntryT::Key> > class AutoProbing { |
|
private: |
|
typedef ProbingHashTable<EntryT, HashT, EqualT> Backend; |
|
public: |
|
static std::size_t MemUsage(std::size_t size, float multiplier = 1.5) { |
|
return Backend::Size(size, multiplier); |
|
} |
|
|
|
typedef EntryT Entry; |
|
typedef typename Entry::Key Key; |
|
typedef const Entry *ConstIterator; |
|
typedef Entry *MutableIterator; |
|
typedef HashT Hash; |
|
typedef EqualT Equal; |
|
|
|
AutoProbing(std::size_t initial_size = 10, const Key &invalid = Key(), const Hash &hash_func = Hash(), const Equal &equal_func = Equal()) : |
|
allocated_(Backend::Size(initial_size, 1.5)), mem_(util::MallocOrThrow(allocated_)), backend_(mem_.get(), allocated_, invalid, hash_func, equal_func) { |
|
threshold_ = initial_size * 1.2; |
|
Clear(); |
|
} |
|
|
|
|
|
template <class T> MutableIterator Insert(const T &t) { |
|
DoubleIfNeeded(); |
|
return backend_.UncheckedInsert(t); |
|
} |
|
|
|
template <class T> bool FindOrInsert(const T &t, MutableIterator &out) { |
|
DoubleIfNeeded(); |
|
return backend_.FindOrInsert(t, out); |
|
} |
|
|
|
template <class Key> bool UnsafeMutableFind(const Key key, MutableIterator &out) { |
|
return backend_.UnsafeMutableFind(key, out); |
|
} |
|
|
|
template <class Key> MutableIterator UnsafeMutableMustFind(const Key key) { |
|
return backend_.UnsafeMutableMustFind(key); |
|
} |
|
|
|
template <class Key> bool Find(const Key key, ConstIterator &out) const { |
|
return backend_.Find(key, out); |
|
} |
|
|
|
template <class Key> ConstIterator MustFind(const Key key) const { |
|
return backend_.MustFind(key); |
|
} |
|
|
|
std::size_t Size() const { |
|
return backend_.SizeNoSerialization(); |
|
} |
|
|
|
void Clear() { |
|
backend_.Clear(); |
|
} |
|
|
|
private: |
|
void DoubleIfNeeded() { |
|
if (Size() < threshold_) |
|
return; |
|
mem_.call_realloc(backend_.DoubleTo()); |
|
allocated_ = backend_.DoubleTo(); |
|
backend_.Double(mem_.get()); |
|
threshold_ *= 2; |
|
} |
|
|
|
std::size_t allocated_; |
|
util::scoped_malloc mem_; |
|
Backend backend_; |
|
std::size_t threshold_; |
|
}; |
|
|
|
} |
|
|
|
#endif |
|
|