|
<html><head><title>dlib C++ Library - dnn_metric_learning_ex.cpp</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt |
|
</font><font color='#009900'>/* |
|
This is an example illustrating the use of the deep learning tools from the |
|
dlib C++ Library. In it, we will show how to use the loss_metric layer to do |
|
metric learning. |
|
|
|
The main reason you might want to use this kind of algorithm is because you |
|
would like to use a k-nearest neighbor classifier or similar algorithm, but |
|
you don't know a good way to calculate the distance between two things. A |
|
popular example would be face recognition. There are a whole lot of papers |
|
that train some kind of deep metric learning algorithm that embeds face |
|
images in some vector space where images of the same person are close to each |
|
other and images of different people are far apart. Then in that vector |
|
space it's very easy to do face recognition with some kind of k-nearest |
|
neighbor classifier. |
|
|
|
To keep this example as simple as possible we won't do face recognition. |
|
Instead, we will create a very simple network and use it to learn a mapping |
|
from 8D vectors to 2D vectors such that vectors with the same class labels |
|
are near each other. If you want to see a more complex example that learns |
|
the kind of network you would use for something like face recognition read |
|
the <a href="dnn_metric_learning_on_images_ex.cpp.html">dnn_metric_learning_on_images_ex.cpp</a> example. |
|
|
|
You should also have read the examples that introduce the dlib DNN API before |
|
continuing. These are <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a> and <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a>. |
|
*/</font> |
|
|
|
|
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>dnn.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font> |
|
|
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std; |
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib; |
|
|
|
|
|
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#0000FF'>try</font> |
|
<b>{</b> |
|
<font color='#009900'>// The API for doing metric learning is very similar to the API for |
|
</font> <font color='#009900'>// multi-class classification. In fact, the inputs are the same, a bunch of |
|
</font> <font color='#009900'>// labeled objects. So here we create our dataset. We make up some simple |
|
</font> <font color='#009900'>// vectors and label them with the integers 1,2,3,4. The specific values of |
|
</font> <font color='#009900'>// the integer labels don't matter. |
|
</font> std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>></font><font color='#5555FF'>></font> samples; |
|
std::vector<font color='#5555FF'><</font><font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font> labels; |
|
|
|
<font color='#009900'>// class 1 training vectors |
|
</font> samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><b>{</b><font color='#979000'>1</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font><b>}</b><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; |
|
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><b>{</b><font color='#979000'>0</font>,<font color='#979000'>1</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font><b>}</b><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// class 2 training vectors |
|
</font> samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><b>{</b><font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>1</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font><b>}</b><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; |
|
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><b>{</b><font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>1</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font><b>}</b><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// class 3 training vectors |
|
</font> samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><b>{</b><font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>1</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font><b>}</b><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; |
|
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><b>{</b><font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>1</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font><b>}</b><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// class 4 training vectors |
|
</font> samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><b>{</b><font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>1</font>,<font color='#979000'>0</font><b>}</b><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; |
|
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><b>{</b><font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><b>}</b><font face='Lucida Console'>)</font>; labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; |
|
|
|
|
|
<font color='#009900'>// Make a network that simply learns a linear mapping from 8D vectors to 2D |
|
</font> <font color='#009900'>// vectors. |
|
</font> <font color='#0000FF'>using</font> net_type <font color='#5555FF'>=</font> loss_metric<font color='#5555FF'><</font>fc<font color='#5555FF'><</font><font color='#979000'>2</font>,input<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font><font color='#5555FF'>></font>; |
|
net_type net; |
|
dnn_trainer<font color='#5555FF'><</font>net_type<font color='#5555FF'>></font> <font color='#BB00BB'>trainer</font><font face='Lucida Console'>(</font>net<font face='Lucida Console'>)</font>; |
|
trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font><font color='#979000'>0.1</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// It should be emphasized out that it's really important that each mini-batch contain |
|
</font> <font color='#009900'>// multiple instances of each class of object. This is because the metric learning |
|
</font> <font color='#009900'>// algorithm needs to consider pairs of objects that should be close as well as pairs |
|
</font> <font color='#009900'>// of objects that should be far apart during each training step. Here we just keep |
|
</font> <font color='#009900'>// training on the same small batch so this constraint is trivially satisfied. |
|
</font> <font color='#0000FF'>while</font><font face='Lucida Console'>(</font>trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>4</font><font face='Lucida Console'>)</font> |
|
trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Wait for training threads to stop |
|
</font> trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>done training</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
|
|
<font color='#009900'>// Run all the samples through the network to get their 2D vector embeddings. |
|
</font> std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font>,<font color='#979000'>0</font>,<font color='#979000'>1</font><font color='#5555FF'>></font><font color='#5555FF'>></font> embedded <font color='#5555FF'>=</font> <font color='#BB00BB'>net</font><font face='Lucida Console'>(</font>samples<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Print the embedding for each sample to the screen. If you look at the |
|
</font> <font color='#009900'>// outputs carefully you should notice that they are grouped together in 2D |
|
</font> <font color='#009900'>// space according to their label. |
|
</font> <font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> embedded.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>label: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> labels[i] <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\t</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#BB00BB'>trans</font><font face='Lucida Console'>(</font>embedded[i]<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Now, check if the embedding puts things with the same labels near each other and |
|
</font> <font color='#009900'>// things with different labels far apart. |
|
</font> <font color='#0000FF'><u>int</u></font> num_right <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'><u>int</u></font> num_wrong <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>; i <font color='#5555FF'><</font> embedded.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> j <font color='#5555FF'>=</font> i<font color='#5555FF'>+</font><font color='#979000'>1</font>; j <font color='#5555FF'><</font> embedded.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>j<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>labels[i] <font color='#5555FF'>=</font><font color='#5555FF'>=</font> labels[j]<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// The loss_metric layer will cause things with the same label to be less |
|
</font> <font color='#009900'>// than net.loss_details().get_distance_threshold() distance from each |
|
</font> <font color='#009900'>// other. So we can use that distance value as our testing threshold for |
|
</font> <font color='#009900'>// "being near to each other". |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>length</font><font face='Lucida Console'>(</font>embedded[i]<font color='#5555FF'>-</font>embedded[j]<font face='Lucida Console'>)</font> <font color='#5555FF'><</font> net.<font color='#BB00BB'>loss_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>get_distance_threshold</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right; |
|
<font color='#0000FF'>else</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>length</font><font face='Lucida Console'>(</font>embedded[i]<font color='#5555FF'>-</font>embedded[j]<font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>=</font> net.<font color='#BB00BB'>loss_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>get_distance_threshold</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_right; |
|
<font color='#0000FF'>else</font> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_wrong; |
|
<b>}</b> |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>num_right: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> num_right <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>num_wrong: </font>"<font color='#5555FF'><</font><font color='#5555FF'><</font> num_wrong <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<b>}</b> |
|
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&</font> e<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<b>}</b> |
|
|
|
|
|
</pre></body></html> |