|
<html><head><title>dlib C++ Library - approximate_linear_models_abstract.h</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// Copyright (C) 2015 Davis E. King (davis@dlib.net) |
|
</font><font color='#009900'>// License: Boost Software License See LICENSE.txt for the full license. |
|
</font><font color='#0000FF'>#undef</font> DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_ |
|
<font color='#0000FF'>#ifdef</font> DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_ |
|
|
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='../matrix.h.html'>../matrix.h</a>" |
|
|
|
<font color='#0000FF'>namespace</font> dlib |
|
<b>{</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>struct</font> <b><a name='example_feature_extractor'></a>example_feature_extractor</b> |
|
<b>{</b> |
|
<font color='#009900'>/*! |
|
WHAT THIS OBJECT REPRESENTS |
|
This object defines the interface a feature extractor must implement if it |
|
is to be used with the process_sample and policy objects defined at the |
|
bottom of this file. Moreover, it is meant to represent the core part |
|
of a model used in a reinforcement learning algorithm. |
|
|
|
In particular, this object models a Q(state,action) function where |
|
Q(state,action) == dot(w, PSI(state,action)) |
|
where PSI(state,action) is a feature vector and w is a parameter |
|
vector. |
|
|
|
Therefore, a feature extractor defines how the PSI(x,y) feature vector is |
|
calculated. It also defines the types used to represent the state and |
|
action objects. |
|
|
|
|
|
THREAD SAFETY |
|
Instances of this object are required to be threadsafe, that is, it should |
|
be safe for multiple threads to make concurrent calls to the member |
|
functions of this object. |
|
!*/</font> |
|
|
|
<font color='#009900'>// The state and actions can be any types so long as you provide typedefs for them. |
|
</font> <font color='#0000FF'>typedef</font> T state_type; |
|
<font color='#0000FF'>typedef</font> U action_type; |
|
<font color='#009900'>// We can also say that the last element in the weight vector w must be 1. This |
|
</font> <font color='#009900'>// can be useful for including a prior into your model. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>static</font> <font color='#0000FF'><u>bool</u></font> force_last_weight_to_1 <font color='#5555FF'>=</font> <font color='#979000'>false</font>; |
|
|
|
<b><a name='example_feature_extractor'></a>example_feature_extractor</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- this object is properly initialized. |
|
!*/</font> |
|
|
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> <b><a name='num_features'></a>num_features</b><font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- returns the dimensionality of the PSI() feature vector. |
|
!*/</font> |
|
|
|
action_type <b><a name='find_best_action'></a>find_best_action</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> state_type<font color='#5555FF'>&</font> state, |
|
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>></font><font color='#5555FF'>&</font> w |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- returns the action A that maximizes Q(state,A) = dot(w,PSI(state,A)). |
|
That is, this function finds the best action to take in the given state |
|
when our model is parameterized by the given weight vector w. |
|
!*/</font> |
|
|
|
<font color='#0000FF'><u>void</u></font> <b><a name='get_features'></a>get_features</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> state_type<font color='#5555FF'>&</font> state, |
|
<font color='#0000FF'>const</font> action_type<font color='#5555FF'>&</font> action, |
|
matrix<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>></font><font color='#5555FF'>&</font> feats |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- #feats.size() == num_features() |
|
- #feats == PSI(state,action) |
|
!*/</font> |
|
|
|
<b>}</b>; |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> feature_extractor |
|
<font color='#5555FF'>></font> |
|
<font color='#0000FF'>struct</font> <b><a name='process_sample'></a>process_sample</b> |
|
<b>{</b> |
|
<font color='#009900'>/*! |
|
REQUIREMENTS ON feature_extractor |
|
feature_extractor should implement the example_feature_extractor interface |
|
defined at the top of this file. |
|
|
|
WHAT THIS OBJECT REPRESENTS |
|
This object holds a training sample for a reinforcement learning algorithm. |
|
In particular, it should be a sample from some process where the process |
|
was in state this->state, then took this->action action which resulted in |
|
receiving this->reward and ending up in the state this->next_state. |
|
!*/</font> |
|
|
|
<font color='#0000FF'>typedef</font> feature_extractor feature_extractor_type; |
|
<font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> feature_extractor::state_type state_type; |
|
<font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> feature_extractor::action_type action_type; |
|
|
|
<b><a name='process_sample'></a>process_sample</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><b>{</b><b>}</b> |
|
|
|
<b><a name='process_sample'></a>process_sample</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> state_type<font color='#5555FF'>&</font> s, |
|
<font color='#0000FF'>const</font> action_type<font color='#5555FF'>&</font> a, |
|
<font color='#0000FF'>const</font> state_type<font color='#5555FF'>&</font> n, |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font><font color='#5555FF'>&</font> r |
|
<font face='Lucida Console'>)</font> : state<font face='Lucida Console'>(</font>s<font face='Lucida Console'>)</font>, action<font face='Lucida Console'>(</font>a<font face='Lucida Console'>)</font>, next_state<font face='Lucida Console'>(</font>n<font face='Lucida Console'>)</font>, reward<font face='Lucida Console'>(</font>r<font face='Lucida Console'>)</font> <b>{</b><b>}</b> |
|
|
|
state_type state; |
|
action_type action; |
|
state_type next_state; |
|
<font color='#0000FF'><u>double</u></font> reward; |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> <font color='#0000FF'>typename</font> feature_extractor <font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> process_sample<font color='#5555FF'><</font>feature_extractor<font color='#5555FF'>></font><font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> <font color='#0000FF'>typename</font> feature_extractor <font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b> <font face='Lucida Console'>(</font>process_sample<font color='#5555FF'><</font>feature_extractor<font color='#5555FF'>></font><font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
provides serialization support. |
|
!*/</font> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> |
|
<font color='#0000FF'>typename</font> feature_extractor |
|
<font color='#5555FF'>></font> |
|
<font color='#0000FF'>class</font> <b><a name='policy'></a>policy</b> |
|
<b>{</b> |
|
<font color='#009900'>/*! |
|
REQUIREMENTS ON feature_extractor |
|
feature_extractor should implement the example_feature_extractor interface |
|
defined at the top of this file. |
|
|
|
WHAT THIS OBJECT REPRESENTS |
|
This is a policy based on the supplied feature_extractor model. In |
|
particular, it maps from feature_extractor::state_type to the best action |
|
to take in that state. |
|
!*/</font> |
|
|
|
<font color='#0000FF'>public</font>: |
|
|
|
<font color='#0000FF'>typedef</font> feature_extractor feature_extractor_type; |
|
<font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> feature_extractor::state_type state_type; |
|
<font color='#0000FF'>typedef</font> <font color='#0000FF'>typename</font> feature_extractor::action_type action_type; |
|
|
|
|
|
<b><a name='policy'></a>policy</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- #get_feature_extractor() == feature_extractor() |
|
(i.e. it will have its default value) |
|
- #get_weights().size() == #get_feature_extractor().num_features() |
|
- #get_weights() == 0 |
|
!*/</font> |
|
|
|
<b><a name='policy'></a>policy</b> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>></font><font color='#5555FF'>&</font> weights, |
|
<font color='#0000FF'>const</font> feature_extractor<font color='#5555FF'>&</font> fe |
|
<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
requires |
|
- fe.num_features() == weights.size() |
|
ensures |
|
- #get_feature_extractor() == fe |
|
- #get_weights() == weights |
|
!*/</font> |
|
|
|
action_type <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> state_type<font color='#5555FF'>&</font> state |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- returns get_feature_extractor().find_best_action(state,w); |
|
!*/</font> |
|
|
|
<font color='#0000FF'>const</font> feature_extractor<font color='#5555FF'>&</font> <b><a name='get_feature_extractor'></a>get_feature_extractor</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- returns the feature extractor used by this object |
|
!*/</font> |
|
|
|
<font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>></font><font color='#5555FF'>&</font> <b><a name='get_weights'></a>get_weights</b> <font face='Lucida Console'>(</font> |
|
<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>; |
|
<font color='#009900'>/*! |
|
ensures |
|
- returns the parameter vector (w) associated with this object. The length |
|
of the vector is get_feature_extractor().num_features(). |
|
!*/</font> |
|
|
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> <font color='#0000FF'>typename</font> feature_extractor <font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='serialize'></a>serialize</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> policy<font color='#5555FF'><</font>feature_extractor<font color='#5555FF'>></font><font color='#5555FF'>&</font> item, std::ostream<font color='#5555FF'>&</font> out<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font> <font color='#0000FF'>typename</font> feature_extractor <font color='#5555FF'>></font> |
|
<font color='#0000FF'><u>void</u></font> <b><a name='deserialize'></a>deserialize</b><font face='Lucida Console'>(</font>policy<font color='#5555FF'><</font>feature_extractor<font color='#5555FF'>></font><font color='#5555FF'>&</font> item, std::istream<font color='#5555FF'>&</font> in<font face='Lucida Console'>)</font>; |
|
<font color='#009900'>/*! |
|
provides serialization support. |
|
!*/</font> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
|
|
<font color='#0000FF'>#endif</font> <font color='#009900'>// DLIB_APPROXIMATE_LINEAR_MODELS_ABSTRACT_Hh_ |
|
</font> |
|
|
|
</pre></body></html> |