| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #ifndef FLANN_KDTREE_INDEX_H_ |
| #define FLANN_KDTREE_INDEX_H_ |
|
|
| #include <algorithm> |
| #include <map> |
| #include <cassert> |
| #include <cstring> |
| #include <stdarg.h> |
| #include <cmath> |
| #include <random> |
|
|
| #include "FLANN/general.h" |
| #include "FLANN/algorithms/nn_index.h" |
| #include "FLANN/util/dynamic_bitset.h" |
| #include "FLANN/util/matrix.h" |
| #include "FLANN/util/result_set.h" |
| #include "FLANN/util/heap.h" |
| #include "FLANN/util/allocator.h" |
| #include "FLANN/util/random.h" |
| #include "FLANN/util/saving.h" |
|
|
|
|
| namespace flann |
| { |
|
|
| struct KDTreeIndexParams : public IndexParams |
| { |
| KDTreeIndexParams(int trees = 4) |
| { |
| (*this)["algorithm"] = FLANN_INDEX_KDTREE; |
| (*this)["trees"] = trees; |
| } |
| }; |
|
|
|
|
| |
| |
| |
| |
| |
| |
| template <typename Distance> |
| class KDTreeIndex : public NNIndex<Distance> |
| { |
| public: |
| typedef typename Distance::ElementType ElementType; |
| typedef typename Distance::ResultType DistanceType; |
|
|
| typedef NNIndex<Distance> BaseClass; |
|
|
| typedef bool needs_kdtree_distance; |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| KDTreeIndex(const IndexParams& params = KDTreeIndexParams(), Distance d = Distance() ) : |
| BaseClass(params, d), mean_(NULL), var_(NULL) |
| { |
| trees_ = get_param(index_params_,"trees",4); |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| KDTreeIndex(const Matrix<ElementType>& dataset, const IndexParams& params = KDTreeIndexParams(), |
| Distance d = Distance() ) : BaseClass(params,d ), mean_(NULL), var_(NULL) |
| { |
| trees_ = get_param(index_params_,"trees",4); |
|
|
| setDataset(dataset); |
| } |
|
|
| KDTreeIndex(const KDTreeIndex& other) : BaseClass(other), |
| trees_(other.trees_) |
| { |
| tree_roots_.resize(other.tree_roots_.size()); |
| for (size_t i=0;i<tree_roots_.size();++i) { |
| copyTree(tree_roots_[i], other.tree_roots_[i]); |
| } |
| } |
|
|
| KDTreeIndex& operator=(KDTreeIndex other) |
| { |
| this->swap(other); |
| return *this; |
| } |
|
|
| |
| |
| |
| virtual ~KDTreeIndex() |
| { |
| freeIndex(); |
| } |
|
|
| BaseClass* clone() const |
| { |
| return new KDTreeIndex(*this); |
| } |
|
|
| using BaseClass::buildIndex; |
|
|
| void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2) |
| { |
| assert(points.cols==veclen_); |
|
|
| size_t old_size = size_; |
| extendDataset(points); |
|
|
| if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) { |
| buildIndex(); |
| } |
| else { |
| for (size_t i=old_size;i<size_;++i) { |
| for (int j = 0; j < trees_; j++) { |
| addPointToTree(tree_roots_[j], i); |
| } |
| } |
| } |
| } |
|
|
| flann_algorithm_t getType() const |
| { |
| return FLANN_INDEX_KDTREE; |
| } |
|
|
|
|
| template<typename Archive> |
| void serialize(Archive& ar) |
| { |
| ar.setObject(this); |
|
|
| ar & *static_cast<NNIndex<Distance>*>(this); |
|
|
| ar & trees_; |
|
|
| if (Archive::is_loading::value) { |
| tree_roots_.resize(trees_); |
| } |
| for (size_t i=0;i<tree_roots_.size();++i) { |
| if (Archive::is_loading::value) { |
| tree_roots_[i] = new(pool_) Node(); |
| } |
| ar & *tree_roots_[i]; |
| } |
|
|
| if (Archive::is_loading::value) { |
| index_params_["algorithm"] = getType(); |
| index_params_["trees"] = trees_; |
| } |
| } |
|
|
|
|
| void saveIndex(FILE* stream) |
| { |
| serialization::SaveArchive sa(stream); |
| sa & *this; |
| } |
|
|
|
|
| void loadIndex(FILE* stream) |
| { |
| freeIndex(); |
| serialization::LoadArchive la(stream); |
| la & *this; |
| } |
|
|
| |
| |
| |
| |
| int usedMemory() const |
| { |
| return int(pool_.usedMemory+pool_.wastedMemory+size_*sizeof(int)); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const |
| { |
| int maxChecks = searchParams.checks; |
| float epsError = 1+searchParams.eps; |
|
|
| if (maxChecks==FLANN_CHECKS_UNLIMITED) { |
| if (removed_) { |
| getExactNeighbors<true>(result, vec, epsError); |
| } |
| else { |
| getExactNeighbors<false>(result, vec, epsError); |
| } |
| } |
| else { |
| if (removed_) { |
| getNeighbors<true>(result, vec, maxChecks, epsError); |
| } |
| else { |
| getNeighbors<false>(result, vec, maxChecks, epsError); |
| } |
| } |
| } |
|
|
| protected: |
|
|
| |
| |
| |
| void buildIndexImpl() |
| { |
| |
| std::vector<int> ind(size_); |
| for (size_t i = 0; i < size_; ++i) { |
| ind[i] = int(i); |
| } |
|
|
| mean_ = new DistanceType[veclen_]; |
| var_ = new DistanceType[veclen_]; |
|
|
| std::default_random_engine generator; |
|
|
| tree_roots_.resize(trees_); |
| |
| for (int i = 0; i < trees_; i++) { |
| |
| std::shuffle(ind.begin(), ind.end(), generator); |
| tree_roots_[i] = divideTree(&ind[0], int(size_) ); |
| } |
| delete[] mean_; |
| delete[] var_; |
| } |
|
|
| void freeIndex() |
| { |
| for (size_t i=0;i<tree_roots_.size();++i) { |
| |
| if (tree_roots_[i]!=NULL) tree_roots_[i]->~Node(); |
| } |
| pool_.free(); |
| } |
|
|
|
|
| private: |
|
|
| |
| struct Node |
| { |
| |
| |
| |
| int divfeat; |
| |
| |
| |
| DistanceType divval; |
| |
| |
| |
| ElementType* point; |
| |
| |
| |
| Node* child1, *child2; |
| Node(){ |
| child1 = NULL; |
| child2 = NULL; |
| } |
| ~Node() { |
| if (child1 != NULL) { child1->~Node(); child1 = NULL; } |
|
|
| if (child2 != NULL) { child2->~Node(); child2 = NULL; } |
| } |
|
|
| private: |
| template<typename Archive> |
| void serialize(Archive& ar) |
| { |
| typedef KDTreeIndex<Distance> Index; |
| Index* obj = static_cast<Index*>(ar.getObject()); |
|
|
| ar & divfeat; |
| ar & divval; |
|
|
| bool leaf_node = false; |
| if (Archive::is_saving::value) { |
| leaf_node = ((child1==NULL) && (child2==NULL)); |
| } |
| ar & leaf_node; |
|
|
| if (leaf_node) { |
| if (Archive::is_loading::value) { |
| point = obj->points_[divfeat]; |
| } |
| } |
|
|
| if (!leaf_node) { |
| if (Archive::is_loading::value) { |
| child1 = new(obj->pool_) Node(); |
| child2 = new(obj->pool_) Node(); |
| } |
| ar & *child1; |
| ar & *child2; |
| } |
| } |
| friend struct serialization::access; |
| }; |
| typedef Node* NodePtr; |
| typedef BranchStruct<NodePtr, DistanceType> BranchSt; |
| typedef BranchSt* Branch; |
|
|
|
|
| void copyTree(NodePtr& dst, const NodePtr& src) |
| { |
| dst = new(pool_) Node(); |
| dst->divfeat = src->divfeat; |
| dst->divval = src->divval; |
| if (src->child1==NULL && src->child2==NULL) { |
| dst->point = points_[dst->divfeat]; |
| dst->child1 = NULL; |
| dst->child2 = NULL; |
| } |
| else { |
| copyTree(dst->child1, src->child1); |
| copyTree(dst->child2, src->child2); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| NodePtr divideTree(int* ind, int count) |
| { |
| NodePtr node = new(pool_) Node(); |
|
|
| |
| if (count == 1) { |
| node->child1 = node->child2 = NULL; |
| node->divfeat = *ind; |
| node->point = points_[*ind]; |
| } |
| else { |
| int idx; |
| int cutfeat; |
| DistanceType cutval; |
| meanSplit(ind, count, idx, cutfeat, cutval); |
|
|
| node->divfeat = cutfeat; |
| node->divval = cutval; |
| node->child1 = divideTree(ind, idx); |
| node->child2 = divideTree(ind+idx, count-idx); |
| } |
|
|
| return node; |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| void meanSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval) |
| { |
| memset(mean_,0,veclen_*sizeof(DistanceType)); |
| memset(var_,0,veclen_*sizeof(DistanceType)); |
|
|
| |
| |
| |
| int cnt = std::min((int)SAMPLE_MEAN+1, count); |
| for (int j = 0; j < cnt; ++j) { |
| ElementType* v = points_[ind[j]]; |
| for (size_t k=0; k<veclen_; ++k) { |
| mean_[k] += v[k]; |
| } |
| } |
| DistanceType div_factor = DistanceType(1)/cnt; |
| for (size_t k=0; k<veclen_; ++k) { |
| mean_[k] *= div_factor; |
| } |
|
|
| |
| for (int j = 0; j < cnt; ++j) { |
| ElementType* v = points_[ind[j]]; |
| for (size_t k=0; k<veclen_; ++k) { |
| DistanceType dist = v[k] - mean_[k]; |
| var_[k] += dist * dist; |
| } |
| } |
| |
| cutfeat = selectDivision(var_); |
| cutval = mean_[cutfeat]; |
|
|
| int lim1, lim2; |
| planeSplit(ind, count, cutfeat, cutval, lim1, lim2); |
|
|
| if (lim1>count/2) index = lim1; |
| else if (lim2<count/2) index = lim2; |
| else index = count/2; |
|
|
| |
| |
| |
| if ((lim1==count)||(lim2==0)) index = count/2; |
| } |
|
|
|
|
| |
| |
| |
| |
| int selectDivision(DistanceType* v) |
| { |
| int num = 0; |
| size_t topind[RAND_DIM]; |
|
|
| |
| for (size_t i = 0; i < veclen_; ++i) { |
| if ((num < RAND_DIM)||(v[i] > v[topind[num-1]])) { |
| |
| if (num < RAND_DIM) { |
| topind[num++] = i; |
| } |
| else { |
| topind[num-1] = i; |
| } |
| |
| int j = num - 1; |
| while (j > 0 && v[topind[j]] > v[topind[j-1]]) { |
| std::swap(topind[j], topind[j-1]); |
| --j; |
| } |
| } |
| } |
| |
| int rnd = rand_int(num); |
| return (int)topind[rnd]; |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2) |
| { |
| |
| int left = 0; |
| int right = count-1; |
| for (;; ) { |
| while (left<=right && points_[ind[left]][cutfeat]<cutval) ++left; |
| while (left<=right && points_[ind[right]][cutfeat]>=cutval) --right; |
| if (left>right) break; |
| std::swap(ind[left], ind[right]); ++left; --right; |
| } |
| lim1 = left; |
| right = count-1; |
| for (;; ) { |
| while (left<=right && points_[ind[left]][cutfeat]<=cutval) ++left; |
| while (left<=right && points_[ind[right]][cutfeat]>cutval) --right; |
| if (left>right) break; |
| std::swap(ind[left], ind[right]); ++left; --right; |
| } |
| lim2 = left; |
| } |
|
|
| |
| |
| |
| |
| template<bool with_removed> |
| void getExactNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, float epsError) const |
| { |
| |
|
|
| if (trees_ > 1) { |
| fprintf(stderr,"It doesn't make any sense to use more than one tree for exact search"); |
| } |
| if (trees_>0) { |
| searchLevelExact<with_removed>(result, vec, tree_roots_[0], 0.0, epsError); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| template<bool with_removed> |
| void getNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, int maxCheck, float epsError) const |
| { |
| int i; |
| BranchSt branch; |
|
|
| int checkCount = 0; |
| Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_); |
| DynamicBitset checked(size_); |
|
|
| |
| for (i = 0; i < trees_; ++i) { |
| searchLevel<with_removed>(result, vec, tree_roots_[i], 0, checkCount, maxCheck, epsError, heap, checked); |
| } |
|
|
| |
| while ( heap->popMin(branch) && (checkCount < maxCheck || !result.full() )) { |
| searchLevel<with_removed>(result, vec, branch.node, branch.mindist, checkCount, maxCheck, epsError, heap, checked); |
| } |
|
|
| delete heap; |
|
|
| } |
|
|
| |
| |
| |
| |
| |
| template<bool with_removed> |
| void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, NodePtr node, DistanceType mindist, int& checkCount, int maxCheck, |
| float epsError, Heap<BranchSt>* heap, DynamicBitset& checked) const |
| { |
| if (result_set.worstDist()<mindist) { |
| |
| return; |
| } |
|
|
| |
| if ((node->child1 == NULL)&&(node->child2 == NULL)) { |
| int index = node->divfeat; |
| if (with_removed) { |
| if (removed_points_.test(index)) return; |
| } |
| |
| if ( checked.test(index) || ((checkCount>=maxCheck)&& result_set.full()) ) return; |
| checked.set(index); |
| checkCount++; |
|
|
| DistanceType dist = distance_(node->point, vec, veclen_); |
| result_set.addPoint(dist,index); |
| return; |
| } |
|
|
| |
| ElementType val = vec[node->divfeat]; |
| DistanceType diff = val - node->divval; |
| NodePtr bestChild = (diff < 0) ? node->child1 : node->child2; |
| NodePtr otherChild = (diff < 0) ? node->child2 : node->child1; |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat); |
| |
| if ((new_distsq*epsError < result_set.worstDist())|| !result_set.full()) { |
| heap->insert( BranchSt(otherChild, new_distsq) ); |
| } |
|
|
| |
| searchLevel<with_removed>(result_set, vec, bestChild, mindist, checkCount, maxCheck, epsError, heap, checked); |
| } |
|
|
| |
| |
| |
| template<bool with_removed> |
| void searchLevelExact(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindist, const float epsError) const |
| { |
| |
| if ((node->child1 == NULL)&&(node->child2 == NULL)) { |
| int index = node->divfeat; |
| if (with_removed) { |
| if (removed_points_.test(index)) return; |
| } |
| DistanceType dist = distance_(node->point, vec, veclen_); |
| result_set.addPoint(dist,index); |
|
|
| return; |
| } |
|
|
| |
| ElementType val = vec[node->divfeat]; |
| DistanceType diff = val - node->divval; |
| NodePtr bestChild = (diff < 0) ? node->child1 : node->child2; |
| NodePtr otherChild = (diff < 0) ? node->child2 : node->child1; |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| DistanceType new_distsq = mindist + distance_.accum_dist(val, node->divval, node->divfeat); |
|
|
| |
| searchLevelExact<with_removed>(result_set, vec, bestChild, mindist, epsError); |
|
|
| if (mindist*epsError<=result_set.worstDist()) { |
| searchLevelExact<with_removed>(result_set, vec, otherChild, new_distsq, epsError); |
| } |
| } |
|
|
| void addPointToTree(NodePtr node, int ind) |
| { |
| ElementType* point = points_[ind]; |
|
|
| if ((node->child1==NULL) && (node->child2==NULL)) { |
| ElementType* leaf_point = node->point; |
| ElementType max_span = 0; |
| size_t div_feat = 0; |
| for (size_t i=0;i<veclen_;++i) { |
| ElementType span = std::abs(point[i]-leaf_point[i]); |
| if (span > max_span) { |
| max_span = span; |
| div_feat = i; |
| } |
| } |
| NodePtr left = new(pool_) Node(); |
| left->child1 = left->child2 = NULL; |
| NodePtr right = new(pool_) Node(); |
| right->child1 = right->child2 = NULL; |
|
|
| if (point[div_feat]<leaf_point[div_feat]) { |
| left->divfeat = ind; |
| left->point = point; |
| right->divfeat = node->divfeat; |
| right->point = node->point; |
| } |
| else { |
| left->divfeat = node->divfeat; |
| left->point = node->point; |
| right->divfeat = ind; |
| right->point = point; |
| } |
| node->divfeat = div_feat; |
| node->divval = (point[div_feat]+leaf_point[div_feat])/2; |
| node->child1 = left; |
| node->child2 = right; |
| } |
| else { |
| if (point[node->divfeat]<node->divval) { |
| addPointToTree(node->child1,ind); |
| } |
| else { |
| addPointToTree(node->child2,ind); |
| } |
| } |
| } |
| private: |
| void swap(KDTreeIndex& other) |
| { |
| BaseClass::swap(other); |
| std::swap(trees_, other.trees_); |
| std::swap(tree_roots_, other.tree_roots_); |
| std::swap(pool_, other.pool_); |
| } |
|
|
| private: |
|
|
| enum |
| { |
| |
| |
| |
| |
| |
| SAMPLE_MEAN = 100, |
| |
| |
| |
| |
| |
| |
| |
| RAND_DIM=5 |
| }; |
|
|
|
|
| |
| |
| |
| int trees_; |
|
|
| DistanceType* mean_; |
| DistanceType* var_; |
|
|
| |
| |
| |
| std::vector<NodePtr> tree_roots_; |
|
|
| |
| |
| |
| |
| |
| |
| |
| PooledAllocator pool_; |
|
|
| USING_BASECLASS_SYMBOLS |
| }; |
|
|
| } |
|
|
| #endif |
|
|