import deeplsd.models.deeplsd_inference as deeplsd_inference
import numpy as np
import torch

from ...settings import DATA_PATH
from ..base_model import BaseModel


class DeepLSD(BaseModel):
    default_conf = {
        "min_length": 15,
        "max_num_lines": None,
        "force_num_lines": False,
        "model_conf": {
            "detect_lines": True,
            "line_detection_params": {
                "merge": False,
                "grad_nfa": True,
                "filtering": "normal",
                "grad_thresh": 3,
            },
        },
    }
    required_data_keys = ["image"]

    def _init(self, conf):
        if self.conf.force_num_lines:
            assert (
                self.conf.max_num_lines is not None
            ), "Missing max_num_lines parameter"
        ckpt = DATA_PATH / "weights/deeplsd_md.tar"
        if not ckpt.is_file():
            self.download_model(ckpt)
        ckpt = torch.load(ckpt, map_location="cpu")
        self.net = deeplsd_inference.DeepLSD(conf.model_conf).eval()
        self.net.load_state_dict(ckpt["model"])
        self.set_initialized()

    def download_model(self, path):
        import subprocess

        if not path.parent.is_dir():
            path.parent.mkdir(parents=True, exist_ok=True)
        link = "https://cvg-data.inf.ethz.ch/DeepLSD/deeplsd_md.tar"
        cmd = ["wget", link, "-O", path]
        print("Downloading DeepLSD model...")
        subprocess.run(cmd, check=True)

    def _forward(self, data):
        image = data["image"]
        lines, line_scores, valid_lines = [], [], []
        if image.shape[1] == 3:
            # Convert to grayscale
            scale = image.new_tensor([0.299, 0.587, 0.114]).view(1, 3, 1, 1)
            image = (image * scale).sum(1, keepdim=True)

        # Forward pass
        with torch.no_grad():
            segs = self.net({"image": image})["lines"]

        # Line scores are the sqrt of the length
        for seg in segs:
            lengths = np.linalg.norm(seg[:, 0] - seg[:, 1], axis=1)
            segs = seg[lengths >= self.conf.min_length]
            scores = np.sqrt(lengths[lengths >= self.conf.min_length])

            # Keep the best lines
            indices = np.argsort(-scores)
            if self.conf.max_num_lines is not None:
                indices = indices[: self.conf.max_num_lines]
                segs = segs[indices]
                scores = scores[indices]

            # Pad if necessary
            n = len(segs)
            valid_mask = np.ones(n, dtype=bool)
            if self.conf.force_num_lines:
                pad = self.conf.max_num_lines - n
                segs = np.concatenate(
                    [segs, np.zeros((pad, 2, 2), dtype=np.float32)], axis=0
                )
                scores = np.concatenate(
                    [scores, np.zeros(pad, dtype=np.float32)], axis=0
                )
                valid_mask = np.concatenate(
                    [valid_mask, np.zeros(pad, dtype=bool)], axis=0
                )

            lines.append(segs)
            line_scores.append(scores)
            valid_lines.append(valid_mask)

        # Batch if possible
        if len(image) == 1 or self.conf.force_num_lines:
            lines = torch.tensor(lines, dtype=torch.float, device=image.device)
            line_scores = torch.tensor(
                line_scores, dtype=torch.float, device=image.device
            )
            valid_lines = torch.tensor(
                valid_lines, dtype=torch.bool, device=image.device
            )

        return {"lines": lines, "line_scores": line_scores, "valid_lines": valid_lines}

    def loss(self, pred, data):
        raise NotImplementedError