import sys
from pathlib import Path
import subprocess
import logging
import torch
from ..utils.base_model import BaseModel

logger = logging.getLogger(__name__)

gluestick_path = Path(__file__).parent / "../../third_party/GlueStick"
sys.path.append(str(gluestick_path))

from gluestick import batch_to_np
from gluestick.models.two_view_pipeline import TwoViewPipeline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class GlueStick(BaseModel):
    default_conf = {
        "name": "two_view_pipeline",
        "model_name": "checkpoint_GlueStick_MD.tar",
        "use_lines": True,
        "max_keypoints": 1000,
        "max_lines": 300,
        "force_num_keypoints": False,
    }
    required_inputs = [
        "image0",
        "image1",
    ]

    gluestick_models = {
        "checkpoint_GlueStick_MD.tar": "https://github.com/cvg/GlueStick/releases/download/v0.1_arxiv/checkpoint_GlueStick_MD.tar",
    }
    # Initialize the line matcher
    def _init(self, conf):
        model_path = gluestick_path / "resources" / "weights" / conf["model_name"]

        # Download the model.
        if not model_path.exists():
            model_path.parent.mkdir(exist_ok=True)
            link = self.gluestick_models[conf["model_name"]]
            cmd = ["wget", link, "-O", str(model_path)]
            logger.info(f"Downloading the Gluestick model with `{cmd}`.")
            subprocess.run(cmd, check=True)
        logger.info(f"Loading GlueStick model...")

        gluestick_conf = {
            "name": "two_view_pipeline",
            "use_lines": True,
            "extractor": {
                "name": "wireframe",
                "sp_params": {
                    "force_num_keypoints": False,
                    "max_num_keypoints": 1000,
                },
                "wireframe_params": {
                    "merge_points": True,
                    "merge_line_endpoints": True,
                },
                "max_n_lines": 300,
            },
            "matcher": {
                "name": "gluestick",
                "weights": str(model_path),
                "trainable": False,
            },
            "ground_truth": {
                "from_pose_depth": False,
            },
        }
        gluestick_conf["extractor"]["sp_params"]["max_num_keypoints"] = conf[
            "max_keypoints"
        ]
        gluestick_conf["extractor"]["sp_params"]["force_num_keypoints"] = conf[
            "force_num_keypoints"
        ]
        gluestick_conf["extractor"]["max_n_lines"] = conf["max_lines"]
        self.net = TwoViewPipeline(gluestick_conf)

    def _forward(self, data):
        pred = self.net(data)

        pred = batch_to_np(pred)
        kp0, kp1 = pred["keypoints0"], pred["keypoints1"]
        m0 = pred["matches0"]

        line_seg0, line_seg1 = pred["lines0"], pred["lines1"]
        line_matches = pred["line_matches0"]

        valid_matches = m0 != -1
        match_indices = m0[valid_matches]
        matched_kps0 = kp0[valid_matches]
        matched_kps1 = kp1[match_indices]

        valid_matches = line_matches != -1
        match_indices = line_matches[valid_matches]
        matched_lines0 = line_seg0[valid_matches]
        matched_lines1 = line_seg1[match_indices]

        pred["raw_lines0"], pred["raw_lines1"] = line_seg0, line_seg1
        pred["lines0"], pred["lines1"] = matched_lines0, matched_lines1
        pred["keypoints0"], pred["keypoints1"] = torch.from_numpy(
            matched_kps0
        ), torch.from_numpy(matched_kps1)
        pred = {**pred, **data}
        return pred