Aging_MouthReplace / dlibs /docs /dnn_instance_segmentation_train_ex.cpp.html
AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
<html><!-- Created using the cpp_pretty_printer from the dlib C++ library. See http://dlib.net for updates. --><head><title>dlib C++ Library - dnn_instance_segmentation_train_ex.cpp</title></head><body bgcolor='white'><pre>
<font color='#009900'>// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
</font><font color='#009900'>/*
This example shows how to train a instance segmentation net using the PASCAL VOC2012
dataset. For an introduction to what segmentation is, see the accompanying header file
dnn_instance_segmentation_ex.h.
Instructions how to run the example:
1. Download the PASCAL VOC2012 data, and untar it somewhere.
http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
2. Build the dnn_instance_segmentation_train_ex example program.
3. Run:
./dnn_instance_segmentation_train_ex /path/to/VOC2012
4. Wait while the network is being trained.
5. Build the dnn_instance_segmentation_ex example program.
6. Run:
./dnn_instance_segmentation_ex /path/to/VOC2012-or-other-images
It would be a good idea to become familiar with dlib's DNN tooling before reading this
example. So you should read <a href="dnn_introduction_ex.cpp.html">dnn_introduction_ex.cpp</a>, <a href="dnn_introduction2_ex.cpp.html">dnn_introduction2_ex.cpp</a>,
and <a href="dnn_semantic_segmentation_train_ex.cpp.html">dnn_semantic_segmentation_train_ex.cpp</a> before reading this example program.
*/</font>
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='dnn_instance_segmentation_ex.h.html'>dnn_instance_segmentation_ex.h</a>"
<font color='#0000FF'>#include</font> "<a style='text-decoration:none' href='pascal_voc_2012.h.html'>pascal_voc_2012.h</a>"
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iostream<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>data_io.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>image_transforms.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>dlib<font color='#5555FF'>/</font>dir_nav.h<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>iterator<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>thread<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#if</font> __cplusplus <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>201703</font>L <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font face='Lucida Console'>(</font>defined<font face='Lucida Console'>(</font>_MSVC_LANG<font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> _MSVC_LANG <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>201703</font>L<font face='Lucida Console'>)</font>
<font color='#0000FF'>#include</font> <font color='#5555FF'>&lt;</font>execution<font color='#5555FF'>&gt;</font>
<font color='#0000FF'>#endif</font> <font color='#009900'>// __cplusplus &gt;= 201703L
</font>
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> std;
<font color='#0000FF'>using</font> <font color='#0000FF'>namespace</font> dlib;
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#009900'>// A single training sample for detection. A mini-batch comprises many of these.
</font><font color='#0000FF'>struct</font> <b><a name='det_training_sample'></a>det_training_sample</b>
<b>{</b>
matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> input_image;
std::vector<font color='#5555FF'>&lt;</font>dlib::mmod_rect<font color='#5555FF'>&gt;</font> mmod_rects;
<b>}</b>;
<font color='#009900'>// A single training sample for segmentation. A mini-batch comprises many of these.
</font><font color='#0000FF'>struct</font> <b><a name='seg_training_sample'></a>seg_training_sample</b>
<b>{</b>
matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> input_image;
matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>&gt;</font> label_image; <font color='#009900'>// The ground-truth label of each pixel. (+1 or -1)
</font><b>}</b>;
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>bool</u></font> <b><a name='is_instance_pixel'></a>is_instance_pixel</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dlib::rgb_pixel<font color='#5555FF'>&amp;</font> rgb_label<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rgb_label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> dlib::<font color='#BB00BB'>rgb_pixel</font><font face='Lucida Console'>(</font><font color='#979000'>0</font>, <font color='#979000'>0</font>, <font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font> <font color='#979000'>false</font>; <font color='#009900'>// Background
</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rgb_label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> dlib::<font color='#BB00BB'>rgb_pixel</font><font face='Lucida Console'>(</font><font color='#979000'>224</font>, <font color='#979000'>224</font>, <font color='#979000'>192</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>return</font> <font color='#979000'>false</font>; <font color='#009900'>// The cream-colored `void' label is used in border regions and to mask difficult objects
</font>
<font color='#0000FF'>return</font> <font color='#979000'>true</font>;
<b>}</b>
<font color='#009900'>// Provide hash function for dlib::rgb_pixel
</font><font color='#0000FF'>namespace</font> std <b>{</b>
<font color='#0000FF'>template</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&gt;</font>
<font color='#0000FF'>struct</font> <b><a name='hash'></a>hash</b><font color='#5555FF'>&lt;</font>dlib::rgb_pixel<font color='#5555FF'>&gt;</font>
<b>{</b>
std::<font color='#0000FF'><u>size_t</u></font> <b><a name='operator'></a>operator</b><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> dlib::rgb_pixel<font color='#5555FF'>&amp;</font> p<font face='Lucida Console'>)</font> <font color='#0000FF'>const</font>
<b>{</b>
<font color='#0000FF'>return</font> <font face='Lucida Console'>(</font><font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font>uint32_t<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>p.red<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#979000'>16</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>|</font> <font face='Lucida Console'>(</font><font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font>uint32_t<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>p.green<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> <font color='#979000'>8</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>|</font> <font face='Lucida Console'>(</font><font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font>uint32_t<font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>p.blue<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>;
<b>}</b>
<font color='#0000FF'>struct</font> <b><a name='truth_instance'></a>truth_instance</b>
<b>{</b>
dlib::rgb_pixel rgb_label;
dlib::mmod_rect mmod_rect;
<b>}</b>;
std::vector<font color='#5555FF'>&lt;</font>truth_instance<font color='#5555FF'>&gt;</font> <b><a name='rgb_label_images_to_truth_instances'></a>rgb_label_images_to_truth_instances</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> dlib::matrix<font color='#5555FF'>&lt;</font>dlib::rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> instance_label_image,
<font color='#0000FF'>const</font> dlib::matrix<font color='#5555FF'>&lt;</font>dlib::rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> class_label_image
<font face='Lucida Console'>)</font>
<b>{</b>
std::unordered_map<font color='#5555FF'>&lt;</font>dlib::rgb_pixel, mmod_rect<font color='#5555FF'>&gt;</font> result_map;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>instance_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> class_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>instance_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> class_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> nr <font color='#5555FF'>=</font> instance_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> nc <font color='#5555FF'>=</font> instance_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> r <font color='#5555FF'>=</font> <font color='#979000'>0</font>; r <font color='#5555FF'>&lt;</font> nr; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>r<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> c <font color='#5555FF'>=</font> <font color='#979000'>0</font>; c <font color='#5555FF'>&lt;</font> nc; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>c<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> rgb_instance_label <font color='#5555FF'>=</font> <font color='#BB00BB'>instance_label_image</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font><font color='#BB00BB'>is_instance_pixel</font><font face='Lucida Console'>(</font>rgb_instance_label<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font color='#0000FF'>continue</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> rgb_class_label <font color='#5555FF'>=</font> <font color='#BB00BB'>class_label_image</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> Voc2012class<font color='#5555FF'>&amp;</font> voc2012_class <font color='#5555FF'>=</font> <font color='#BB00BB'>find_voc2012_class</font><font face='Lucida Console'>(</font>rgb_class_label<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> i <font color='#5555FF'>=</font> result_map.<font color='#BB00BB'>find</font><font face='Lucida Console'>(</font>rgb_instance_label<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>i <font color='#5555FF'>=</font><font color='#5555FF'>=</font> result_map.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Encountered a new instance
</font> result_map[rgb_instance_label] <font color='#5555FF'>=</font> <font color='#BB00BB'>rectangle</font><font face='Lucida Console'>(</font>c, r, c, r<font face='Lucida Console'>)</font>;
result_map[rgb_instance_label].label <font color='#5555FF'>=</font> voc2012_class.classlabel;
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
<font color='#009900'>// Not the first occurrence - update the rect
</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> rect <font color='#5555FF'>=</font> i<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>second.rect;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>c <font color='#5555FF'>&lt;</font> rect.<font color='#BB00BB'>left</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
rect.<font color='#BB00BB'>set_left</font><font face='Lucida Console'>(</font>c<font face='Lucida Console'>)</font>;
<font color='#0000FF'>else</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>c <font color='#5555FF'>&gt;</font> rect.<font color='#BB00BB'>right</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
rect.<font color='#BB00BB'>set_right</font><font face='Lucida Console'>(</font>c<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>r <font color='#5555FF'>&gt;</font> rect.<font color='#BB00BB'>bottom</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
rect.<font color='#BB00BB'>set_bottom</font><font face='Lucida Console'>(</font>r<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>i<font color='#5555FF'>-</font><font color='#5555FF'>&gt;</font>second.label <font color='#5555FF'>=</font><font color='#5555FF'>=</font> voc2012_class.classlabel<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<b>}</b>
std::vector<font color='#5555FF'>&lt;</font>truth_instance<font color='#5555FF'>&gt;</font> flat_result;
flat_result.<font color='#BB00BB'>reserve</font><font face='Lucida Console'>(</font>result_map.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> i : result_map<font face='Lucida Console'>)</font> <b>{</b>
flat_result.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>truth_instance<b>{</b>
i.first, i.second
<b>}</b><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>return</font> flat_result;
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'>struct</font> <b><a name='truth_image'></a>truth_image</b>
<b>{</b>
image_info info;
std::vector<font color='#5555FF'>&lt;</font>truth_instance<font color='#5555FF'>&gt;</font> truth_instances;
<b>}</b>;
std::vector<font color='#5555FF'>&lt;</font>mmod_rect<font color='#5555FF'>&gt;</font> <b><a name='extract_mmod_rects'></a>extract_mmod_rects</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>truth_instance<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> truth_instances
<font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'>&lt;</font>mmod_rect<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>mmod_rects</font><font face='Lucida Console'>(</font>truth_instances.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
std::<font color='#BB00BB'>transform</font><font face='Lucida Console'>(</font>
truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
truth_instances.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
mmod_rects.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
[]<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> truth_instance<font color='#5555FF'>&amp;</font> truth<font face='Lucida Console'>)</font> <b>{</b> <font color='#0000FF'>return</font> truth.mmod_rect; <b>}</b>
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> mmod_rects;
<b>}</b>
std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>mmod_rect<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <b><a name='extract_mmod_rect_vectors'></a>extract_mmod_rect_vectors</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> truth_images
<font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>mmod_rect<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>mmod_rects</font><font face='Lucida Console'>(</font>truth_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> extract_mmod_rects_from_truth_image <font color='#5555FF'>=</font> []<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> truth_image<font color='#5555FF'>&amp;</font> truth_image<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>return</font> <font color='#BB00BB'>extract_mmod_rects</font><font face='Lucida Console'>(</font>truth_image.truth_instances<font face='Lucida Console'>)</font>;
<b>}</b>;
std::<font color='#BB00BB'>transform</font><font face='Lucida Console'>(</font>
truth_images.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
truth_images.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
mmod_rects.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
extract_mmod_rects_from_truth_image
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> mmod_rects;
<b>}</b>
det_bnet_type <b><a name='train_detection_network'></a>train_detection_network</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> truth_images,
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> det_minibatch_size
<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> min_detector_window_overlap_iou <font color='#5555FF'>=</font> <font color='#979000'>0.65</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> target_size <font color='#5555FF'>=</font> <font color='#979000'>70</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>int</u></font> min_target_size <font color='#5555FF'>=</font> <font color='#979000'>30</font>;
mmod_options <font color='#BB00BB'>options</font><font face='Lucida Console'>(</font>
<font color='#BB00BB'>extract_mmod_rect_vectors</font><font face='Lucida Console'>(</font>truth_images<font face='Lucida Console'>)</font>,
target_size, min_target_size,
min_detector_window_overlap_iou
<font face='Lucida Console'>)</font>;
options.overlaps_ignore <font color='#5555FF'>=</font> <font color='#BB00BB'>test_box_overlap</font><font face='Lucida Console'>(</font><font color='#979000'>0.5</font>, <font color='#979000'>0.9</font><font face='Lucida Console'>)</font>;
det_bnet_type <font color='#BB00BB'>det_net</font><font face='Lucida Console'>(</font>options<font face='Lucida Console'>)</font>;
det_net.<font color='#BB00BB'>subnet</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>layer_details</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>.<font color='#BB00BB'>set_num_filters</font><font face='Lucida Console'>(</font>options.detector_windows.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
dlib::pipe<font color='#5555FF'>&lt;</font>det_training_sample<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>data, <font color='#5555FF'>&amp;</font>truth_images, target_size, min_target_size]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font>
<b>{</b>
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> seed<font face='Lucida Console'>)</font>;
matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> input_image;
random_cropper cropper;
cropper.<font color='#BB00BB'>set_seed</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
cropper.<font color='#BB00BB'>set_chip_dims</font><font face='Lucida Console'>(</font><font color='#979000'>350</font>, <font color='#979000'>350</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Usually you want to give the cropper whatever min sizes you passed to the
</font> <font color='#009900'>// mmod_options constructor, or very slightly smaller sizes, which is what we do here.
</font> cropper.<font color='#BB00BB'>set_min_object_size</font><font face='Lucida Console'>(</font>target_size <font color='#5555FF'>-</font> <font color='#979000'>2</font>, min_target_size <font color='#5555FF'>-</font> <font color='#979000'>2</font><font face='Lucida Console'>)</font>;
cropper.<font color='#BB00BB'>set_max_rotation_degrees</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>;
det_training_sample temp;
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Pick a random input image.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> random_index <font color='#5555FF'>=</font> rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>%</font> truth_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> truth_image <font color='#5555FF'>=</font> truth_images[random_index];
<font color='#009900'>// Load the input image.
</font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>input_image, truth_image.info.image_filename<font face='Lucida Console'>)</font>;
<font color='#009900'>// Get a random crop of the input.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> mmod_rects <font color='#5555FF'>=</font> <font color='#BB00BB'>extract_mmod_rects</font><font face='Lucida Console'>(</font>truth_image.truth_instances<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>cropper</font><font face='Lucida Console'>(</font>input_image, mmod_rects, temp.input_image, temp.mmod_rects<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>disturb_colors</font><font face='Lucida Console'>(</font>temp.input_image, rnd<font face='Lucida Console'>)</font>;
<font color='#009900'>// Push the result to be used by the trainer.
</font> data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>;
std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> stop_data_loaders <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>;
dnn_trainer<font color='#5555FF'>&lt;</font>det_bnet_type<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>det_trainer</font><font face='Lucida Console'>(</font>det_net, <font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>try</font>
<b>{</b>
det_trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
det_trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>;
det_trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>pascal_voc2012_det_trainer_state_file.dat</font>", std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
det_trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>5000</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Output training parameters.
</font> cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> det_trainer <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> samples;
std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>mmod_rect<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> labels;
<font color='#009900'>// The main training loop. Keep making mini-batches and giving them to the trainer.
</font> <font color='#009900'>// We will run until the learning rate becomes small enough.
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>det_trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>
<b>{</b>
samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// make a mini-batch
</font> det_training_sample temp;
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> det_minibatch_size<font face='Lucida Console'>)</font>
<b>{</b>
data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.input_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.mmod_rects<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
det_trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'>catch</font> <font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>stop_data_loaders</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>throw</font>;
<b>}</b>
<font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before
</font> <font color='#009900'>// moving on.
</font> <font color='#BB00BB'>stop_data_loaders</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// also wait for threaded processing to stop in the trainer.
</font> det_trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
det_net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> det_net;
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>&gt;</font> <b><a name='keep_only_current_instance'></a>keep_only_current_instance</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> rgb_label_image, <font color='#0000FF'>const</font> rgb_pixel rgb_label<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> nr <font color='#5555FF'>=</font> rgb_label_image.<font color='#BB00BB'>nr</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> nc <font color='#5555FF'>=</font> rgb_label_image.<font color='#BB00BB'>nc</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>result</font><font face='Lucida Console'>(</font>nr, nc<font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> r <font color='#5555FF'>=</font> <font color='#979000'>0</font>; r <font color='#5555FF'>&lt;</font> nr; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>r<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>long</u></font> c <font color='#5555FF'>=</font> <font color='#979000'>0</font>; c <font color='#5555FF'>&lt;</font> nc; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>c<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> index <font color='#5555FF'>=</font> <font color='#BB00BB'>rgb_label_image</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>index <font color='#5555FF'>=</font><font color='#5555FF'>=</font> rgb_label<font face='Lucida Console'>)</font>
<font color='#BB00BB'>result</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#5555FF'>+</font><font color='#979000'>1</font>;
<font color='#0000FF'>else</font> <font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>index <font color='#5555FF'>=</font><font color='#5555FF'>=</font> dlib::<font color='#BB00BB'>rgb_pixel</font><font face='Lucida Console'>(</font><font color='#979000'>224</font>, <font color='#979000'>224</font>, <font color='#979000'>192</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font color='#BB00BB'>result</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>else</font>
<font color='#BB00BB'>result</font><font face='Lucida Console'>(</font>r, c<font face='Lucida Console'>)</font> <font color='#5555FF'>=</font> <font color='#5555FF'>-</font><font color='#979000'>1</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'>return</font> result;
<b>}</b>
seg_bnet_type <b><a name='train_segmentation_network'></a>train_segmentation_network</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> truth_images,
<font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> seg_minibatch_size,
<font color='#0000FF'>const</font> std::string<font color='#5555FF'>&amp;</font> classlabel
<font face='Lucida Console'>)</font>
<b>{</b>
seg_bnet_type seg_net;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> initial_learning_rate <font color='#5555FF'>=</font> <font color='#979000'>0.1</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> weight_decay <font color='#5555FF'>=</font> <font color='#979000'>0.0001</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> momentum <font color='#5555FF'>=</font> <font color='#979000'>0.9</font>;
<font color='#0000FF'>const</font> std::string synchronization_file_name
<font color='#5555FF'>=</font> "<font color='#CC0000'>pascal_voc2012_seg_trainer_state_file</font>"
<font color='#5555FF'>+</font> <font face='Lucida Console'>(</font>classlabel.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> ? "<font color='#CC0000'></font>" : <font face='Lucida Console'>(</font>"<font color='#CC0000'>_</font>" <font color='#5555FF'>+</font> classlabel<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<font color='#5555FF'>+</font> "<font color='#CC0000'>.dat</font>";
dnn_trainer<font color='#5555FF'>&lt;</font>seg_bnet_type<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>seg_trainer</font><font face='Lucida Console'>(</font>seg_net, <font color='#BB00BB'>sgd</font><font face='Lucida Console'>(</font>weight_decay, momentum<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
seg_trainer.<font color='#BB00BB'>be_verbose</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
seg_trainer.<font color='#BB00BB'>set_learning_rate</font><font face='Lucida Console'>(</font>initial_learning_rate<font face='Lucida Console'>)</font>;
seg_trainer.<font color='#BB00BB'>set_synchronization_file</font><font face='Lucida Console'>(</font>synchronization_file_name, std::chrono::<font color='#BB00BB'>minutes</font><font face='Lucida Console'>(</font><font color='#979000'>10</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
seg_trainer.<font color='#BB00BB'>set_iterations_without_progress_threshold</font><font face='Lucida Console'>(</font><font color='#979000'>2000</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>set_all_bn_running_stats_window_sizes</font><font face='Lucida Console'>(</font>seg_net, <font color='#979000'>1000</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Output training parameters.
</font> cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> seg_trainer <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> samples;
std::vector<font color='#5555FF'>&lt;</font>matrix<font color='#5555FF'>&lt;</font><font color='#0000FF'><u>float</u></font><font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> labels;
<font color='#009900'>// Start a bunch of threads that read images from disk and pull out random crops. It's
</font> <font color='#009900'>// important to be sure to feed the GPU fast enough to keep it busy. Using multiple
</font> <font color='#009900'>// thread for this kind of data preparation helps us do that. Each thread puts the
</font> <font color='#009900'>// crops into the data queue.
</font> dlib::pipe<font color='#5555FF'>&lt;</font>seg_training_sample<font color='#5555FF'>&gt;</font> <font color='#BB00BB'>data</font><font face='Lucida Console'>(</font><font color='#979000'>200</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>auto</font> f <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>data, <font color='#5555FF'>&amp;</font>truth_images]<font face='Lucida Console'>(</font>time_t seed<font face='Lucida Console'>)</font>
<b>{</b>
dlib::rand <font color='#BB00BB'>rnd</font><font face='Lucida Console'>(</font><font color='#BB00BB'>time</font><font face='Lucida Console'>(</font><font color='#979000'>0</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> seed<font face='Lucida Console'>)</font>;
matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> input_image;
matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> rgb_label_image;
matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> rgb_label_chip;
seg_training_sample temp;
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>data.<font color='#BB00BB'>is_enabled</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Pick a random input image.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> random_index <font color='#5555FF'>=</font> rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>%</font> truth_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> truth_image <font color='#5555FF'>=</font> truth_images[random_index];
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> image_truths <font color='#5555FF'>=</font> truth_image.truth_instances;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>image_truths.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> image_info<font color='#5555FF'>&amp;</font> info <font color='#5555FF'>=</font> truth_image.info;
<font color='#009900'>// Load the input image.
</font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>input_image, info.image_filename<font face='Lucida Console'>)</font>;
<font color='#009900'>// Load the ground-truth (RGB) instance labels.
</font> <font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>rgb_label_image, info.instance_label_filename<font face='Lucida Console'>)</font>;
<font color='#009900'>// Pick a random training instance.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> truth_instance <font color='#5555FF'>=</font> image_truths[rnd.<font color='#BB00BB'>get_random_32bit_number</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>%</font> image_truths.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>];
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> truth_rect <font color='#5555FF'>=</font> truth_instance.mmod_rect.rect;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> cropping_rect <font color='#5555FF'>=</font> <font color='#BB00BB'>get_cropping_rect</font><font face='Lucida Console'>(</font>truth_rect<font face='Lucida Console'>)</font>;
<font color='#009900'>// Pick a random crop around the instance.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> max_x_translate_amount <font color='#5555FF'>=</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>truth_rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#979000'>10.0</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> max_y_translate_amount <font color='#5555FF'>=</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>long</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>truth_rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#979000'>10.0</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> random_translate <font color='#5555FF'>=</font> <font color='#BB00BB'>point</font><font face='Lucida Console'>(</font>
rnd.<font color='#BB00BB'>get_integer_in_range</font><font face='Lucida Console'>(</font><font color='#5555FF'>-</font>max_x_translate_amount, max_x_translate_amount <font color='#5555FF'>+</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>,
rnd.<font color='#BB00BB'>get_integer_in_range</font><font face='Lucida Console'>(</font><font color='#5555FF'>-</font>max_y_translate_amount, max_y_translate_amount <font color='#5555FF'>+</font> <font color='#979000'>1</font><font face='Lucida Console'>)</font>
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> rectangle <font color='#BB00BB'>random_rect</font><font face='Lucida Console'>(</font>
cropping_rect.<font color='#BB00BB'>left</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> random_translate.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
cropping_rect.<font color='#BB00BB'>top</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> random_translate.<font color='#BB00BB'>y</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
cropping_rect.<font color='#BB00BB'>right</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> random_translate.<font color='#BB00BB'>x</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
cropping_rect.<font color='#BB00BB'>bottom</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>+</font> random_translate.<font color='#BB00BB'>y</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> chip_details <font color='#BB00BB'>chip_details</font><font face='Lucida Console'>(</font>random_rect, <font color='#BB00BB'>chip_dims</font><font face='Lucida Console'>(</font>seg_dim, seg_dim<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Crop the input image.
</font> <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>input_image, chip_details, temp.input_image, <font color='#BB00BB'>interpolate_bilinear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#BB00BB'>disturb_colors</font><font face='Lucida Console'>(</font>temp.input_image, rnd<font face='Lucida Console'>)</font>;
<font color='#009900'>// Crop the labels correspondingly. However, note that here bilinear
</font> <font color='#009900'>// interpolation would make absolutely no sense - you wouldn't say that
</font> <font color='#009900'>// a bicycle is half-way between an aeroplane and a bird, would you?
</font> <font color='#BB00BB'>extract_image_chip</font><font face='Lucida Console'>(</font>rgb_label_image, chip_details, rgb_label_chip, <font color='#BB00BB'>interpolate_nearest_neighbor</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// Clear pixels not related to the current instance.
</font> temp.label_image <font color='#5555FF'>=</font> <font color='#BB00BB'>keep_only_current_instance</font><font face='Lucida Console'>(</font>rgb_label_chip, truth_instance.rgb_label<font face='Lucida Console'>)</font>;
<font color='#009900'>// Push the result to be used by the trainer.
</font> data.<font color='#BB00BB'>enqueue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
<font color='#009900'>// TODO: use background samples as well
</font> <b>}</b>
<b>}</b>
<b>}</b>;
std::thread <font color='#BB00BB'>data_loader1</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>1</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
std::thread <font color='#BB00BB'>data_loader2</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>2</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
std::thread <font color='#BB00BB'>data_loader3</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>3</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
std::thread <font color='#BB00BB'>data_loader4</font><font face='Lucida Console'>(</font>[f]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <b>{</b> <font color='#BB00BB'>f</font><font face='Lucida Console'>(</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>; <b>}</b><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> stop_data_loaders <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>]<font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>
<b>{</b>
data.<font color='#BB00BB'>disable</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader1.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader2.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader3.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
data_loader4.<font color='#BB00BB'>join</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>;
<font color='#0000FF'>try</font>
<b>{</b>
<font color='#009900'>// The main training loop. Keep making mini-batches and giving them to the trainer.
</font> <font color='#009900'>// We will run until the learning rate has dropped by a factor of 1e-4.
</font> <font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>seg_trainer.<font color='#BB00BB'>get_learning_rate</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>1e</font><font color='#5555FF'>-</font><font color='#979000'>4</font><font face='Lucida Console'>)</font>
<b>{</b>
samples.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>clear</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// make a mini-batch
</font> seg_training_sample temp;
<font color='#0000FF'>while</font> <font face='Lucida Console'>(</font>samples.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> seg_minibatch_size<font face='Lucida Console'>)</font>
<b>{</b>
data.<font color='#BB00BB'>dequeue</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>;
samples.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.input_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
labels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>std::<font color='#BB00BB'>move</font><font face='Lucida Console'>(</font>temp.label_image<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<b>}</b>
seg_trainer.<font color='#BB00BB'>train_one_step</font><font face='Lucida Console'>(</font>samples, labels<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'>catch</font> <font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&amp;</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#BB00BB'>stop_data_loaders</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>throw</font>;
<b>}</b>
<font color='#009900'>// Training done, tell threads to stop and make sure to wait for them to finish before
</font> <font color='#009900'>// moving on.
</font> <font color='#BB00BB'>stop_data_loaders</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#009900'>// also wait for threaded processing to stop in the trainer.
</font> seg_trainer.<font color='#BB00BB'>get_net</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
seg_net.<font color='#BB00BB'>clean</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> seg_net;
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
<font color='#0000FF'><u>int</u></font> <b><a name='ignore_overlapped_boxes'></a>ignore_overlapped_boxes</b><font face='Lucida Console'>(</font>
std::vector<font color='#5555FF'>&lt;</font>truth_instance<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> truth_instances,
<font color='#0000FF'>const</font> test_box_overlap<font color='#5555FF'>&amp;</font> overlaps
<font face='Lucida Console'>)</font>
<font color='#009900'>/*!
ensures
- Whenever two rectangles in boxes overlap, according to overlaps(), we set the
smallest box to ignore.
- returns the number of newly ignored boxes.
!*/</font>
<b>{</b>
<font color='#0000FF'><u>int</u></font> num_ignored <font color='#5555FF'>=</font> <font color='#979000'>0</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>, end <font color='#5555FF'>=</font> truth_instances.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; i <font color='#5555FF'>&lt;</font> end; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> box_i <font color='#5555FF'>=</font> truth_instances[i].mmod_rect;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>box_i.ignore<font face='Lucida Console'>)</font>
<font color='#0000FF'>continue</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> j <font color='#5555FF'>=</font> i<font color='#5555FF'>+</font><font color='#979000'>1</font>; j <font color='#5555FF'>&lt;</font> end; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>j<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> box_j <font color='#5555FF'>=</font> truth_instances[j].mmod_rect;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>box_j.ignore<font face='Lucida Console'>)</font>
<font color='#0000FF'>continue</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#BB00BB'>overlaps</font><font face='Lucida Console'>(</font>box_i, box_j<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
<font color='#5555FF'>+</font><font color='#5555FF'>+</font>num_ignored;
<font color='#0000FF'>if</font><font face='Lucida Console'>(</font>box_i.rect.<font color='#BB00BB'>area</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> box_j.rect.<font color='#BB00BB'>area</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
box_i.ignore <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<font color='#0000FF'>else</font>
box_j.ignore <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<b>}</b>
<b>}</b>
<b>}</b>
<font color='#0000FF'>return</font> num_ignored;
<b>}</b>
std::vector<font color='#5555FF'>&lt;</font>truth_instance<font color='#5555FF'>&gt;</font> <b><a name='load_truth_instances'></a>load_truth_instances</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> image_info<font color='#5555FF'>&amp;</font> info<font face='Lucida Console'>)</font>
<b>{</b>
matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> instance_label_image;
matrix<font color='#5555FF'>&lt;</font>rgb_pixel<font color='#5555FF'>&gt;</font> class_label_image;
<font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>instance_label_image, info.instance_label_filename<font face='Lucida Console'>)</font>;
<font color='#BB00BB'>load_image</font><font face='Lucida Console'>(</font>class_label_image, info.class_label_filename<font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> <font color='#BB00BB'>rgb_label_images_to_truth_instances</font><font face='Lucida Console'>(</font>instance_label_image, class_label_image<font face='Lucida Console'>)</font>;
<b>}</b>
std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>truth_instance<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <b><a name='load_all_truth_instances'></a>load_all_truth_instances</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>image_info<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> listing<font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'>&lt;</font>std::vector<font color='#5555FF'>&lt;</font>truth_instance<font color='#5555FF'>&gt;</font><font color='#5555FF'>&gt;</font> <font color='#BB00BB'>truth_instances</font><font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
std::<font color='#BB00BB'>transform</font><font face='Lucida Console'>(</font>
<font color='#0000FF'>#if</font> __cplusplus <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>201703</font>L <font color='#5555FF'>|</font><font color='#5555FF'>|</font> <font face='Lucida Console'>(</font>defined<font face='Lucida Console'>(</font>_MSVC_LANG<font face='Lucida Console'>)</font> <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> _MSVC_LANG <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>201703</font>L<font face='Lucida Console'>)</font>
std::execution::par,
<font color='#0000FF'>#endif</font> <font color='#009900'>// __cplusplus &gt;= 201703L
</font> listing.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
listing.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
load_truth_instances
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>return</font> truth_instances;
<b>}</b>
<font color='#009900'>// ----------------------------------------------------------------------------------------
</font>
std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font> <b><a name='filter_based_on_classlabel'></a>filter_based_on_classlabel</b><font face='Lucida Console'>(</font>
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> truth_images,
<font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>std::string<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> desired_classlabels
<font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font> result;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> represents_desired_class <font color='#5555FF'>=</font> [<font color='#5555FF'>&amp;</font>desired_classlabels]<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> truth_instance<font color='#5555FF'>&amp;</font> truth_instance<font face='Lucida Console'>)</font> <b>{</b>
<font color='#0000FF'>return</font> std::<font color='#BB00BB'>find</font><font face='Lucida Console'>(</font>
desired_classlabels.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
desired_classlabels.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
truth_instance.mmod_rect.label
<font face='Lucida Console'>)</font> <font color='#5555FF'>!</font><font color='#5555FF'>=</font> desired_classlabels.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>;
<b>}</b>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> input : truth_images<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> has_desired_class <font color='#5555FF'>=</font> std::<font color='#BB00BB'>any_of</font><font face='Lucida Console'>(</font>
input.truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
input.truth_instances.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
represents_desired_class
<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>has_desired_class<font face='Lucida Console'>)</font> <b>{</b>
<font color='#009900'>// NB: This keeps only MMOD rects belonging to any of the desired classes.
</font> <font color='#009900'>// A reasonable alternative could be to keep all rects, but mark those
</font> <font color='#009900'>// belonging in other classes to be ignored during training.
</font> std::vector<font color='#5555FF'>&lt;</font>truth_instance<font color='#5555FF'>&gt;</font> temp;
std::<font color='#BB00BB'>copy_if</font><font face='Lucida Console'>(</font>
input.truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
input.truth_instances.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>,
std::<font color='#BB00BB'>back_inserter</font><font face='Lucida Console'>(</font>temp<font face='Lucida Console'>)</font>,
represents_desired_class
<font face='Lucida Console'>)</font>;
result.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>truth_image<b>{</b> input.info, temp <b>}</b><font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'>return</font> result;
<b>}</b>
<font color='#009900'>// Ignore truth boxes that overlap too much, are too small, or have a large aspect ratio.
</font><font color='#0000FF'><u>void</u></font> <b><a name='ignore_some_truth_boxes'></a>ignore_some_truth_boxes</b><font face='Lucida Console'>(</font>std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> truth_images<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> i : truth_images<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> truth_instances <font color='#5555FF'>=</font> i.truth_instances;
<font color='#BB00BB'>ignore_overlapped_boxes</font><font face='Lucida Console'>(</font>truth_instances, <font color='#BB00BB'>test_box_overlap</font><font face='Lucida Console'>(</font><font color='#979000'>0.90</font>, <font color='#979000'>0.95</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> truth : truth_instances<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>truth.mmod_rect.ignore<font face='Lucida Console'>)</font>
<font color='#0000FF'>continue</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> rect <font color='#5555FF'>=</font> truth.mmod_rect.rect;
constexpr <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> min_width <font color='#5555FF'>=</font> <font color='#979000'>35</font>;
constexpr <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>long</u></font> min_height <font color='#5555FF'>=</font> <font color='#979000'>35</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> min_width <font color='#5555FF'>&amp;</font><font color='#5555FF'>&amp;</font> rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font> min_height<font face='Lucida Console'>)</font>
<b>{</b>
truth.mmod_rect.ignore <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<font color='#0000FF'>continue</font>;
<b>}</b>
constexpr <font color='#0000FF'><u>double</u></font> max_aspect_ratio_width_to_height <font color='#5555FF'>=</font> <font color='#979000'>3.0</font>;
constexpr <font color='#0000FF'><u>double</u></font> max_aspect_ratio_height_to_width <font color='#5555FF'>=</font> <font color='#979000'>1.5</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> aspect_ratio_width_to_height <font color='#5555FF'>=</font> rect.<font color='#BB00BB'>width</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>/</font> <font color='#0000FF'>static_cast</font><font color='#5555FF'>&lt;</font><font color='#0000FF'><u>double</u></font><font color='#5555FF'>&gt;</font><font face='Lucida Console'>(</font>rect.<font color='#BB00BB'>height</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>double</u></font> aspect_ratio_height_to_width <font color='#5555FF'>=</font> <font color='#979000'>1.0</font> <font color='#5555FF'>/</font> aspect_ratio_width_to_height;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>bool</u></font> is_aspect_ratio_too_large
<font color='#5555FF'>=</font> aspect_ratio_width_to_height <font color='#5555FF'>&gt;</font> max_aspect_ratio_width_to_height
<font color='#5555FF'>|</font><font color='#5555FF'>|</font> aspect_ratio_height_to_width <font color='#5555FF'>&gt;</font> max_aspect_ratio_height_to_width;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>is_aspect_ratio_too_large<font face='Lucida Console'>)</font>
truth.mmod_rect.ignore <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<b>}</b>
<b>}</b>
<b>}</b>
<font color='#009900'>// Filter images that have no (non-ignored) truth
</font>std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font> <b><a name='filter_images_with_no_truth'></a>filter_images_with_no_truth</b><font face='Lucida Console'>(</font><font color='#0000FF'>const</font> std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font><font color='#5555FF'>&amp;</font> truth_images<font face='Lucida Console'>)</font>
<b>{</b>
std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font> result;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> truth_image : truth_images<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> ignored <font color='#5555FF'>=</font> []<font face='Lucida Console'>(</font><font color='#0000FF'>const</font> truth_instance<font color='#5555FF'>&amp;</font> truth<font face='Lucida Console'>)</font> <b>{</b> <font color='#0000FF'>return</font> truth.mmod_rect.ignore; <b>}</b>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> truth_instances <font color='#5555FF'>=</font> truth_image.truth_instances;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font><font color='#5555FF'>!</font>std::<font color='#BB00BB'>all_of</font><font face='Lucida Console'>(</font>truth_instances.<font color='#BB00BB'>begin</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, truth_instances.<font color='#BB00BB'>end</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>, ignored<font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
result.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>truth_image<font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>return</font> result;
<b>}</b>
<font color='#0000FF'><u>int</u></font> <b><a name='main'></a>main</b><font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> argc, <font color='#0000FF'><u>char</u></font><font color='#5555FF'>*</font><font color='#5555FF'>*</font> argv<font face='Lucida Console'>)</font> <font color='#0000FF'>try</font>
<b>{</b>
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>argc <font color='#5555FF'>&lt;</font> <font color='#979000'>2</font><font face='Lucida Console'>)</font>
<b>{</b>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>To run this program you need a copy of the PASCAL VOC2012 dataset.</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>You call this program like this: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>./dnn_instance_segmentation_train_ex /path/to/VOC2012 [det-minibatch-size] [seg-minibatch-size] [class-1] [class-2] [class-3] ...</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#0000FF'>return</font> <font color='#979000'>1</font>;
<b>}</b>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>\nSCANNING PASCAL VOC2012 DATASET\n</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> listing <font color='#5555FF'>=</font> <font color='#BB00BB'>get_pascal_voc2012_train_listing</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>1</font>]<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>images in entire dataset: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> <font color='#979000'>0</font><font face='Lucida Console'>)</font>
<b>{</b>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Didn't find the VOC2012 dataset. </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#0000FF'>return</font> <font color='#979000'>1</font>;
<b>}</b>
<font color='#009900'>// mini-batches smaller than the default can be used with GPUs having less memory
</font> <font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> det_minibatch_size <font color='#5555FF'>=</font> argc <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>3</font> ? std::<font color='#BB00BB'>stoi</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>2</font>]<font face='Lucida Console'>)</font> : <font color='#979000'>35</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'><u>unsigned</u></font> <font color='#0000FF'><u>int</u></font> seg_minibatch_size <font color='#5555FF'>=</font> argc <font color='#5555FF'>&gt;</font><font color='#5555FF'>=</font> <font color='#979000'>4</font> ? std::<font color='#BB00BB'>stoi</font><font face='Lucida Console'>(</font>argv[<font color='#979000'>3</font>]<font face='Lucida Console'>)</font> : <font color='#979000'>100</font>;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>det mini-batch size: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> det_minibatch_size <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>seg mini-batch size: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> seg_minibatch_size <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
std::vector<font color='#5555FF'>&lt;</font>std::string<font color='#5555FF'>&gt;</font> desired_classlabels;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>int</u></font> arg <font color='#5555FF'>=</font> <font color='#979000'>4</font>; arg <font color='#5555FF'>&lt;</font> argc; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>arg<font face='Lucida Console'>)</font>
desired_classlabels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>argv[arg]<font face='Lucida Console'>)</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>desired_classlabels.<font color='#BB00BB'>empty</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>
<b>{</b>
desired_classlabels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>bicycle</font>"<font face='Lucida Console'>)</font>;
desired_classlabels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>car</font>"<font face='Lucida Console'>)</font>;
desired_classlabels.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>"<font color='#CC0000'>cat</font>"<font face='Lucida Console'>)</font>;
<b>}</b>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>desired classlabels:</font>";
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> desired_classlabel : desired_classlabels<font face='Lucida Console'>)</font>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> desired_classlabel;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#009900'>// extract the MMOD rects
</font> cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Extracting all truth instances...</font>";
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> truth_instances <font color='#5555FF'>=</font> <font color='#BB00BB'>load_all_truth_instances</font><font face='Lucida Console'>(</font>listing<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'> Done!</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#BB00BB'>DLIB_CASSERT</font><font face='Lucida Console'>(</font>listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>=</font><font color='#5555FF'>=</font> truth_instances.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font><font face='Lucida Console'>)</font>;
std::vector<font color='#5555FF'>&lt;</font>truth_image<font color='#5555FF'>&gt;</font> original_truth_images;
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'><u>size_t</u></font> i <font color='#5555FF'>=</font> <font color='#979000'>0</font>, end <font color='#5555FF'>=</font> listing.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font>; i <font color='#5555FF'>&lt;</font> end; <font color='#5555FF'>+</font><font color='#5555FF'>+</font>i<font face='Lucida Console'>)</font>
<b>{</b>
original_truth_images.<font color='#BB00BB'>push_back</font><font face='Lucida Console'>(</font>truth_image<b>{</b>
listing[i], truth_instances[i]
<b>}</b><font face='Lucida Console'>)</font>;
<b>}</b>
<font color='#0000FF'>auto</font> truth_images_filtered_by_class <font color='#5555FF'>=</font> <font color='#BB00BB'>filter_based_on_classlabel</font><font face='Lucida Console'>(</font>original_truth_images, desired_classlabels<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>images in dataset filtered by class: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> truth_images_filtered_by_class.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#BB00BB'>ignore_some_truth_boxes</font><font face='Lucida Console'>(</font>truth_images_filtered_by_class<font face='Lucida Console'>)</font>;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> truth_images <font color='#5555FF'>=</font> <font color='#BB00BB'>filter_images_with_no_truth</font><font face='Lucida Console'>(</font>truth_images_filtered_by_class<font face='Lucida Console'>)</font>;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>images in dataset after ignoring some truth boxes: </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> truth_images.<font color='#BB00BB'>size</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#009900'>// First train an object detector network (loss_mmod).
</font> cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Training detector network:</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> det_net <font color='#5555FF'>=</font> <font color='#BB00BB'>train_detection_network</font><font face='Lucida Console'>(</font>truth_images, det_minibatch_size<font face='Lucida Console'>)</font>;
<font color='#009900'>// Then train mask predictors (segmentation).
</font> std::map<font color='#5555FF'>&lt;</font>std::string, seg_bnet_type<font color='#5555FF'>&gt;</font> seg_nets_by_class;
<font color='#009900'>// This flag controls if a separate mask predictor is trained for each class.
</font> <font color='#009900'>// Note that it would also be possible to train a separate mask predictor for
</font> <font color='#009900'>// class groups, each containing somehow similar classes -- for example, one
</font> <font color='#009900'>// mask predictor for cars and buses, another for cats and dogs, and so on.
</font> constexpr <font color='#0000FF'><u>bool</u></font> separate_seg_net_for_each_class <font color='#5555FF'>=</font> <font color='#979000'>true</font>;
<font color='#0000FF'>if</font> <font face='Lucida Console'>(</font>separate_seg_net_for_each_class<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#0000FF'>for</font> <font face='Lucida Console'>(</font><font color='#0000FF'>const</font> <font color='#0000FF'>auto</font><font color='#5555FF'>&amp;</font> classlabel : desired_classlabels<font face='Lucida Console'>)</font>
<b>{</b>
<font color='#009900'>// Consider only the truth images belonging to this class.
</font> <font color='#0000FF'>const</font> <font color='#0000FF'>auto</font> class_images <font color='#5555FF'>=</font> <font color='#BB00BB'>filter_based_on_classlabel</font><font face='Lucida Console'>(</font>truth_images, <b>{</b> classlabel <b>}</b><font face='Lucida Console'>)</font>;
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Training segmentation network for class </font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> classlabel <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>:</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
seg_nets_by_class[classlabel] <font color='#5555FF'>=</font> <font color='#BB00BB'>train_segmentation_network</font><font face='Lucida Console'>(</font>class_images, seg_minibatch_size, classlabel<font face='Lucida Console'>)</font>;
<b>}</b>
<b>}</b>
<font color='#0000FF'>else</font>
<b>{</b>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Training a single segmentation network:</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
seg_nets_by_class["<font color='#CC0000'></font>"] <font color='#5555FF'>=</font> <font color='#BB00BB'>train_segmentation_network</font><font face='Lucida Console'>(</font>truth_images, seg_minibatch_size, "<font color='#CC0000'></font>"<font face='Lucida Console'>)</font>;
<b>}</b>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> "<font color='#CC0000'>Saving networks</font>" <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<font color='#BB00BB'>serialize</font><font face='Lucida Console'>(</font>instance_segmentation_net_filename<font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> det_net <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> seg_nets_by_class;
<b>}</b>
<font color='#0000FF'>catch</font><font face='Lucida Console'>(</font>std::exception<font color='#5555FF'>&amp;</font> e<font face='Lucida Console'>)</font>
<b>{</b>
cout <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> e.<font color='#BB00BB'>what</font><font face='Lucida Console'>(</font><font face='Lucida Console'>)</font> <font color='#5555FF'>&lt;</font><font color='#5555FF'>&lt;</font> endl;
<b>}</b>
</pre></body></html>