Aging_MouthReplace / dlibs /dlib /conditioning_class /conditioning_class_kernel_3.h
AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
// Copyright (C) 2004 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_CONDITIONING_CLASS_KERNEl_3_
#define DLIB_CONDITIONING_CLASS_KERNEl_3_
#include "conditioning_class_kernel_abstract.h"
#include "../assert.h"
#include "../algs.h"
namespace dlib
{
template <
unsigned long alphabet_size
>
class conditioning_class_kernel_3
{
/*!
INITIAL VALUE
total == 1
counts == pointer to an array of alphabet_size data structs
for all i except i == 0: counts[i].count == 0
counts[0].count == 1
counts[0].symbol == alphabet_size-1
for all i except i == alphabet_size-1: counts[i].present == false
counts[alphabet_size-1].present == true
CONVENTION
counts == pointer to an array of alphabet_size data structs
get_total() == total
get_count(symbol) == counts[x].count where
counts[x].symbol == symbol
LOW_COUNT(symbol) == sum of counts[0].count though counts[x-1].count
where counts[x].symbol == symbol
if (counts[0].symbol == symbol) LOW_COUNT(symbol)==0
if (counts[i].count == 0) then
counts[i].symbol == undefined value
if (symbol has a nonzero count) then
counts[symbol].present == true
get_memory_usage() == global_state.memory_usage
!*/
public:
class global_state_type
{
public:
global_state_type () : memory_usage(0) {}
private:
unsigned long memory_usage;
friend class conditioning_class_kernel_3<alphabet_size>;
};
conditioning_class_kernel_3 (
global_state_type& global_state_
);
~conditioning_class_kernel_3 (
);
void clear(
);
bool increment_count (
unsigned long symbol,
unsigned short amount = 1
);
unsigned long get_count (
unsigned long symbol
) const;
unsigned long get_total (
) const;
unsigned long get_range (
unsigned long symbol,
unsigned long& low_count,
unsigned long& high_count,
unsigned long& total_count
) const;
void get_symbol (
unsigned long target,
unsigned long& symbol,
unsigned long& low_count,
unsigned long& high_count
) const;
unsigned long get_memory_usage (
) const;
global_state_type& get_global_state (
);
static unsigned long get_alphabet_size (
);
private:
// restricted functions
conditioning_class_kernel_3(conditioning_class_kernel_3<alphabet_size>&); // copy constructor
conditioning_class_kernel_3& operator=(conditioning_class_kernel_3<alphabet_size>&); // assignment operator
struct data
{
unsigned short count;
unsigned short symbol;
bool present;
};
// data members
unsigned short total;
data* counts;
global_state_type& global_state;
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// member function definitions
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
conditioning_class_kernel_3<alphabet_size>::
conditioning_class_kernel_3 (
global_state_type& global_state_
) :
total(1),
counts(new data[alphabet_size]),
global_state(global_state_)
{
COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 );
data* start = counts;
data* end = counts+alphabet_size;
start->count = 1;
start->symbol = alphabet_size-1;
start->present = false;
++start;
while (start != end)
{
start->count = 0;
start->present = false;
++start;
}
counts[alphabet_size-1].present = true;
// update memory usage
global_state.memory_usage += sizeof(data)*alphabet_size +
sizeof(conditioning_class_kernel_3);
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
conditioning_class_kernel_3<alphabet_size>::
~conditioning_class_kernel_3 (
)
{
delete [] counts;
// update memory usage
global_state.memory_usage -= sizeof(data)*alphabet_size +
sizeof(conditioning_class_kernel_3);
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
void conditioning_class_kernel_3<alphabet_size>::
clear(
)
{
total = 1;
data* start = counts;
data* end = counts+alphabet_size;
start->count = 1;
start->symbol = alphabet_size-1;
start->present = false;
++start;
while (start != end)
{
start->count = 0;
start->present = false;
++start;
}
counts[alphabet_size-1].present = true;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
typename conditioning_class_kernel_3<alphabet_size>::global_state_type& conditioning_class_kernel_3<alphabet_size>::
get_global_state(
)
{
return global_state;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_3<alphabet_size>::
get_memory_usage(
) const
{
return global_state.memory_usage;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
bool conditioning_class_kernel_3<alphabet_size>::
increment_count (
unsigned long symbol,
unsigned short amount
)
{
// if we are going over a total of 65535 then scale down all counts by 2
if (static_cast<unsigned long>(total)+static_cast<unsigned long>(amount) >= 65536)
{
total = 0;
data* start = counts;
data* end = counts+alphabet_size;
while (start != end)
{
if (start->count == 1)
{
if (start->symbol == alphabet_size-1)
{
// this symbol must never be zero so we will leave its count at 1
++total;
}
else
{
start->count = 0;
counts[start->symbol].present = false;
}
}
else
{
start->count >>= 1;
total += start->count;
}
++start;
}
}
data* start = counts;
data* swap_spot = counts;
if (counts[symbol].present)
{
while (true)
{
if (start->symbol == symbol && start->count!=0)
{
unsigned short temp = start->count + amount;
start->symbol = swap_spot->symbol;
start->count = swap_spot->count;
swap_spot->symbol = static_cast<unsigned short>(symbol);
swap_spot->count = temp;
break;
}
if ( (start->count) < (swap_spot->count))
{
swap_spot = start;
}
++start;
}
}
else
{
counts[symbol].present = true;
while (true)
{
if (start->count == 0)
{
start->symbol = swap_spot->symbol;
start->count = swap_spot->count;
swap_spot->symbol = static_cast<unsigned short>(symbol);
swap_spot->count = amount;
break;
}
if ((start->count) < (swap_spot->count))
{
swap_spot = start;
}
++start;
}
}
total += amount;
return true;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_3<alphabet_size>::
get_count (
unsigned long symbol
) const
{
if (counts[symbol].present == false)
return 0;
data* start = counts;
while (start->symbol != symbol)
{
++start;
}
return start->count;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_3<alphabet_size>::
get_alphabet_size (
)
{
return alphabet_size;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_3<alphabet_size>::
get_total (
) const
{
return total;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
unsigned long conditioning_class_kernel_3<alphabet_size>::
get_range (
unsigned long symbol,
unsigned long& low_count,
unsigned long& high_count,
unsigned long& total_count
) const
{
if (counts[symbol].present == false)
return 0;
total_count = total;
unsigned long low_count_temp = 0;
data* start = counts;
while (start->symbol != symbol)
{
low_count_temp += start->count;
++start;
}
low_count = low_count_temp;
high_count = low_count_temp + start->count;
return start->count;
}
// ----------------------------------------------------------------------------------------
template <
unsigned long alphabet_size
>
void conditioning_class_kernel_3<alphabet_size>::
get_symbol (
unsigned long target,
unsigned long& symbol,
unsigned long& low_count,
unsigned long& high_count
) const
{
unsigned long high_count_temp = counts->count;
const data* start = counts;
while (target >= high_count_temp)
{
++start;
high_count_temp += start->count;
}
low_count = high_count_temp - start->count;
high_count = high_count_temp;
symbol = static_cast<unsigned long>(start->symbol);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_CONDITIONING_CLASS_KERNEl_3_