Aging_MouthReplace / dlibs /dlib /conditioning_class /conditioning_class_kernel_2.h
AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
// Copyright (C) 2003 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CONDITIONING_CLASS_KERNEl_2_
#define DLIB_CONDITIONING_CLASS_KERNEl_2_
#include "conditioning_class_kernel_abstract.h"
#include "../assert.h"
#include "../algs.h"
namespace dlib
{
template <
unsigned long alphabet_size
>
class conditioning_class_kernel_2
{
/*!
INITIAL VALUE
total == 1
symbols == pointer to array of alphabet_size data structs
for all i except i == alphabet_size-1: symbols[i].count == 0
symbols[i].left_count == 0
symbols[alphabet_size-1].count == 1
symbols[alpahbet_size-1].left_count == 0
CONVENTION
symbols == pointer to array of alphabet_size data structs
get_total() == total
get_count(symbol) == symbols[symbol].count
symbols is organized as a tree with symbols[0] as the root.
the left subchild of symbols[i] is symbols[i*2+1] and
the right subchild is symbols[i*2+2].
the partent of symbols[i] == symbols[(i-1)/2]
symbols[i].left_count == the sum of the counts of all the
symbols to the left of symbols[i]
get_memory_usage() == global_state.memory_usage
!*/
public:
class global_state_type
{
public:
global_state_type () : memory_usage(0) {}
private:
unsigned long memory_usage;
friend class conditioning_class_kernel_2<alphabet_size>;
};
conditioning_class_kernel_2 (
global_state_type& global_state_
);
~conditioning_class_kernel_2 (
);
void clear(
);
bool increment_count (
unsigned long symbol,
unsigned short amount = 1
);
unsigned long get_count (
unsigned long symbol
) const;
inline unsigned long get_total (
) const;
unsigned long get_range (
unsigned long symbol,
unsigned long& low_count,
unsigned long& high_count,
unsigned long& total_count
) const;
void get_symbol (
unsigned long target,
unsigned long& symbol,
unsigned long& low_count,
unsigned long& high_count
) const;
unsigned long get_memory_usage (
) const;
global_state_type& get_global_state (
);
static unsigned long get_alphabet_size (
);
private:
// restricted functions
conditioning_class_kernel_2(conditioning_class_kernel_2<alphabet_size>&); // copy constructor
conditioning_class_kernel_2& operator=(conditioning_class_kernel_2<alphabet_size>&); // assignment operator
// data members
unsigned short total;
struct data
{
unsigned short count;
unsigned short left_count;
};
data* symbols;
global_state_type& global_state;
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// member function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
conditioning_class_kernel_2<alphabet_size>::
conditioning_class_kernel_2 (
global_state_type& global_state_
) :
total(1),
symbols(new data[alphabet_size]),
global_state(global_state_)
{
COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 );
data* start = symbols;
data* end = symbols + alphabet_size-1;
while (start != end)
{
start->count = 0;
start->left_count = 0;
++start;
}
start->count = 1;
start->left_count = 0;
// update the left_counts for the symbol alphabet_size-1
unsigned short temp;
unsigned long symbol = alphabet_size-1;
while (symbol != 0)
{
// temp will be 1 if symbol is odd, 0 if it is even
temp = static_cast<unsigned short>(symbol&0x1);
// set symbol to its parent
symbol = (symbol-1)>>1;
// note that all left subchidren are odd and also that
// if symbol was a left subchild then we want to increment
// its parents left_count
if (temp)
++symbols[symbol].left_count;
}
global_state.memory_usage += sizeof(data)*alphabet_size +
sizeof(conditioning_class_kernel_2);
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
conditioning_class_kernel_2<alphabet_size>::
~conditioning_class_kernel_2 (
)
{
delete [] symbols;
global_state.memory_usage -= sizeof(data)*alphabet_size +
sizeof(conditioning_class_kernel_2);
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
void conditioning_class_kernel_2<alphabet_size>::
clear(
)
{
data* start = symbols;
data* end = symbols + alphabet_size-1;
total = 1;
while (start != end)
{
start->count = 0;
start->left_count = 0;
++start;
}
start->count = 1;
start->left_count = 0;
// update the left_counts
unsigned short temp;
unsigned long symbol = alphabet_size-1;
while (symbol != 0)
{
// temp will be 1 if symbol is odd, 0 if it is even
temp = static_cast<unsigned short>(symbol&0x1);
// set symbol to its parent
symbol = (symbol-1)>>1;
// note that all left subchidren are odd and also that
// if symbol was a left subchild then we want to increment
// its parents left_count
symbols[symbol].left_count += temp;
}
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_2<alphabet_size>::
get_memory_usage(
) const
{
return global_state.memory_usage;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
typename conditioning_class_kernel_2<alphabet_size>::global_state_type& conditioning_class_kernel_2<alphabet_size>::
get_global_state(
)
{
return global_state;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
bool conditioning_class_kernel_2<alphabet_size>::
increment_count (
unsigned long symbol,
unsigned short amount
)
{
// if we need to renormalize then do so
if (static_cast<unsigned long>(total)+static_cast<unsigned long>(amount) >= 65536)
{
unsigned long s;
unsigned short temp;
for (unsigned short i = 0; i < alphabet_size-1; ++i)
{
s = i;
// divide the count for this symbol by 2
symbols[i].count >>= 1;
symbols[i].left_count = 0;
// bubble this change up though the tree
while (s != 0)
{
// temp will be 1 if symbol is odd, 0 if it is even
temp = static_cast<unsigned short>(s&0x1);
// set s to its parent
s = (s-1)>>1;
// note that all left subchidren are odd and also that
// if s was a left subchild then we want to increment
// its parents left_count
if (temp)
symbols[s].left_count += symbols[i].count;
}
}
// update symbols alphabet_size-1
{
s = alphabet_size-1;
// divide alphabet_size-1 symbol by 2 if it's > 1
if (symbols[alphabet_size-1].count > 1)
symbols[alphabet_size-1].count >>= 1;
// bubble this change up though the tree
while (s != 0)
{
// temp will be 1 if symbol is odd, 0 if it is even
temp = static_cast<unsigned short>(s&0x1);
// set s to its parent
s = (s-1)>>1;
// note that all left subchidren are odd and also that
// if s was a left subchild then we want to increment
// its parents left_count
if (temp)
symbols[s].left_count += symbols[alphabet_size-1].count;
}
}
// calculate the new total
total = 0;
unsigned long m = 0;
while (m < alphabet_size)
{
total += symbols[m].count + symbols[m].left_count;
m = (m<<1) + 2;
}
}
// increment the count for the specified symbol
symbols[symbol].count += amount;;
total += amount;
unsigned short temp;
while (symbol != 0)
{
// temp will be 1 if symbol is odd, 0 if it is even
temp = static_cast<unsigned short>(symbol&0x1);
// set symbol to its parent
symbol = (symbol-1)>>1;
// note that all left subchidren are odd and also that
// if symbol was a left subchild then we want to increment
// its parents left_count
if (temp)
symbols[symbol].left_count += amount;
}
return true;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_2<alphabet_size>::
get_count (
unsigned long symbol
) const
{
return symbols[symbol].count;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_2<alphabet_size>::
get_alphabet_size (
)
{
return alphabet_size;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_2<alphabet_size>::
get_total (
) const
{
return total;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_2<alphabet_size>::
get_range (
unsigned long symbol,
unsigned long& low_count,
unsigned long& high_count,
unsigned long& total_count
) const
{
if (symbols[symbol].count == 0)
return 0;
unsigned long current = symbol;
total_count = total;
unsigned long high_count_temp = 0;
bool came_from_right = true;
while (true)
{
if (came_from_right)
{
high_count_temp += symbols[current].count + symbols[current].left_count;
}
// note that if current is even then it is a right child
came_from_right = !(current&0x1);
if (current == 0)
break;
// set current to its parent
current = (current-1)>>1 ;
}
low_count = high_count_temp - symbols[symbol].count;
high_count = high_count_temp;
return symbols[symbol].count;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
void conditioning_class_kernel_2<alphabet_size>::
get_symbol (
unsigned long target,
unsigned long& symbol,
unsigned long& low_count,
unsigned long& high_count
) const
{
unsigned long current = 0;
unsigned long low_count_temp = 0;
while (true)
{
if (static_cast<unsigned short>(target) < symbols[current].left_count)
{
// we should go left
current = (current<<1) + 1;
}
else
{
target -= symbols[current].left_count;
low_count_temp += symbols[current].left_count;
if (static_cast<unsigned short>(target) < symbols[current].count)
{
// we have found our target
symbol = current;
high_count = low_count_temp + symbols[current].count;
low_count = low_count_temp;
break;
}
else
{
// go right
target -= symbols[current].count;
low_count_temp += symbols[current].count;
current = (current<<1) + 2;
}
}
}
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CONDITIONING_CLASS_KERNEl_1_