from scipy.interpolate import interp1d, PchipInterpolator

import numpy as np
from PIL import Image
import cv2
import torch


def sift_match(
    img1, img2,
    thr=0.5, 
    topk=5, method="max_dist",
    output_path="sift_matches.png",
):
    
    assert method in ["max_dist", "random", "max_score", "max_score_even"]

    # img1 and img2 are PIL images
    # small threshold means less points

    # 1. to cv2 grayscale image
    img1_rgb = np.array(img1).copy()
    img2_rgb = np.array(img2).copy()
    img1 = cv2.cvtColor(np.array(img1), cv2.COLOR_RGB2BGR)
    img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
    img2 = cv2.cvtColor(np.array(img2), cv2.COLOR_RGB2BGR)
    img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)

    # 2. use sift to extract keypoints and descriptors
    # Initiate SIFT detector
    sift = cv2.SIFT_create()
    # find the keypoints and descriptors with SIFT
    kp1, des1 = sift.detectAndCompute(img1, None)
    kp2, des2 = sift.detectAndCompute(img2, None)
    # BFMatcher with default params
    bf = cv2.BFMatcher()
    # bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)
    # bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
    matches = bf.knnMatch(des1, des2, k=2)

    # Apply ratio test
    good = []
    point_list = []
    distance_list = []

    if method in ['max_score', 'max_score_even']:
        matches = sorted(matches, key=lambda x: x[0].distance / x[1].distance)

        anchor_points_list = []
        for m, n in matches[:topk]:
            print(m.distance / n.distance)

            # check evenly distributed
            if method == 'max_score_even':
                to_close = False
                for anchor_point in anchor_points_list:
                    pt1 = kp1[m.queryIdx].pt
                    dist = np.linalg.norm(np.array(pt1) - np.array(anchor_point))
                    if dist < 50:
                        to_close = True
                        break
                if to_close:
                    continue

            good.append([m])

            pt1 = kp1[m.queryIdx].pt
            pt2 = kp2[m.trainIdx].pt
            dist = np.linalg.norm(np.array(pt1) - np.array(pt2))
            distance_list.append(dist)

            anchor_points_list.append(pt1)

            pt1 = torch.tensor(pt1)
            pt2 = torch.tensor(pt2)
            pt = torch.stack([pt1, pt2])  # (2, 2)
            point_list.append(pt)

    if method in ['max_dist', 'random']:
        for m, n in matches:
            if m.distance < thr * n.distance:
                good.append([m])

                pt1 = kp1[m.queryIdx].pt
                pt2 = kp2[m.trainIdx].pt
                dist = np.linalg.norm(np.array(pt1) - np.array(pt2))
                distance_list.append(dist)

                pt1 = torch.tensor(pt1)
                pt2 = torch.tensor(pt2)
                pt = torch.stack([pt1, pt2])  # (2, 2)
                point_list.append(pt)

        distance_list = np.array(distance_list)
        # only keep the points with the largest topk distance
        idx = np.argsort(distance_list)
        if method == "max_dist":
            idx = idx[-topk:]
        elif method == "random":
            topk = min(topk, len(idx))
            idx = np.random.choice(idx, topk, replace=False)
        elif method == "max_score":
            import pdb; pdb.set_trace()
            raise NotImplementedError
            # idx = np.argsort(distance_list)[:topk]
        else:
            raise ValueError(f"Unknown method {method}")

        point_list = [point_list[i] for i in idx]
        good = [good[i] for i in idx]

    # # cv2.drawMatchesKnn expects list of lists as matches.
    # draw_params = dict(
    #     matchColor=(255, 0, 0),
    #     singlePointColor=None,
    #     flags=2,
    # )
    # img3 = cv2.drawMatchesKnn(img1, kp1, img2, kp2, good, None, **draw_params)


    # # manually draw the matches, the images are put in horizontal
    # img3 = np.concatenate([img1_rgb, img2_rgb], axis=1)  # (h, 2w, 3)
    # for m in good:
    #     pt1 = kp1[m[0].queryIdx].pt
    #     pt2 = kp2[m[0].trainIdx].pt
    #     pt1 = (int(pt1[0]), int(pt1[1]))
    #     pt2 = (int(pt2[0]) + img1_rgb.shape[1], int(pt2[1]))
    #     cv2.line(img3, pt1, pt2, (255, 0, 0), 1)

    # manually draw the matches, the images are put in vertical. with 10 pixels margin
    margin = 10
    img3 = np.zeros((img1_rgb.shape[0] + img2_rgb.shape[0] + margin, max(img1_rgb.shape[1], img2_rgb.shape[1]), 3), dtype=np.uint8)
    # the margin is white
    img3[:, :] = 255
    img3[:img1_rgb.shape[0], :img1_rgb.shape[1]] = img1_rgb
    img3[img1_rgb.shape[0] + margin:, :img2_rgb.shape[1]] = img2_rgb
    # create a color list of 6 different colors
    color_list = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)]
    for color_idx, m in enumerate(good):
        pt1 = kp1[m[0].queryIdx].pt
        pt2 = kp2[m[0].trainIdx].pt
        pt1 = (int(pt1[0]), int(pt1[1]))
        pt2 = (int(pt2[0]), int(pt2[1]) + img1_rgb.shape[0] + margin)
        # cv2.line(img3, pt1, pt2, (255, 0, 0), 1)
        # avoid the zigzag artifact in line
        # random_color = tuple(np.random.randint(0, 255, 3).tolist())
        color = color_list[color_idx % len(color_list)]
        cv2.line(img3, pt1, pt2, color, 1, lineType=cv2.LINE_AA)
        # add a empty circle to both start and end points
        cv2.circle(img3, pt1, 3, color, lineType=cv2.LINE_AA)
        cv2.circle(img3, pt2, 3, color, lineType=cv2.LINE_AA)

    Image.fromarray(img3).save(output_path)
    print(f"Save the sift matches to {output_path}")

    # (f, topk, 2), f=2 (before interpolation)
    if len(point_list) == 0:
        return None

    point_list = torch.stack(point_list)
    point_list = point_list.permute(1, 0, 2)

    return point_list


def interpolate_trajectory(points_torch, num_frames, t=None):
    # points:(f, topk, 2), f=2 (before interpolation)

    num_points = points_torch.shape[1]
    points_torch = points_torch.permute(1, 0, 2)  # (topk, f, 2)

    points_list = []
    for i in range(num_points):
        # points:(f, 2)
        points = points_torch[i].cpu().numpy()

        x = [point[0] for point in points]
        y = [point[1] for point in points]

        if t is None:
            t = np.linspace(0, 1, len(points))

        # fx = interp1d(t, x, kind='cubic')
        # fy = interp1d(t, y, kind='cubic')
        fx = PchipInterpolator(t, x)
        fy = PchipInterpolator(t, y)

        new_t = np.linspace(0, 1, num_frames)

        new_x = fx(new_t)
        new_y = fy(new_t)
        new_points = list(zip(new_x, new_y))

        points_list.append(new_points)

    points = torch.tensor(points_list)  # (topk, num_frames, 2)
    points = points.permute(1, 0, 2)  # (num_frames, topk, 2)

    return points


# diffusion feature matching
def point_tracking(
    F0,
    F1,
    handle_points,
    handle_points_init,
    track_dist=5,
):
    # handle_points: (num_points, 2)
    # NOTE: 
    # 1. all row and col are reversed 
    # 2. handle_points in (y, x), not (x, y)

    # reverse row and col
    handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1)
    handle_points_init = torch.stack([handle_points_init[:, 1], handle_points_init[:, 0]], dim=-1)

    with torch.no_grad():
        _, _, max_r, max_c = F0.shape

        for i in range(len(handle_points)):
            pi0, pi = handle_points_init[i], handle_points[i]
            f0 = F0[:, :, int(pi0[0]), int(pi0[1])]

            r1, r2 = max(0, int(pi[0]) - track_dist), min(max_r, int(pi[0]) + track_dist + 1)
            c1, c2 = max(0, int(pi[1]) - track_dist), min(max_c, int(pi[1]) + track_dist + 1)
            F1_neighbor = F1[:, :, r1:r2, c1:c2]
            all_dist = (f0.unsqueeze(dim=-1).unsqueeze(dim=-1) - F1_neighbor).abs().sum(dim=1)
            all_dist = all_dist.squeeze(dim=0)
            row, col = divmod(all_dist.argmin().item(), all_dist.shape[-1])
            # handle_points[i][0] = pi[0] - track_dist + row
            # handle_points[i][1] = pi[1] - track_dist + col
            handle_points[i][0] = r1 + row
            handle_points[i][1] = c1 + col

        handle_points = torch.stack([handle_points[:, 1], handle_points[:, 0]], dim=-1)  # (num_points, 2)

        return handle_points