import os
from datetime import datetime
from pathlib import Path

import torch
from diffusers import AutoencoderKL, DDIMScheduler
from einops import repeat
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
from transformers import CLIPVisionModelWithProjection
import torch.nn.functional as F
import gc
from huggingface_hub import hf_hub_download

from musepose.models.pose_guider import PoseGuider
from musepose.models.unet_2d_condition import UNet2DConditionModel
from musepose.models.unet_3d import UNet3DConditionModel
from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from musepose.utils.util import get_fps, read_frames, save_videos_grid


class MusePoseInference:
    def __init__(self):
        self.image_gen_model_paths = {
            "pretrained_base_model": "lambdalabs/sd-image-variations-diffusers/unet",
            "pretrained_vae": "stabilityai/sd-vae-ft-mse",
            "image_encoder": "lambdalabs/sd-image-variations-diffusers/image_encoder",
        }
        self.musepose_model_paths = {
            "denoising_unet": os.path.join("pretrained_weights", "MusePose", "denoising_unet.pth"),
            "reference_unet": os.path.join("pretrained_weights", "MusePose", "reference_unet.pth"),
            "pose_guider": os.path.join("pretrained_weights", "MusePose", "pose_guider.pth"),
            "motion_module": os.path.join("pretrained_weights", "MusePose", "pose_guider.pth"),
        }
        self.inference_config_path = os.path.join("configs", "inference_v2.yaml")
        self.vae = None
        self.reference_unet = None
        self.denoising_unet = None
        self.pose_guider = None
        self.image_enc = None
        self.pipe = None
        self.output_dir = os.path.join("assets", "video")
        #self.download_models()

    def infer_musepose(
        self,
        ref_image_path: str,
        pose_video_path: str,
        weight_dtype: str,
        W: int,
        H: int,
        L: int,
        S: int,
        O: int,
        cfg: float,
        seed: int,
        steps: int,
        fps: int,
        skip: int
    ):
        print(f"Model Paths: {self.musepose_model_paths}\n{self.image_gen_model_paths}\n{self.inference_config_path}")
        print(f"Input Image Path: {ref_image_path}")
        print(f"Pose Video Path: {pose_video_path}")
        print(f"Dtype: {weight_dtype}")
        print(f"Width: {W}")
        print(f"Height: {H}")
        print(f"Video Frame Length: {L}")
        print(f"VIDEO SLICE FRAME LENGTH:: {S}")
        print(f"VIDEO SLICE OVERLAP_FRAME NUMBER: {O}")
        print(f"CFG: {cfg}")
        print(f"Seed: {seed}")
        print(f"Steps: {steps}")
        print(f"FPS: {fps}")
        print(f"Skip: {skip}")

        image_file_name = os.path.splitext(os.path.basename(ref_image_path))[0]
        pose_video_file_name = os.path.splitext(os.path.basename(pose_video_path))[0]
        output_file_name = f"img_{image_file_name}_pose_{pose_video_file_name}"
        output_path = os.path.abspath(os.path.join(self.output_dir, f'{output_file_name}.mp4'))
        output_path_demo = os.path.abspath(os.path.join(self.output_dir, f'{output_file_name}_demo.mp4'))

        if weight_dtype == "fp16":
            weight_dtype = torch.float16
        else:
            weight_dtype = torch.float32

        self.vae = AutoencoderKL.from_pretrained(
            self.image_gen_model_paths["pretrained_vae"],
        ).to("cuda", dtype=weight_dtype)

        self.reference_unet = UNet2DConditionModel.from_pretrained(
            self.image_gen_model_paths["pretrained_base_model"],
            subfolder="unet",
        ).to(dtype=weight_dtype, device="cuda")

        inference_config_path = self.inference_config_path
        infer_config = OmegaConf.load(inference_config_path)

        self.denoising_unet = UNet3DConditionModel.from_pretrained_2d(
            Path(self.image_gen_model_paths["pretrained_base_model"]),
            Path(self.musepose_model_paths["motion_module"]),
            subfolder="unet",
            unet_additional_kwargs=infer_config.unet_additional_kwargs,
        ).to(dtype=weight_dtype, device="cuda")

        self.pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
            dtype=weight_dtype, device="cuda"
        )

        self.image_enc = CLIPVisionModelWithProjection.from_pretrained(
            self.image_gen_model_paths["image_encoder"]
        ).to(dtype=weight_dtype, device="cuda")

        sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
        scheduler = DDIMScheduler(**sched_kwargs)

        generator = torch.manual_seed(seed)

        width, height = W, H

        # load pretrained weights
        self.denoising_unet.load_state_dict(
            torch.load(self.musepose_model_paths["denoising_unet"], map_location="cpu"),
            strict=False,
        )
        self.reference_unet.load_state_dict(
            torch.load(self.musepose_model_paths["reference_unet"], map_location="cpu"),
        )
        self.pose_guider.load_state_dict(
            torch.load(self.musepose_model_paths["pose_guider"], map_location="cpu"),
        )
        self.pipe = Pose2VideoPipeline(
            vae=self.vae,
            image_encoder=self.image_enc,
            reference_unet=self.reference_unet,
            denoising_unet=self.denoising_unet,
            pose_guider=self.pose_guider,
            scheduler=scheduler,
        )
        self.pipe = self.pipe.to("cuda", dtype=weight_dtype)

        print("image: ", ref_image_path, "pose_video: ", pose_video_path)

        ref_image_pil = Image.open(ref_image_path).convert("RGB")

        pose_list = []
        pose_tensor_list = []
        pose_images = read_frames(pose_video_path)
        src_fps = get_fps(pose_video_path)
        print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
        L = min(L, len(pose_images))
        pose_transform = transforms.Compose(
            [transforms.Resize((height, width)), transforms.ToTensor()]
        )
        original_width, original_height = 0, 0

        pose_images = pose_images[::skip + 1]
        print("processing length:", len(pose_images))
        src_fps = src_fps // (skip + 1)
        print("fps", src_fps)
        L = L // ((skip + 1))

        for pose_image_pil in pose_images[: L]:
            pose_tensor_list.append(pose_transform(pose_image_pil))
            pose_list.append(pose_image_pil)
            original_width, original_height = pose_image_pil.size
            pose_image_pil = pose_image_pil.resize((width, height))

        # repeart the last segment
        last_segment_frame_num = (L - S) % (S - O)
        repeart_frame_num = (S - O - last_segment_frame_num) % (S - O)
        for i in range(repeart_frame_num):
            pose_list.append(pose_list[-1])
            pose_tensor_list.append(pose_tensor_list[-1])

        ref_image_tensor = pose_transform(ref_image_pil)  # (c, h, w)
        ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0)  # (1, c, 1, h, w)
        ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L)

        pose_tensor = torch.stack(pose_tensor_list, dim=0)  # (f, c, h, w)
        pose_tensor = pose_tensor.transpose(0, 1)
        pose_tensor = pose_tensor.unsqueeze(0)

        video = self.pipe(
            ref_image_pil,
            pose_list,
            width,
            height,
            len(pose_list),
            steps,
            cfg,
            generator=generator,
            context_frames=S,
            context_stride=1,
            context_overlap=O,
        ).videos

        result = self.scale_video(video[:, :, :L], original_width, original_height)
        save_videos_grid(
            result,
            output_path,
            n_rows=1,
            fps=src_fps if fps is None or fps < 0 else fps,
        )

        video = torch.cat([ref_image_tensor, pose_tensor[:, :, :L], video[:, :, :L]], dim=0)
        video = self.scale_video(video, original_width, original_height)
        save_videos_grid(
            video,
            output_path_demo,
            n_rows=3,
            fps=src_fps if fps is None or fps < 0 else fps,
        )
        self.release_vram()
        return output_path, output_path_demo

    def download_models(self):
        repo_id = 'jhj0517/MusePose'
        for name, file_path in self.musepose_model_paths.items():
            local_dir, filename = os.path.dirname(file_path), os.path.basename(file_path)
            if not os.path.exists(local_dir):
                os.makedirs(local_dir)

            remote_filepath = os.path.join("MusePose", filename)
            if not os.path.exists(file_path):
                hf_hub_download(repo_id=repo_id, filename=remote_filepath,
                                local_dir=local_dir,
                                local_dir_use_symlinks=False)

    def release_vram(self):
        models = [
            'vae', 'reference_unet', 'denoising_unet',
            'pose_guider', 'image_enc', 'pipe'
        ]

        for model_name in models:
            model = getattr(self, model_name, None)
            if model is not None:
                del model
                setattr(self, model_name, None)

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

    @staticmethod
    def scale_video(video, width, height):
        video_reshaped = video.view(-1, *video.shape[2:])  # [batch*frames, channels, height, width]
        scaled_video = F.interpolate(video_reshaped, size=(height, width), mode='bilinear', align_corners=False)
        scaled_video = scaled_video.view(*video.shape[:2], scaled_video.shape[1], height,
                                         width)  # [batch, frames, channels, height, width]

        return scaled_video