|
<html><head><title>dlib C++ Library - dnn_instance_segmentation_train_ex.cpp</title></head><body bgcolor='white'><pre> |
|
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt |
|
</font><font color='#009900'>/* |
|
This example shows how to train a instance segmentation net using the PASCAL VOC2012 |
|
dataset. For an introduction to what segmentation is, see the accompanying header file |
|
dnn_instance_segmentation_ex.h. |
|
|
|
Instructions how to run the example: |
|
1. Download the PASCAL VOC2012 data, and untar it somewhere. |
|
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar |
|
2. Build the dnn_instance_segmentation_train_ex example program. |
|
3. Run: |
|
./dnn_instance_segmentation_train_ex /path/to/VOC2012 |
|
4. Wait while the network is being trained. |
|
5. Build the dnn_instance_segmentation_ex example program. |
|
6. Run: |
|
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images |
|
|
|
It would be a good idea to become familiar with dlib's DNN tooling before reading this |
|
example. So you should read <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a>, <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a>, |
|
and <a href="dnn_semantic_segmentation_train_ex.cpp.html">dnn_semantic_segmentation_train_ex.cpp</a> before reading this example program. |
|
*/</font> |
|
|
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='dnn_instance_segmentation_ex.h.html'>dnn_instance_segmentation_ex.h</a>" |
|
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='pascal_voc_2012.h.html'>pascal_voc_2012.h</a>" |
|
|
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iostream<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>image_transforms.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>dlib<font color='#5555FF'>/</font>dir_nav.h<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>iterator<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>thread<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#if</font> __cplusplus <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>201703</font>L <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font face='Lucida Console'>(</font>defined<font face='Lucida Console'>(</font>_MSVC_LANG<font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> _MSVC_LANG <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>201703</font>L<font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>#include</font> <font color='#5555FF'><</font>execution<font color='#5555FF'>></font> |
|
<font color='#0000FF'>#endif</font> <font color='#009900'>// __cplusplus >= 201703L |
|
</font> |
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std; |
|
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib; |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#009900'>// A single training sample for detection. A mini-batch comprises many of these. |
|
</font><font color='#0000FF'>struct</font> <b><a name='det_training_sample'></a>det_training_sample</b> |
|
<b>{</b> |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> input_image; |
|
std::vector<font color='#5555FF'><</font>dlib::mmod_rect<font color='#5555FF'>></font> mmod_rects; |
|
<b>}</b>; |
|
|
|
<font color='#009900'>// A single training sample for segmentation. A mini-batch comprises many of these. |
|
</font><font color='#0000FF'>struct</font> <b><a name='seg_training_sample'></a>seg_training_sample</b> |
|
<b>{</b> |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> input_image; |
|
matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>></font> label_image; <font color='#009900'>// The ground-truth label of each pixel. (+1 or -1) |
|
</font><b>}</b>; |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'><u>bool</u></font> <b><a name='is_instance_pixel'></a>is_instance_pixel</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dlib::rgb_pixel<font color='#5555FF'>&</font> rgb_label<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rgb_label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> dlib::<font color='#BB00BB'>rgb_pixel</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>, <font color='#979000'>0</font>, <font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font> <font color='#979000'>false</font>; <font color='#009900'>// Background |
|
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rgb_label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> dlib::<font color='#BB00BB'>rgb_pixel</font><font face='Lucida Console'>(</font><font color='#979000'>224</font>, <font color='#979000'>224</font>, <font color='#979000'>192</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>return</font> <font color='#979000'>false</font>; <font color='#009900'>// The cream-colored `void' label is used in border regions and to mask difficult objects |
|
</font> |
|
<font color='#0000FF'>return</font> <font color='#979000'>true</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Provide hash function for dlib::rgb_pixel |
|
</font><font color='#0000FF'>namespace</font> std <b>{</b> |
|
<font color='#0000FF'>template</font> <font color='#5555FF'><</font><font color='#5555FF'>></font> |
|
<font color='#0000FF'>struct</font> <b><a name='hash'></a>hash</b><font color='#5555FF'><</font>dlib::rgb_pixel<font color='#5555FF'>></font> |
|
<b>{</b> |
|
std::<font color='#0000FF'><u>size_t</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dlib::rgb_pixel<font color='#5555FF'>&</font> p<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> <font face='Lucida Console'>(</font><font color='#0000FF'>static_cast</font><font color='#5555FF'><</font>uint32_t<font color='#5555FF'>></font><font face='Lucida Console'>(</font>p.red<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#979000'>16</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>|</font> <font face='Lucida Console'>(</font><font color='#0000FF'>static_cast</font><font color='#5555FF'><</font>uint32_t<font color='#5555FF'>></font><font face='Lucida Console'>(</font>p.green<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> <font color='#979000'>8</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>|</font> <font face='Lucida Console'>(</font><font color='#0000FF'>static_cast</font><font color='#5555FF'><</font>uint32_t<font color='#5555FF'>></font><font face='Lucida Console'>(</font>p.blue<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>struct</font> <b><a name='truth_instance'></a>truth_instance</b> |
|
<b>{</b> |
|
dlib::rgb_pixel rgb_label; |
|
dlib::mmod_rect mmod_rect; |
|
<b>}</b>; |
|
|
|
std::vector<font color='#5555FF'><</font>truth_instance<font color='#5555FF'>></font> <b><a name='rgb_label_images_to_truth_instances'></a>rgb_label_images_to_truth_instances</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> dlib::matrix<font color='#5555FF'><</font>dlib::rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> instance_label_image, |
|
<font color='#0000FF'>const</font> dlib::matrix<font color='#5555FF'><</font>dlib::rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> class_label_image |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::unordered_map<font color='#5555FF'><</font>dlib::rgb_pixel, mmod_rect<font color='#5555FF'>></font> result_map; |
|
|
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>instance_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> class_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>instance_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> class_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> nr <font color='#5555FF'>=</font> instance_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> nc <font color='#5555FF'>=</font> instance_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> r <font color='#5555FF'>=</font> <font color='#979000'>0</font>; r <font color='#5555FF'><</font> nr; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>r<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> c <font color='#5555FF'>=</font> <font color='#979000'>0</font>; c <font color='#5555FF'><</font> nc; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>c<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> rgb_instance_label <font color='#5555FF'>=</font> <font color='#BB00BB'>instance_label_image</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font><font color='#BB00BB'>is_instance_pixel</font><font face='Lucida Console'>(</font>rgb_instance_label<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>continue</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> rgb_class_label <font color='#5555FF'>=</font> <font color='#BB00BB'>class_label_image</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> Voc2012class<font color='#5555FF'>&</font> voc2012_class <font color='#5555FF'>=</font> <font color='#BB00BB'>find_voc2012_class</font><font face='Lucida Console'>(</font>rgb_class_label<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> i <font color='#5555FF'>=</font> result_map.<font color='#BB00BB'>find</font><font face='Lucida Console'>(</font>rgb_instance_label<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>i <font color='#5555FF'>=</font><font color='#5555FF'>=</font> result_map.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Encountered a new instance |
|
</font> result_map[rgb_instance_label] <font color='#5555FF'>=</font> <font color='#BB00BB'>rectangle</font><font face='Lucida Console'>(</font>c, r, c, r<font face='Lucida Console'>)</font>; |
|
result_map[rgb_instance_label].label <font color='#5555FF'>=</font> voc2012_class.classlabel; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
<font color='#009900'>// Not the first occurrence - update the rect |
|
</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> rect <font color='#5555FF'>=</font> i<font color='#5555FF'>-</font><font color='#5555FF'>></font>second.rect; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>c <font color='#5555FF'><</font> rect.<font color='#BB00BB'>left</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
rect.<font color='#BB00BB'>set_left</font><font face='Lucida Console'>(</font>c<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>else</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>c <font color='#5555FF'>></font> rect.<font color='#BB00BB'>right</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
rect.<font color='#BB00BB'>set_right</font><font face='Lucida Console'>(</font>c<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>r <font color='#5555FF'>></font> rect.<font color='#BB00BB'>bottom</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
rect.<font color='#BB00BB'>set_bottom</font><font face='Lucida Console'>(</font>r<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>i<font color='#5555FF'>-</font><font color='#5555FF'>></font>second.label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> voc2012_class.classlabel<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
std::vector<font color='#5555FF'><</font>truth_instance<font color='#5555FF'>></font> flat_result; |
|
flat_result.<font color='#BB00BB'>reserve</font><font face='Lucida Console'>(</font>result_map.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> i : result_map<font face='Lucida Console'>)</font> <b>{</b> |
|
flat_result.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>truth_instance<b>{</b> |
|
i.first, i.second |
|
<b>}</b><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>return</font> flat_result; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'>struct</font> <b><a name='truth_image'></a>truth_image</b> |
|
<b>{</b> |
|
image_info info; |
|
std::vector<font color='#5555FF'><</font>truth_instance<font color='#5555FF'>></font> truth_instances; |
|
<b>}</b>; |
|
|
|
std::vector<font color='#5555FF'><</font>mmod_rect<font color='#5555FF'>></font> <b><a name='extract_mmod_rects'></a>extract_mmod_rects</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>truth_instance<font color='#5555FF'>></font><font color='#5555FF'>&</font> truth_instances |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::vector<font color='#5555FF'><</font>mmod_rect<font color='#5555FF'>></font> <font color='#BB00BB'>mmod_rects</font><font face='Lucida Console'>(</font>truth_instances.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
std::<font color='#BB00BB'>transform</font><font face='Lucida Console'>(</font> |
|
truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
truth_instances.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
mmod_rects.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
[]<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> truth_instance<font color='#5555FF'>&</font> truth<font face='Lucida Console'>)</font> <b>{</b> <font color='#0000FF'>return</font> truth.mmod_rect; <b>}</b> |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>return</font> mmod_rects; |
|
<b>}</b> |
|
|
|
std::vector<font color='#5555FF'><</font>std::vector<font color='#5555FF'><</font>mmod_rect<font color='#5555FF'>></font><font color='#5555FF'>></font> <b><a name='extract_mmod_rect_vectors'></a>extract_mmod_rect_vectors</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font><font color='#5555FF'>&</font> truth_images |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::vector<font color='#5555FF'><</font>std::vector<font color='#5555FF'><</font>mmod_rect<font color='#5555FF'>></font><font color='#5555FF'>></font> <font color='#BB00BB'>mmod_rects</font><font face='Lucida Console'>(</font>truth_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> extract_mmod_rects_from_truth_image <font color='#5555FF'>=</font> []<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> truth_image<font color='#5555FF'>&</font> truth_image<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>return</font> <font color='#BB00BB'>extract_mmod_rects</font><font face='Lucida Console'>(</font>truth_image.truth_instances<font face='Lucida Console'>)</font>; |
|
<b>}</b>; |
|
|
|
std::<font color='#BB00BB'>transform</font><font face='Lucida Console'>(</font> |
|
truth_images.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
truth_images.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
mmod_rects.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
extract_mmod_rects_from_truth_image |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>return</font> mmod_rects; |
|
<b>}</b> |
|
|
|
det_bnet_type <b><a name='train_detection_network'></a>train_detection_network</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font><font color='#5555FF'>&</font> truth_images, |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> det_minibatch_size |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> min_detector_window_overlap_iou <font color='#5555FF'>=</font> <font color='#979000'>0.65</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> target_size <font color='#5555FF'>=</font> <font color='#979000'>70</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> min_target_size <font color='#5555FF'>=</font> <font color='#979000'>30</font>; |
|
|
|
mmod_options <font color='#BB00BB'>options</font><font face='Lucida Console'>(</font> |
|
<font color='#BB00BB'>extract_mmod_rect_vectors</font><font face='Lucida Console'>(</font>truth_images<font face='Lucida Console'>)</font>, |
|
target_size, min_target_size, |
|
min_detector_window_overlap_iou |
|
<font face='Lucida Console'>)</font>; |
|
|
|
options.overlaps_ignore <font color='#5555FF'>=</font> <font color='#BB00BB'>test_box_overlap</font><font face='Lucida Console'>(</font><font color='#979000'>0.5</font>, <font color='#979000'>0.9</font><font face='Lucida Console'>)</font>; |
|
|
|
det_bnet_type <font color='#BB00BB'>det_net</font><font face='Lucida Console'>(</font>options<font face='Lucida Console'>)</font>; |
|
|
|
det_net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>layer_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>set_num_filters</font><font face='Lucida Console'>(</font>options.detector_windows.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
dlib::pipe<font color='#5555FF'><</font>det_training_sample<font color='#5555FF'>></font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&</font>data, <font color='#5555FF'>&</font>truth_images, target_size, min_target_size]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> seed<font face='Lucida Console'>)</font>; |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> input_image; |
|
|
|
random_cropper cropper; |
|
cropper.<font color='#BB00BB'>set_seed</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
cropper.<font color='#BB00BB'>set_chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>350</font>, <font color='#979000'>350</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Usually you want to give the cropper whatever min sizes you passed to the |
|
</font> <font color='#009900'>// mmod_options constructor, or very slightly smaller sizes, which is what we do here. |
|
</font> cropper.<font color='#BB00BB'>set_min_object_size</font><font face='Lucida Console'>(</font>target_size <font color='#5555FF'>-</font> <font color='#979000'>2</font>, min_target_size <font color='#5555FF'>-</font> <font color='#979000'>2</font><font face='Lucida Console'>)</font>; |
|
cropper.<font color='#BB00BB'>set_max_rotation_degrees</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; |
|
|
|
det_training_sample temp; |
|
|
|
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Pick a random input image. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> random_index <font color='#5555FF'>=</font> rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>%</font> truth_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> truth_image <font color='#5555FF'>=</font> truth_images[random_index]; |
|
|
|
<font color='#009900'>// Load the input image. |
|
</font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>input_image, truth_image.info.image_filename<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Get a random crop of the input. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> mmod_rects <font color='#5555FF'>=</font> <font color='#BB00BB'>extract_mmod_rects</font><font face='Lucida Console'>(</font>truth_image.truth_instances<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>cropper</font><font face='Lucida Console'>(</font>input_image, mmod_rects, temp.input_image, temp.mmod_rects<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>disturb_colors</font><font face='Lucida Console'>(</font>temp.input_image, rnd<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Push the result to be used by the trainer. |
|
</font> data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b>; |
|
std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> stop_data_loaders <font color='#5555FF'>=</font> [<font color='#5555FF'>&</font>]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b>; |
|
|
|
dnn_trainer<font color='#5555FF'><</font>det_bnet_type<font color='#5555FF'>></font> <font color='#BB00BB'>det_trainer</font><font face='Lucida Console'>(</font>det_net, <font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>try</font> |
|
<b>{</b> |
|
det_trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
det_trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>; |
|
det_trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>pascal_voc2012_det_trainer_state_file.dat</font>", std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
det_trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>5000</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Output training parameters. |
|
</font> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> det_trainer <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> samples; |
|
std::vector<font color='#5555FF'><</font>std::vector<font color='#5555FF'><</font>mmod_rect<font color='#5555FF'>></font><font color='#5555FF'>></font> labels; |
|
|
|
<font color='#009900'>// The main training loop. Keep making mini-batches and giving them to the trainer. |
|
</font> <font color='#009900'>// We will run until the learning rate becomes small enough. |
|
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>det_trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>4</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// make a mini-batch |
|
</font> det_training_sample temp; |
|
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> det_minibatch_size<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; |
|
|
|
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.input_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.mmod_rects<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
det_trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#0000FF'>catch</font> <font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>stop_data_loaders</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>throw</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before |
|
</font> <font color='#009900'>// moving on. |
|
</font> <font color='#BB00BB'>stop_data_loaders</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// also wait for threaded processing to stop in the trainer. |
|
</font> det_trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
det_net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>return</font> det_net; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>></font> <b><a name='keep_only_current_instance'></a>keep_only_current_instance</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>&</font> rgb_label_image, <font color='#0000FF'>const</font> rgb_pixel rgb_label<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> nr <font color='#5555FF'>=</font> rgb_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> nc <font color='#5555FF'>=</font> rgb_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>></font> <font color='#BB00BB'>result</font><font face='Lucida Console'>(</font>nr, nc<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> r <font color='#5555FF'>=</font> <font color='#979000'>0</font>; r <font color='#5555FF'><</font> nr; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>r<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> c <font color='#5555FF'>=</font> <font color='#979000'>0</font>; c <font color='#5555FF'><</font> nc; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>c<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> index <font color='#5555FF'>=</font> <font color='#BB00BB'>rgb_label_image</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>index <font color='#5555FF'>=</font><font color='#5555FF'>=</font> rgb_label<font face='Lucida Console'>)</font> |
|
<font color='#BB00BB'>result</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#5555FF'>+</font><font color='#979000'>1</font>; |
|
<font color='#0000FF'>else</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>index <font color='#5555FF'>=</font><font color='#5555FF'>=</font> dlib::<font color='#BB00BB'>rgb_pixel</font><font face='Lucida Console'>(</font><font color='#979000'>224</font>, <font color='#979000'>224</font>, <font color='#979000'>192</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#BB00BB'>result</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'>else</font> |
|
<font color='#BB00BB'>result</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>1</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>return</font> result; |
|
<b>}</b> |
|
|
|
seg_bnet_type <b><a name='train_segmentation_network'></a>train_segmentation_network</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font><font color='#5555FF'>&</font> truth_images, |
|
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> seg_minibatch_size, |
|
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&</font> classlabel |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
seg_bnet_type seg_net; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>; |
|
|
|
<font color='#0000FF'>const</font> std::string synchronization_file_name |
|
<font color='#5555FF'>=</font> "<font color='#CC0000'>pascal_voc2012_seg_trainer_state_file</font>" |
|
<font color='#5555FF'>+</font> <font face='Lucida Console'>(</font>classlabel.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> ? "<font color='#CC0000'></font>" : <font face='Lucida Console'>(</font>"<font color='#CC0000'>_</font>" <font color='#5555FF'>+</font> classlabel<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<font color='#5555FF'>+</font> "<font color='#CC0000'>.dat</font>"; |
|
|
|
dnn_trainer<font color='#5555FF'><</font>seg_bnet_type<font color='#5555FF'>></font> <font color='#BB00BB'>seg_trainer</font><font face='Lucida Console'>(</font>seg_net, <font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
seg_trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
seg_trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>; |
|
seg_trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>synchronization_file_name, std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
seg_trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>2000</font><font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>set_all_bn_running_stats_window_sizes</font><font face='Lucida Console'>(</font>seg_net, <font color='#979000'>1000</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Output training parameters. |
|
</font> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> seg_trainer <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font><font color='#5555FF'>></font> samples; |
|
std::vector<font color='#5555FF'><</font>matrix<font color='#5555FF'><</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>></font><font color='#5555FF'>></font> labels; |
|
|
|
<font color='#009900'>// Start a bunch of threads that read images from disk and pull out random crops. It's |
|
</font> <font color='#009900'>// important to be sure to feed the GPU fast enough to keep it busy. Using multiple |
|
</font> <font color='#009900'>// thread for this kind of data preparation helps us do that. Each thread puts the |
|
</font> <font color='#009900'>// crops into the data queue. |
|
</font> dlib::pipe<font color='#5555FF'><</font>seg_training_sample<font color='#5555FF'>></font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&</font>data, <font color='#5555FF'>&</font>truth_images]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> seed<font face='Lucida Console'>)</font>; |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> input_image; |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> rgb_label_image; |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> rgb_label_chip; |
|
seg_training_sample temp; |
|
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Pick a random input image. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> random_index <font color='#5555FF'>=</font> rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>%</font> truth_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> truth_image <font color='#5555FF'>=</font> truth_images[random_index]; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> image_truths <font color='#5555FF'>=</font> truth_image.truth_instances; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>image_truths.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> image_info<font color='#5555FF'>&</font> info <font color='#5555FF'>=</font> truth_image.info; |
|
|
|
<font color='#009900'>// Load the input image. |
|
</font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>input_image, info.image_filename<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Load the ground-truth (RGB) instance labels. |
|
</font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>rgb_label_image, info.instance_label_filename<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Pick a random training instance. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> truth_instance <font color='#5555FF'>=</font> image_truths[rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>%</font> image_truths.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>]; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> truth_rect <font color='#5555FF'>=</font> truth_instance.mmod_rect.rect; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> cropping_rect <font color='#5555FF'>=</font> <font color='#BB00BB'>get_cropping_rect</font><font face='Lucida Console'>(</font>truth_rect<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Pick a random crop around the instance. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> max_x_translate_amount <font color='#5555FF'>=</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'><</font><font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font><font face='Lucida Console'>(</font>truth_rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#979000'>10.0</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> max_y_translate_amount <font color='#5555FF'>=</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'><</font><font color='#0000FF'><u>long</u></font><font color='#5555FF'>></font><font face='Lucida Console'>(</font>truth_rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#979000'>10.0</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> random_translate <font color='#5555FF'>=</font> <font color='#BB00BB'>point</font><font face='Lucida Console'>(</font> |
|
rnd.<font color='#BB00BB'>get_integer_in_range</font><font face='Lucida Console'>(</font><font color='#5555FF'>-</font>max_x_translate_amount, max_x_translate_amount <font color='#5555FF'>+</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>, |
|
rnd.<font color='#BB00BB'>get_integer_in_range</font><font face='Lucida Console'>(</font><font color='#5555FF'>-</font>max_y_translate_amount, max_y_translate_amount <font color='#5555FF'>+</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font> |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> rectangle <font color='#BB00BB'>random_rect</font><font face='Lucida Console'>(</font> |
|
cropping_rect.<font color='#BB00BB'>left</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> random_translate.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
cropping_rect.<font color='#BB00BB'>top</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> random_translate.<font color='#BB00BB'>y</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
cropping_rect.<font color='#BB00BB'>right</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> random_translate.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
cropping_rect.<font color='#BB00BB'>bottom</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> random_translate.<font color='#BB00BB'>y</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> chip_details <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>random_rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font>seg_dim, seg_dim<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Crop the input image. |
|
</font> <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>input_image, chip_details, temp.input_image, <font color='#BB00BB'>interpolate_bilinear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#BB00BB'>disturb_colors</font><font face='Lucida Console'>(</font>temp.input_image, rnd<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Crop the labels correspondingly. However, note that here bilinear |
|
</font> <font color='#009900'>// interpolation would make absolutely no sense - you wouldn't say that |
|
</font> <font color='#009900'>// a bicycle is half-way between an aeroplane and a bird, would you? |
|
</font> <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>rgb_label_image, chip_details, rgb_label_chip, <font color='#BB00BB'>interpolate_nearest_neighbor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Clear pixels not related to the current instance. |
|
</font> temp.label_image <font color='#5555FF'>=</font> <font color='#BB00BB'>keep_only_current_instance</font><font face='Lucida Console'>(</font>rgb_label_chip, truth_instance.rgb_label<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Push the result to be used by the trainer. |
|
</font> data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
<font color='#009900'>// TODO: use background samples as well |
|
</font> <b>}</b> |
|
<b>}</b> |
|
<b>}</b>; |
|
std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> stop_data_loaders <font color='#5555FF'>=</font> [<font color='#5555FF'>&</font>]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>try</font> |
|
<b>{</b> |
|
<font color='#009900'>// The main training loop. Keep making mini-batches and giving them to the trainer. |
|
</font> <font color='#009900'>// We will run until the learning rate has dropped by a factor of 1e-4. |
|
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>seg_trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>4</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// make a mini-batch |
|
</font> seg_training_sample temp; |
|
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> seg_minibatch_size<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>; |
|
|
|
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.input_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.label_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
seg_trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#0000FF'>catch</font> <font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#BB00BB'>stop_data_loaders</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>throw</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before |
|
</font> <font color='#009900'>// moving on. |
|
</font> <font color='#BB00BB'>stop_data_loaders</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// also wait for threaded processing to stop in the trainer. |
|
</font> seg_trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
seg_net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>return</font> seg_net; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
<font color='#0000FF'><u>int</u></font> <b><a name='ignore_overlapped_boxes'></a>ignore_overlapped_boxes</b><font face='Lucida Console'>(</font> |
|
std::vector<font color='#5555FF'><</font>truth_instance<font color='#5555FF'>></font><font color='#5555FF'>&</font> truth_instances, |
|
<font color='#0000FF'>const</font> test_box_overlap<font color='#5555FF'>&</font> overlaps |
|
<font face='Lucida Console'>)</font> |
|
<font color='#009900'>/*! |
|
ensures |
|
- Whenever two rectangles in boxes overlap, according to overlaps(), we set the |
|
smallest box to ignore. |
|
- returns the number of newly ignored boxes. |
|
!*/</font> |
|
<b>{</b> |
|
<font color='#0000FF'><u>int</u></font> num_ignored <font color='#5555FF'>=</font> <font color='#979000'>0</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>, end <font color='#5555FF'>=</font> truth_instances.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; i <font color='#5555FF'><</font> end; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font><font color='#5555FF'>&</font> box_i <font color='#5555FF'>=</font> truth_instances[i].mmod_rect; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>box_i.ignore<font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>continue</font>; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> j <font color='#5555FF'>=</font> i<font color='#5555FF'>+</font><font color='#979000'>1</font>; j <font color='#5555FF'><</font> end; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>j<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font><font color='#5555FF'>&</font> box_j <font color='#5555FF'>=</font> truth_instances[j].mmod_rect; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>box_j.ignore<font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>continue</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>overlaps</font><font face='Lucida Console'>(</font>box_i, box_j<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_ignored; |
|
<font color='#0000FF'>if</font><font face='Lucida Console'>(</font>box_i.rect.<font color='#BB00BB'>area</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> box_j.rect.<font color='#BB00BB'>area</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
box_i.ignore <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<font color='#0000FF'>else</font> |
|
box_j.ignore <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#0000FF'>return</font> num_ignored; |
|
<b>}</b> |
|
|
|
std::vector<font color='#5555FF'><</font>truth_instance<font color='#5555FF'>></font> <b><a name='load_truth_instances'></a>load_truth_instances</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> image_info<font color='#5555FF'>&</font> info<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> instance_label_image; |
|
matrix<font color='#5555FF'><</font>rgb_pixel<font color='#5555FF'>></font> class_label_image; |
|
|
|
<font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>instance_label_image, info.instance_label_filename<font face='Lucida Console'>)</font>; |
|
<font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>class_label_image, info.class_label_filename<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>return</font> <font color='#BB00BB'>rgb_label_images_to_truth_instances</font><font face='Lucida Console'>(</font>instance_label_image, class_label_image<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
std::vector<font color='#5555FF'><</font>std::vector<font color='#5555FF'><</font>truth_instance<font color='#5555FF'>></font><font color='#5555FF'>></font> <b><a name='load_all_truth_instances'></a>load_all_truth_instances</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>image_info<font color='#5555FF'>></font><font color='#5555FF'>&</font> listing<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::vector<font color='#5555FF'><</font>std::vector<font color='#5555FF'><</font>truth_instance<font color='#5555FF'>></font><font color='#5555FF'>></font> <font color='#BB00BB'>truth_instances</font><font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
std::<font color='#BB00BB'>transform</font><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>#if</font> __cplusplus <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>201703</font>L <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font face='Lucida Console'>(</font>defined<font face='Lucida Console'>(</font>_MSVC_LANG<font face='Lucida Console'>)</font> <font color='#5555FF'>&</font><font color='#5555FF'>&</font> _MSVC_LANG <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>201703</font>L<font face='Lucida Console'>)</font> |
|
std::execution::par, |
|
<font color='#0000FF'>#endif</font> <font color='#009900'>// __cplusplus >= 201703L |
|
</font> listing.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
listing.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
load_truth_instances |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>return</font> truth_instances; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// ---------------------------------------------------------------------------------------- |
|
</font> |
|
std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font> <b><a name='filter_based_on_classlabel'></a>filter_based_on_classlabel</b><font face='Lucida Console'>(</font> |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font><font color='#5555FF'>&</font> truth_images, |
|
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>std::string<font color='#5555FF'>></font><font color='#5555FF'>&</font> desired_classlabels |
|
<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font> result; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> represents_desired_class <font color='#5555FF'>=</font> [<font color='#5555FF'>&</font>desired_classlabels]<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> truth_instance<font color='#5555FF'>&</font> truth_instance<font face='Lucida Console'>)</font> <b>{</b> |
|
<font color='#0000FF'>return</font> std::<font color='#BB00BB'>find</font><font face='Lucida Console'>(</font> |
|
desired_classlabels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
desired_classlabels.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
truth_instance.mmod_rect.label |
|
<font face='Lucida Console'>)</font> <font color='#5555FF'>!</font><font color='#5555FF'>=</font> desired_classlabels.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; |
|
<b>}</b>; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> input : truth_images<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> has_desired_class <font color='#5555FF'>=</font> std::<font color='#BB00BB'>any_of</font><font face='Lucida Console'>(</font> |
|
input.truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
input.truth_instances.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
represents_desired_class |
|
<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>has_desired_class<font face='Lucida Console'>)</font> <b>{</b> |
|
|
|
<font color='#009900'>// NB: This keeps only MMOD rects belonging to any of the desired classes. |
|
</font> <font color='#009900'>// A reasonable alternative could be to keep all rects, but mark those |
|
</font> <font color='#009900'>// belonging in other classes to be ignored during training. |
|
</font> std::vector<font color='#5555FF'><</font>truth_instance<font color='#5555FF'>></font> temp; |
|
std::<font color='#BB00BB'>copy_if</font><font face='Lucida Console'>(</font> |
|
input.truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
input.truth_instances.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, |
|
std::<font color='#BB00BB'>back_inserter</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>, |
|
represents_desired_class |
|
<font face='Lucida Console'>)</font>; |
|
|
|
result.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>truth_image<b>{</b> input.info, temp <b>}</b><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>return</font> result; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Ignore truth boxes that overlap too much, are too small, or have a large aspect ratio. |
|
</font><font color='#0000FF'><u>void</u></font> <b><a name='ignore_some_truth_boxes'></a>ignore_some_truth_boxes</b><font face='Lucida Console'>(</font>std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font><font color='#5555FF'>&</font> truth_images<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font> i : truth_images<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>auto</font><font color='#5555FF'>&</font> truth_instances <font color='#5555FF'>=</font> i.truth_instances; |
|
|
|
<font color='#BB00BB'>ignore_overlapped_boxes</font><font face='Lucida Console'>(</font>truth_instances, <font color='#BB00BB'>test_box_overlap</font><font face='Lucida Console'>(</font><font color='#979000'>0.90</font>, <font color='#979000'>0.95</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&</font> truth : truth_instances<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>truth.mmod_rect.ignore<font face='Lucida Console'>)</font> |
|
<font color='#0000FF'>continue</font>; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> rect <font color='#5555FF'>=</font> truth.mmod_rect.rect; |
|
|
|
constexpr <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> min_width <font color='#5555FF'>=</font> <font color='#979000'>35</font>; |
|
constexpr <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> min_height <font color='#5555FF'>=</font> <font color='#979000'>35</font>; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> min_width <font color='#5555FF'>&</font><font color='#5555FF'>&</font> rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font> min_height<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
truth.mmod_rect.ignore <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<font color='#0000FF'>continue</font>; |
|
<b>}</b> |
|
|
|
constexpr <font color='#0000FF'><u>double</u></font> max_aspect_ratio_width_to_height <font color='#5555FF'>=</font> <font color='#979000'>3.0</font>; |
|
constexpr <font color='#0000FF'><u>double</u></font> max_aspect_ratio_height_to_width <font color='#5555FF'>=</font> <font color='#979000'>1.5</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> aspect_ratio_width_to_height <font color='#5555FF'>=</font> rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'><</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>></font><font face='Lucida Console'>(</font>rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> aspect_ratio_height_to_width <font color='#5555FF'>=</font> <font color='#979000'>1.0</font> <font color='#5555FF'>/</font> aspect_ratio_width_to_height; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> is_aspect_ratio_too_large |
|
<font color='#5555FF'>=</font> aspect_ratio_width_to_height <font color='#5555FF'>></font> max_aspect_ratio_width_to_height |
|
<font color='#5555FF'>|</font><font color='#5555FF'>|</font> aspect_ratio_height_to_width <font color='#5555FF'>></font> max_aspect_ratio_height_to_width; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>is_aspect_ratio_too_large<font face='Lucida Console'>)</font> |
|
truth.mmod_rect.ignore <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<b>}</b> |
|
|
|
<font color='#009900'>// Filter images that have no (non-ignored) truth |
|
</font>std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font> <b><a name='filter_images_with_no_truth'></a>filter_images_with_no_truth</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font><font color='#5555FF'>&</font> truth_images<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font> result; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> truth_image : truth_images<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> ignored <font color='#5555FF'>=</font> []<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> truth_instance<font color='#5555FF'>&</font> truth<font face='Lucida Console'>)</font> <b>{</b> <font color='#0000FF'>return</font> truth.mmod_rect.ignore; <b>}</b>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> truth_instances <font color='#5555FF'>=</font> truth_image.truth_instances; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>std::<font color='#BB00BB'>all_of</font><font face='Lucida Console'>(</font>truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, truth_instances.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, ignored<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
result.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>truth_image<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>return</font> result; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font> |
|
<b>{</b> |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'><</font> <font color='#979000'>2</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>To run this program you need a copy of the PASCAL VOC2012 dataset.</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>You call this program like this: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size] [class-1] [class-2] [class-3] ...</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>return</font> <font color='#979000'>1</font>; |
|
<b>}</b> |
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>\nSCANNING PASCAL VOC2012 DATASET\n</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> listing <font color='#5555FF'>=</font> <font color='#BB00BB'>get_pascal_voc2012_train_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font>; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>images in entire dataset: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Didn't find the VOC2012 dataset. </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>return</font> <font color='#979000'>1</font>; |
|
<b>}</b> |
|
|
|
<font color='#009900'>// mini-batches smaller than the default can be used with GPUs having less memory |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> det_minibatch_size <font color='#5555FF'>=</font> argc <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>3</font> ? std::<font color='#BB00BB'>stoi</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>2</font>]<font face='Lucida Console'>)</font> : <font color='#979000'>35</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> seg_minibatch_size <font color='#5555FF'>=</font> argc <font color='#5555FF'>></font><font color='#5555FF'>=</font> <font color='#979000'>4</font> ? std::<font color='#BB00BB'>stoi</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>3</font>]<font face='Lucida Console'>)</font> : <font color='#979000'>100</font>; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>det mini-batch size: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> det_minibatch_size <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>seg mini-batch size: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> seg_minibatch_size <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
std::vector<font color='#5555FF'><</font>std::string<font color='#5555FF'>></font> desired_classlabels; |
|
|
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> arg <font color='#5555FF'>=</font> <font color='#979000'>4</font>; arg <font color='#5555FF'><</font> argc; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>arg<font face='Lucida Console'>)</font> |
|
desired_classlabels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>argv[arg]<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>desired_classlabels.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
desired_classlabels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>bicycle</font>"<font face='Lucida Console'>)</font>; |
|
desired_classlabels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>car</font>"<font face='Lucida Console'>)</font>; |
|
desired_classlabels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>cat</font>"<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>desired classlabels:</font>"; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> desired_classlabel : desired_classlabels<font face='Lucida Console'>)</font> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> desired_classlabel; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
<font color='#009900'>// extract the MMOD rects |
|
</font> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Extracting all truth instances...</font>"; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> truth_instances <font color='#5555FF'>=</font> <font color='#BB00BB'>load_all_truth_instances</font><font face='Lucida Console'>(</font>listing<font face='Lucida Console'>)</font>; |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'> Done!</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> truth_instances.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>; |
|
|
|
std::vector<font color='#5555FF'><</font>truth_image<font color='#5555FF'>></font> original_truth_images; |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>, end <font color='#5555FF'>=</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; i <font color='#5555FF'><</font> end; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
original_truth_images.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>truth_image<b>{</b> |
|
listing[i], truth_instances[i] |
|
<b>}</b><font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>auto</font> truth_images_filtered_by_class <font color='#5555FF'>=</font> <font color='#BB00BB'>filter_based_on_classlabel</font><font face='Lucida Console'>(</font>original_truth_images, desired_classlabels<font face='Lucida Console'>)</font>; |
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>images in dataset filtered by class: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> truth_images_filtered_by_class.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
<font color='#BB00BB'>ignore_some_truth_boxes</font><font face='Lucida Console'>(</font>truth_images_filtered_by_class<font face='Lucida Console'>)</font>; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> truth_images <font color='#5555FF'>=</font> <font color='#BB00BB'>filter_images_with_no_truth</font><font face='Lucida Console'>(</font>truth_images_filtered_by_class<font face='Lucida Console'>)</font>; |
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>images in dataset after ignoring some truth boxes: </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> truth_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
|
|
<font color='#009900'>// First train an object detector network (loss_mmod). |
|
</font> cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Training detector network:</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> det_net <font color='#5555FF'>=</font> <font color='#BB00BB'>train_detection_network</font><font face='Lucida Console'>(</font>truth_images, det_minibatch_size<font face='Lucida Console'>)</font>; |
|
|
|
<font color='#009900'>// Then train mask predictors (segmentation). |
|
</font> std::map<font color='#5555FF'><</font>std::string, seg_bnet_type<font color='#5555FF'>></font> seg_nets_by_class; |
|
|
|
<font color='#009900'>// This flag controls if a separate mask predictor is trained for each class. |
|
</font> <font color='#009900'>// Note that it would also be possible to train a separate mask predictor for |
|
</font> <font color='#009900'>// class groups, each containing somehow similar classes -- for example, one |
|
</font> <font color='#009900'>// mask predictor for cars and buses, another for cats and dogs, and so on. |
|
</font> constexpr <font color='#0000FF'><u>bool</u></font> separate_seg_net_for_each_class <font color='#5555FF'>=</font> <font color='#979000'>true</font>; |
|
|
|
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>separate_seg_net_for_each_class<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&</font> classlabel : desired_classlabels<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
<font color='#009900'>// Consider only the truth images belonging to this class. |
|
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> class_images <font color='#5555FF'>=</font> <font color='#BB00BB'>filter_based_on_classlabel</font><font face='Lucida Console'>(</font>truth_images, <b>{</b> classlabel <b>}</b><font face='Lucida Console'>)</font>; |
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> endl <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Training segmentation network for class </font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> classlabel <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>:</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
seg_nets_by_class[classlabel] <font color='#5555FF'>=</font> <font color='#BB00BB'>train_segmentation_network</font><font face='Lucida Console'>(</font>class_images, seg_minibatch_size, classlabel<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
<b>}</b> |
|
<font color='#0000FF'>else</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Training a single segmentation network:</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
seg_nets_by_class["<font color='#CC0000'></font>"] <font color='#5555FF'>=</font> <font color='#BB00BB'>train_segmentation_network</font><font face='Lucida Console'>(</font>truth_images, seg_minibatch_size, "<font color='#CC0000'></font>"<font face='Lucida Console'>)</font>; |
|
<b>}</b> |
|
|
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> "<font color='#CC0000'>Saving networks</font>" <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>instance_segmentation_net_filename<font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> det_net <font color='#5555FF'><</font><font color='#5555FF'><</font> seg_nets_by_class; |
|
<b>}</b> |
|
|
|
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&</font> e<font face='Lucida Console'>)</font> |
|
<b>{</b> |
|
cout <font color='#5555FF'><</font><font color='#5555FF'><</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'><</font><font color='#5555FF'><</font> endl; |
|
<b>}</b> |
|
|
|
|
|
</pre></body></html> |