| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| #include "dnn_semantic_segmentation_ex.h" |
|
|
| #include <iostream> |
| #include <dlib/data_io.h> |
| #include <dlib/gui_widgets.h> |
|
|
| using namespace std; |
| using namespace dlib; |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| const Voc2012class& find_voc2012_class(const uint16_t& index_label) |
| { |
| return find_voc2012_class( |
| [&index_label](const Voc2012class& voc2012class) |
| { |
| return index_label == voc2012class.index; |
| } |
| ); |
| } |
|
|
| |
| inline rgb_pixel index_label_to_rgb_label(uint16_t index_label) |
| { |
| return find_voc2012_class(index_label).rgb_label; |
| } |
|
|
| |
| |
| void index_label_image_to_rgb_label_image( |
| const matrix<uint16_t>& index_label_image, |
| matrix<rgb_pixel>& rgb_label_image |
| ) |
| { |
| const long nr = index_label_image.nr(); |
| const long nc = index_label_image.nc(); |
|
|
| rgb_label_image.set_size(nr, nc); |
|
|
| for (long r = 0; r < nr; ++r) |
| { |
| for (long c = 0; c < nc; ++c) |
| { |
| rgb_label_image(r, c) = index_label_to_rgb_label(index_label_image(r, c)); |
| } |
| } |
| } |
|
|
| |
| std::string get_most_prominent_non_background_classlabel(const matrix<uint16_t>& index_label_image) |
| { |
| const long nr = index_label_image.nr(); |
| const long nc = index_label_image.nc(); |
|
|
| std::vector<unsigned int> counters(class_count); |
|
|
| for (long r = 0; r < nr; ++r) |
| { |
| for (long c = 0; c < nc; ++c) |
| { |
| const uint16_t label = index_label_image(r, c); |
| ++counters[label]; |
| } |
| } |
|
|
| const auto max_element = std::max_element(counters.begin() + 1, counters.end()); |
| const uint16_t most_prominent_index_label = max_element - counters.begin(); |
|
|
| return find_voc2012_class(most_prominent_index_label).classlabel; |
| } |
|
|
| |
|
|
| int main(int argc, char** argv) try |
| { |
| if (argc != 2) |
| { |
| cout << "You call this program like this: " << endl; |
| cout << "./dnn_semantic_segmentation_train_ex /path/to/images" << endl; |
| cout << endl; |
| cout << "You will also need a trained '" << semantic_segmentation_net_filename << "' file." << endl; |
| cout << "You can either train it yourself (see example program" << endl; |
| cout << "dnn_semantic_segmentation_train_ex), or download a" << endl; |
| cout << "copy from here: http://dlib.net/files/" << semantic_segmentation_net_filename << endl; |
| return 1; |
| } |
|
|
| |
| anet_type net; |
| deserialize(semantic_segmentation_net_filename) >> net; |
|
|
| |
| image_window win; |
|
|
| matrix<rgb_pixel> input_image; |
| matrix<uint16_t> index_label_image; |
| matrix<rgb_pixel> rgb_label_image; |
|
|
| |
| const std::vector<file> files = dlib::get_files_in_directory_tree(argv[1], |
| dlib::match_endings(".jpeg .jpg .png")); |
|
|
| cout << "Found " << files.size() << " images, processing..." << endl; |
|
|
| for (const file& file : files) |
| { |
| |
| load_image(input_image, file.full_name()); |
|
|
| |
| |
| |
| const matrix<uint16_t> temp = net(input_image); |
|
|
| |
| const chip_details chip_details( |
| centered_rect(temp.nc() / 2, temp.nr() / 2, input_image.nc(), input_image.nr()), |
| chip_dims(input_image.nr(), input_image.nc()) |
| ); |
| extract_image_chip(temp, chip_details, index_label_image, interpolate_nearest_neighbor()); |
|
|
| |
| index_label_image_to_rgb_label_image(index_label_image, rgb_label_image); |
|
|
| |
| win.set_image(join_rows(input_image, rgb_label_image)); |
|
|
| |
| const std::string classlabel = get_most_prominent_non_background_classlabel(index_label_image); |
|
|
| cout << file.name() << " : " << classlabel << " - hit enter to process the next image"; |
| cin.get(); |
| } |
| } |
| catch(std::exception& e) |
| { |
| cout << e.what() << endl; |
| } |
|
|
|
|