"""
Implementation of the line segment detection module.
"""
import math
import numpy as np
import torch


class LineSegmentDetectionModule(object):
    """Module extracting line segments from junctions and line heatmaps."""

    def __init__(
        self,
        detect_thresh,
        num_samples=64,
        sampling_method="local_max",
        inlier_thresh=0.0,
        heatmap_low_thresh=0.15,
        heatmap_high_thresh=0.2,
        max_local_patch_radius=3,
        lambda_radius=2.0,
        use_candidate_suppression=False,
        nms_dist_tolerance=3.0,
        use_heatmap_refinement=False,
        heatmap_refine_cfg=None,
        use_junction_refinement=False,
        junction_refine_cfg=None,
    ):
        """
        Parameters:
            detect_thresh: The probability threshold for mean activation (0. ~ 1.)
            num_samples: Number of sampling locations along the line segments.
            sampling_method: Sampling method on locations ("bilinear" or "local_max").
            inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold.
            heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery.
            heatmap_high_thresh: The higher threshold for NMS in junction recovery.
            max_local_patch_radius: The max patch to be considered in local maximum search.
            lambda_radius: The lambda factor in linear local maximum search formulation
            use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments.
            nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line.
            use_heatmap_refinement: Use heatmap refinement method or not.
            heatmap_refine_cfg: The configs for heatmap refinement methods.
            use_junction_refinement: Use junction refinement method or not.
            junction_refine_cfg: The configs for junction refinement methods.
        """
        # Line detection parameters
        self.detect_thresh = detect_thresh

        # Line sampling parameters
        self.num_samples = num_samples
        self.sampling_method = sampling_method
        self.inlier_thresh = inlier_thresh
        self.local_patch_radius = max_local_patch_radius
        self.lambda_radius = lambda_radius

        # Detecting junctions on the boundary parameters
        self.low_thresh = heatmap_low_thresh
        self.high_thresh = heatmap_high_thresh

        # Pre-compute the linspace sampler
        self.sampler = np.linspace(0, 1, self.num_samples)
        self.torch_sampler = torch.linspace(0, 1, self.num_samples)

        # Long line segment suppression configuration
        self.use_candidate_suppression = use_candidate_suppression
        self.nms_dist_tolerance = nms_dist_tolerance

        # Heatmap refinement configuration
        self.use_heatmap_refinement = use_heatmap_refinement
        self.heatmap_refine_cfg = heatmap_refine_cfg
        if self.use_heatmap_refinement and self.heatmap_refine_cfg is None:
            raise ValueError("[Error] Missing heatmap refinement config.")

        # Junction refinement configuration
        self.use_junction_refinement = use_junction_refinement
        self.junction_refine_cfg = junction_refine_cfg
        if self.use_junction_refinement and self.junction_refine_cfg is None:
            raise ValueError("[Error] Missing junction refinement config.")

    def convert_inputs(self, inputs, device):
        """Convert inputs to desired torch tensor."""
        if isinstance(inputs, np.ndarray):
            outputs = torch.tensor(inputs, dtype=torch.float32, device=device)
        elif isinstance(inputs, torch.Tensor):
            outputs = inputs.to(torch.float32).to(device)
        else:
            raise ValueError(
                "[Error] Inputs must either be torch tensor or numpy ndarray."
            )

        return outputs

    def detect(self, junctions, heatmap, device=torch.device("cpu")):
        """Main function performing line segment detection."""
        # Convert inputs to torch tensor
        junctions = self.convert_inputs(junctions, device=device)
        heatmap = self.convert_inputs(heatmap, device=device)

        # Perform the heatmap refinement
        if self.use_heatmap_refinement:
            if self.heatmap_refine_cfg["mode"] == "global":
                heatmap = self.refine_heatmap(
                    heatmap,
                    self.heatmap_refine_cfg["ratio"],
                    self.heatmap_refine_cfg["valid_thresh"],
                )
            elif self.heatmap_refine_cfg["mode"] == "local":
                heatmap = self.refine_heatmap_local(
                    heatmap,
                    self.heatmap_refine_cfg["num_blocks"],
                    self.heatmap_refine_cfg["overlap_ratio"],
                    self.heatmap_refine_cfg["ratio"],
                    self.heatmap_refine_cfg["valid_thresh"],
                )

        # Initialize empty line map
        num_junctions = junctions.shape[0]
        line_map_pred = torch.zeros(
            [num_junctions, num_junctions], device=device, dtype=torch.int32
        )

        # Stop if there are not enough junctions
        if num_junctions < 2:
            return line_map_pred, junctions, heatmap

        # Generate the candidate map
        candidate_map = torch.triu(
            torch.ones(
                [num_junctions, num_junctions], device=device, dtype=torch.int32
            ),
            diagonal=1,
        )

        # Fetch the image boundary
        if len(heatmap.shape) > 2:
            H, W, _ = heatmap.shape
        else:
            H, W = heatmap.shape

        # Optionally perform candidate filtering
        if self.use_candidate_suppression:
            candidate_map = self.candidate_suppression(junctions, candidate_map)

        # Fetch the candidates
        candidate_index_map = torch.where(candidate_map)
        candidate_index_map = torch.cat(
            [candidate_index_map[0][..., None], candidate_index_map[1][..., None]],
            dim=-1,
        )

        # Get the corresponding start and end junctions
        candidate_junc_start = junctions[candidate_index_map[:, 0], :]
        candidate_junc_end = junctions[candidate_index_map[:, 1], :]

        # Get the sampling locations (N x 64)
        sampler = self.torch_sampler.to(device)[None, ...]
        cand_samples_h = candidate_junc_start[:, 0:1] * sampler + candidate_junc_end[
            :, 0:1
        ] * (1 - sampler)
        cand_samples_w = candidate_junc_start[:, 1:2] * sampler + candidate_junc_end[
            :, 1:2
        ] * (1 - sampler)

        # Clip to image boundary
        cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1)
        cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1)

        # Local maximum search
        if self.sampling_method == "local_max":
            # Compute normalized segment lengths
            segments_length = torch.sqrt(
                torch.sum(
                    (
                        candidate_junc_start.to(torch.float32)
                        - candidate_junc_end.to(torch.float32)
                    )
                    ** 2,
                    dim=-1,
                )
            )
            normalized_seg_length = segments_length / (((H**2) + (W**2)) ** 0.5)

            # Perform local max search
            num_cand = cand_h.shape[0]
            group_size = 10000
            if num_cand > group_size:
                num_iter = math.ceil(num_cand / group_size)
                sampled_feat_lst = []
                for iter_idx in range(num_iter):
                    if not iter_idx == num_iter - 1:
                        cand_h_ = cand_h[
                            iter_idx * group_size : (iter_idx + 1) * group_size, :
                        ]
                        cand_w_ = cand_w[
                            iter_idx * group_size : (iter_idx + 1) * group_size, :
                        ]
                        normalized_seg_length_ = normalized_seg_length[
                            iter_idx * group_size : (iter_idx + 1) * group_size
                        ]
                    else:
                        cand_h_ = cand_h[iter_idx * group_size :, :]
                        cand_w_ = cand_w[iter_idx * group_size :, :]
                        normalized_seg_length_ = normalized_seg_length[
                            iter_idx * group_size :
                        ]
                    sampled_feat_ = self.detect_local_max(
                        heatmap, cand_h_, cand_w_, H, W, normalized_seg_length_, device
                    )
                    sampled_feat_lst.append(sampled_feat_)
                sampled_feat = torch.cat(sampled_feat_lst, dim=0)
            else:
                sampled_feat = self.detect_local_max(
                    heatmap, cand_h, cand_w, H, W, normalized_seg_length, device
                )
        # Bilinear sampling
        elif self.sampling_method == "bilinear":
            # Perform bilinear sampling
            sampled_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device)
        else:
            raise ValueError("[Error] Unknown sampling method.")

        # [Simple threshold detection]
        # detection_results is a mask over all candidates
        detection_results = torch.mean(sampled_feat, dim=-1) > self.detect_thresh

        # [Inlier threshold detection]
        if self.inlier_thresh > 0.0:
            inlier_ratio = (
                torch.sum(sampled_feat > self.detect_thresh, dim=-1).to(torch.float32)
                / self.num_samples
            )
            detection_results_inlier = inlier_ratio >= self.inlier_thresh
            detection_results = detection_results * detection_results_inlier

        # Convert detection results back to line_map_pred
        detected_junc_indexes = candidate_index_map[detection_results, :]
        line_map_pred[detected_junc_indexes[:, 0], detected_junc_indexes[:, 1]] = 1
        line_map_pred[detected_junc_indexes[:, 1], detected_junc_indexes[:, 0]] = 1

        # Perform junction refinement
        if self.use_junction_refinement and len(detected_junc_indexes) > 0:
            junctions, line_map_pred = self.refine_junction_perturb(
                junctions, line_map_pred, heatmap, H, W, device
            )

        return line_map_pred, junctions, heatmap

    def refine_heatmap(self, heatmap, ratio=0.2, valid_thresh=1e-2):
        """Global heatmap refinement method."""
        # Grab the top 10% values
        heatmap_values = heatmap[heatmap > valid_thresh]
        sorted_values = torch.sort(heatmap_values, descending=True)[0]
        top10_len = math.ceil(sorted_values.shape[0] * ratio)
        max20 = torch.mean(sorted_values[:top10_len])
        heatmap = torch.clamp(heatmap / max20, min=0.0, max=1.0)
        return heatmap

    def refine_heatmap_local(
        self, heatmap, num_blocks=5, overlap_ratio=0.5, ratio=0.2, valid_thresh=2e-3
    ):
        """Local heatmap refinement method."""
        # Get the shape of the heatmap
        H, W = heatmap.shape
        increase_ratio = 1 - overlap_ratio
        h_block = round(H / (1 + (num_blocks - 1) * increase_ratio))
        w_block = round(W / (1 + (num_blocks - 1) * increase_ratio))

        count_map = torch.zeros(heatmap.shape, dtype=torch.int, device=heatmap.device)
        heatmap_output = torch.zeros(
            heatmap.shape, dtype=torch.float, device=heatmap.device
        )
        # Iterate through each block
        for h_idx in range(num_blocks):
            for w_idx in range(num_blocks):
                # Fetch the heatmap
                h_start = round(h_idx * h_block * increase_ratio)
                w_start = round(w_idx * w_block * increase_ratio)
                h_end = h_start + h_block if h_idx < num_blocks - 1 else H
                w_end = w_start + w_block if w_idx < num_blocks - 1 else W

                subheatmap = heatmap[h_start:h_end, w_start:w_end]
                if subheatmap.max() > valid_thresh:
                    subheatmap = self.refine_heatmap(
                        subheatmap, ratio, valid_thresh=valid_thresh
                    )

                # Aggregate it to the final heatmap
                heatmap_output[h_start:h_end, w_start:w_end] += subheatmap
                count_map[h_start:h_end, w_start:w_end] += 1
        heatmap_output = torch.clamp(heatmap_output / count_map, max=1.0, min=0.0)

        return heatmap_output

    def candidate_suppression(self, junctions, candidate_map):
        """Suppress overlapping long lines in the candidate segments."""
        # Define the distance tolerance
        dist_tolerance = self.nms_dist_tolerance

        # Compute distance between junction pairs
        # (num_junc x 1 x 2) - (1 x num_junc x 2) => num_junc x num_junc map
        line_dist_map = (
            torch.sum(
                (torch.unsqueeze(junctions, dim=1) - junctions[None, ...]) ** 2, dim=-1
            )
            ** 0.5
        )

        # Fetch all the "detected lines"
        seg_indexes = torch.where(torch.triu(candidate_map, diagonal=1))
        start_point_idxs = seg_indexes[0]
        end_point_idxs = seg_indexes[1]
        start_points = junctions[start_point_idxs, :]
        end_points = junctions[end_point_idxs, :]

        # Fetch corresponding entries
        line_dists = line_dist_map[start_point_idxs, end_point_idxs]

        # Check whether they are on the line
        dir_vecs = (end_points - start_points) / torch.norm(
            end_points - start_points, dim=-1
        )[..., None]
        # Get the orthogonal distance
        cand_vecs = junctions[None, ...] - start_points.unsqueeze(dim=1)
        cand_vecs_norm = torch.norm(cand_vecs, dim=-1)
        # Check whether they are projected directly onto the segment
        proj = (
            torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None])
            / line_dists[..., None, None]
        )
        # proj is num_segs x num_junction x 1
        proj_mask = (proj >= 0) * (proj <= 1)
        cand_angles = torch.acos(
            torch.einsum("bij,bjk->bik", cand_vecs, dir_vecs[..., None])
            / cand_vecs_norm[..., None]
        )
        cand_dists = cand_vecs_norm[..., None] * torch.sin(cand_angles)
        junc_dist_mask = cand_dists <= dist_tolerance
        junc_mask = junc_dist_mask * proj_mask

        # Minus starting points
        num_segs = start_point_idxs.shape[0]
        junc_counts = torch.sum(junc_mask, dim=[1, 2])
        junc_counts -= junc_mask[..., 0][
            torch.arange(0, num_segs), start_point_idxs
        ].to(torch.int)
        junc_counts -= junc_mask[..., 0][torch.arange(0, num_segs), end_point_idxs].to(
            torch.int
        )

        # Get the invalid candidate mask
        final_mask = junc_counts > 0
        candidate_map[start_point_idxs[final_mask], end_point_idxs[final_mask]] = 0

        return candidate_map

    def refine_junction_perturb(self, junctions, line_map_pred, heatmap, H, W, device):
        """Refine the line endpoints in a similar way as in LSD."""
        # Get the config
        junction_refine_cfg = self.junction_refine_cfg

        # Fetch refinement parameters
        num_perturbs = junction_refine_cfg["num_perturbs"]
        perturb_interval = junction_refine_cfg["perturb_interval"]
        side_perturbs = (num_perturbs - 1) // 2
        # Fetch the 2D perturb mat
        perturb_vec = torch.arange(
            start=-perturb_interval * side_perturbs,
            end=perturb_interval * (side_perturbs + 1),
            step=perturb_interval,
            device=device,
        )
        w1_grid, h1_grid, w2_grid, h2_grid = torch.meshgrid(
            perturb_vec, perturb_vec, perturb_vec, perturb_vec
        )
        perturb_tensor = torch.cat(
            [
                w1_grid[..., None],
                h1_grid[..., None],
                w2_grid[..., None],
                h2_grid[..., None],
            ],
            dim=-1,
        )
        perturb_tensor_flat = perturb_tensor.view(-1, 2, 2)

        # Fetch the junctions and line_map
        junctions = junctions.clone()
        line_map = line_map_pred

        # Fetch all the detected lines
        detected_seg_indexes = torch.where(torch.triu(line_map, diagonal=1))
        start_point_idxs = detected_seg_indexes[0]
        end_point_idxs = detected_seg_indexes[1]
        start_points = junctions[start_point_idxs, :]
        end_points = junctions[end_point_idxs, :]

        line_segments = torch.cat(
            [start_points.unsqueeze(dim=1), end_points.unsqueeze(dim=1)], dim=1
        )

        line_segment_candidates = (
            line_segments.unsqueeze(dim=1) + perturb_tensor_flat[None, ...]
        )
        # Clip the boundaries
        line_segment_candidates[..., 0] = torch.clamp(
            line_segment_candidates[..., 0], min=0, max=H - 1
        )
        line_segment_candidates[..., 1] = torch.clamp(
            line_segment_candidates[..., 1], min=0, max=W - 1
        )

        # Iterate through all the segments
        refined_segment_lst = []
        num_segments = line_segments.shape[0]
        for idx in range(num_segments):
            segment = line_segment_candidates[idx, ...]
            # Get the corresponding start and end junctions
            candidate_junc_start = segment[:, 0, :]
            candidate_junc_end = segment[:, 1, :]

            # Get the sampling locations (N x 64)
            sampler = self.torch_sampler.to(device)[None, ...]
            cand_samples_h = candidate_junc_start[
                :, 0:1
            ] * sampler + candidate_junc_end[:, 0:1] * (1 - sampler)
            cand_samples_w = candidate_junc_start[
                :, 1:2
            ] * sampler + candidate_junc_end[:, 1:2] * (1 - sampler)

            # Clip to image boundary
            cand_h = torch.clamp(cand_samples_h, min=0, max=H - 1)
            cand_w = torch.clamp(cand_samples_w, min=0, max=W - 1)

            # Perform bilinear sampling
            segment_feat = self.detect_bilinear(heatmap, cand_h, cand_w, H, W, device)
            segment_results = torch.mean(segment_feat, dim=-1)
            max_idx = torch.argmax(segment_results)
            refined_segment_lst.append(segment[max_idx, ...][None, ...])

        # Concatenate back to segments
        refined_segments = torch.cat(refined_segment_lst, dim=0)

        # Convert back to junctions and line_map
        junctions_new = torch.cat(
            [refined_segments[:, 0, :], refined_segments[:, 1, :]], dim=0
        )
        junctions_new = torch.unique(junctions_new, dim=0)
        line_map_new = self.segments_to_line_map(junctions_new, refined_segments)

        return junctions_new, line_map_new

    def segments_to_line_map(self, junctions, segments):
        """Convert the list of segments to line map."""
        # Create empty line map
        device = junctions.device
        num_junctions = junctions.shape[0]
        line_map = torch.zeros([num_junctions, num_junctions], device=device)

        # Iterate through every segment
        for idx in range(segments.shape[0]):
            # Get the junctions from a single segement
            seg = segments[idx, ...]
            junction1 = seg[0, :]
            junction2 = seg[1, :]

            # Get index
            idx_junction1 = torch.where((junctions == junction1).sum(axis=1) == 2)[0]
            idx_junction2 = torch.where((junctions == junction2).sum(axis=1) == 2)[0]

            # label the corresponding entries
            line_map[idx_junction1, idx_junction2] = 1
            line_map[idx_junction2, idx_junction1] = 1

        return line_map

    def detect_bilinear(self, heatmap, cand_h, cand_w, H, W, device):
        """Detection by bilinear sampling."""
        # Get the floor and ceiling locations
        cand_h_floor = torch.floor(cand_h).to(torch.long)
        cand_h_ceil = torch.ceil(cand_h).to(torch.long)
        cand_w_floor = torch.floor(cand_w).to(torch.long)
        cand_w_ceil = torch.ceil(cand_w).to(torch.long)

        # Perform the bilinear sampling
        cand_samples_feat = (
            heatmap[cand_h_floor, cand_w_floor]
            * (cand_h_ceil - cand_h)
            * (cand_w_ceil - cand_w)
            + heatmap[cand_h_floor, cand_w_ceil]
            * (cand_h_ceil - cand_h)
            * (cand_w - cand_w_floor)
            + heatmap[cand_h_ceil, cand_w_floor]
            * (cand_h - cand_h_floor)
            * (cand_w_ceil - cand_w)
            + heatmap[cand_h_ceil, cand_w_ceil]
            * (cand_h - cand_h_floor)
            * (cand_w - cand_w_floor)
        )

        return cand_samples_feat

    def detect_local_max(
        self, heatmap, cand_h, cand_w, H, W, normalized_seg_length, device
    ):
        """Detection by local maximum search."""
        # Compute the distance threshold
        dist_thresh = 0.5 * (2**0.5) + self.lambda_radius * normalized_seg_length
        # Make it N x 64
        dist_thresh = torch.repeat_interleave(
            dist_thresh[..., None], self.num_samples, dim=-1
        )

        # Compute the candidate points
        cand_points = torch.cat([cand_h[..., None], cand_w[..., None]], dim=-1)
        cand_points_round = torch.round(cand_points)  # N x 64 x 2

        # Construct local patches 9x9 = 81
        patch_mask = torch.zeros(
            [
                int(2 * self.local_patch_radius + 1),
                int(2 * self.local_patch_radius + 1),
            ],
            device=device,
        )
        patch_center = torch.tensor(
            [[self.local_patch_radius, self.local_patch_radius]],
            device=device,
            dtype=torch.float32,
        )
        H_patch_points, W_patch_points = torch.where(patch_mask >= 0)
        patch_points = torch.cat(
            [H_patch_points[..., None], W_patch_points[..., None]], dim=-1
        )
        # Fetch the circle region
        patch_center_dist = torch.sqrt(
            torch.sum((patch_points - patch_center) ** 2, dim=-1)
        )
        patch_points = patch_points[patch_center_dist <= self.local_patch_radius, :]
        # Shift [0, 0] to the center
        patch_points = patch_points - self.local_patch_radius

        # Construct local patch mask
        patch_points_shifted = (
            torch.unsqueeze(cand_points_round, dim=2) + patch_points[None, None, ...]
        )
        patch_dist = torch.sqrt(
            torch.sum(
                (torch.unsqueeze(cand_points, dim=2) - patch_points_shifted) ** 2,
                dim=-1,
            )
        )
        patch_dist_mask = patch_dist < dist_thresh[..., None]

        # Get all points => num_points_center x num_patch_points x 2
        points_H = torch.clamp(patch_points_shifted[:, :, :, 0], min=0, max=H - 1).to(
            torch.long
        )
        points_W = torch.clamp(patch_points_shifted[:, :, :, 1], min=0, max=W - 1).to(
            torch.long
        )
        points = torch.cat([points_H[..., None], points_W[..., None]], dim=-1)

        # Sample the feature (N x 64 x 81)
        sampled_feat = heatmap[points[:, :, :, 0], points[:, :, :, 1]]
        # Filtering using the valid mask
        sampled_feat = sampled_feat * patch_dist_mask.to(torch.float32)
        if len(sampled_feat) == 0:
            sampled_feat_lmax = torch.empty(0, 64)
        else:
            sampled_feat_lmax, _ = torch.max(sampled_feat, dim=-1)

        return sampled_feat_lmax