AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_LSPI_Hh_
#define DLIB_LSPI_Hh_
#include "lspi_abstract.h"
#include "approximate_linear_models.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class lspi
{
public:
typedef feature_extractor feature_extractor_type;
typedef typename feature_extractor::state_type state_type;
typedef typename feature_extractor::action_type action_type;
explicit lspi(
const feature_extractor& fe_
) : fe(fe_)
{
init();
}
lspi(
)
{
init();
}
double get_discount (
) const { return discount; }
void set_discount (
double value
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 < value && value <= 1,
"\t void lspi::set_discount(value)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t value: " << value
);
discount = value;
}
const feature_extractor& get_feature_extractor (
) const { return fe; }
void be_verbose (
)
{
verbose = true;
}
void be_quiet (
)
{
verbose = false;
}
void set_epsilon (
double eps_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(eps_ > 0,
"\t void lspi::set_epsilon(eps_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t eps_: " << eps_
);
eps = eps_;
}
double get_epsilon (
) const
{
return eps;
}
void set_lambda (
double lambda_
)
{
// make sure requires clause is not broken
DLIB_ASSERT(lambda_ >= 0,
"\t void lspi::set_lambda(lambda_)"
<< "\n\t invalid inputs were given to this function"
<< "\n\t lambda_: " << lambda_
);
lambda = lambda_;
}
double get_lambda (
) const
{
return lambda;
}
void set_max_iterations (
unsigned long max_iter
) { max_iterations = max_iter; }
unsigned long get_max_iterations (
) { return max_iterations; }
template <typename vector_type>
policy<feature_extractor> train (
const vector_type& samples
) const
{
// make sure requires clause is not broken
DLIB_ASSERT(samples.size() > 0,
"\t policy lspi::train(samples)"
<< "\n\t invalid inputs were given to this function"
);
matrix<double,0,1> w(fe.num_features());
w = 0;
matrix<double,0,1> prev_w, b, f1, f2;
matrix<double> A;
double change;
unsigned long iter = 0;
do
{
A = identity_matrix<double>(fe.num_features())*lambda;
b = 0;
for (unsigned long i = 0; i < samples.size(); ++i)
{
fe.get_features(samples[i].state, samples[i].action, f1);
fe.get_features(samples[i].next_state,
fe.find_best_action(samples[i].next_state,w),
f2);
A += f1*trans(f1 - discount*f2);
b += f1*samples[i].reward;
}
prev_w = w;
if (feature_extractor::force_last_weight_to_1)
w = join_cols(pinv(colm(A,range(0,A.nc()-2)))*(b-colm(A,A.nc()-1)),mat(1.0));
else
w = pinv(A)*b;
change = length(w-prev_w);
++iter;
if (verbose)
std::cout << "iteration: " << iter << "\tchange: " << change << std::endl;
} while(change > eps && iter < max_iterations);
return policy<feature_extractor>(w,fe);
}
private:
void init()
{
lambda = 0.01;
discount = 0.8;
eps = 0.01;
verbose = false;
max_iterations = 100;
}
double lambda;
double discount;
double eps;
bool verbose;
unsigned long max_iterations;
feature_extractor fe;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_LSPI_Hh_