#pragma once #include #include "masked_image.h" class PatchDistanceMetric { public: PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {} virtual ~PatchDistanceMetric() = default; inline int patch_size() const { return m_patch_size; } virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0; static const int kDistanceScale; protected: int m_patch_size; }; class NearestNeighborField { public: NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) { // pass } NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20) : m_source(source), m_target(target), m_distance_metric(metric) { m_field = cv::Mat(m_source.size(), CV_32SC3); _randomize_field(max_retry); } NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20) : m_source(source), m_target(target), m_distance_metric(metric) { m_field = cv::Mat(m_source.size(), CV_32SC3); _initialize_field_from(other, max_retry); } const MaskedImage &source() const { return m_source; } const MaskedImage &target() const { return m_target; } inline cv::Size source_size() const { return m_source.size(); } inline cv::Size target_size() const { return m_target.size(); } inline void set_source(const MaskedImage &source) { m_source = source; } inline void set_target(const MaskedImage &target) { m_target = target; } inline int *mutable_ptr(int y, int x) { return m_field.ptr(y, x); } inline const int *ptr(int y, int x) const { return m_field.ptr(y, x); } inline int at(int y, int x, int c) const { return m_field.ptr(y, x)[c]; } inline int &at(int y, int x, int c) { return m_field.ptr(y, x)[c]; } inline void set_identity(int y, int x) { auto ptr = mutable_ptr(y, x); ptr[0] = y, ptr[1] = x, ptr[2] = 0; } void minimize(int nr_pass); private: inline int _distance(int source_y, int source_x, int target_y, int target_x) { return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x); } void _randomize_field(int max_retry = 20, bool reset = true); void _initialize_field_from(const NearestNeighborField &other, int max_retry); void _minimize_link(int y, int x, int direction); MaskedImage m_source; MaskedImage m_target; cv::Mat m_field; // { y_target, x_target, distance_scaled } const PatchDistanceMetric *m_distance_metric; }; class PatchSSDDistanceMetric : public PatchDistanceMetric { public: using PatchDistanceMetric::PatchDistanceMetric; virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const; static const int kSSDScale; }; class DebugPatchSSDDistanceMetric : public PatchDistanceMetric { public: DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {} virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const; protected: int m_width, m_height; }; class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric { public: RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight) : PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) { assert(m_dy1 == 0); assert(m_dx2 == 0); m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4; } virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const; protected: double m_dx1, m_dy1, m_dx2, m_dy2; double m_scale, m_weight; }; class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric { public: RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight) : PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) { } virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const; protected: cv::Mat m_ijmap; double m_width, m_height, m_weight; };