"""run bash scripts/download_models.sh first to prepare the weights file"""
import os
import shutil
from argparse import Namespace
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
from cog import BasePredictor, Input, Path

checkpoints = "checkpoints"


class Predictor(BasePredictor):
    def setup(self):
        """Load the model into memory to make running multiple predictions efficient"""
        device = "cuda"

        path_of_lm_croper = os.path.join(
            checkpoints, "shape_predictor_68_face_landmarks.dat"
        )
        path_of_net_recon_model = os.path.join(checkpoints, "epoch_20.pth")
        dir_of_BFM_fitting = os.path.join(checkpoints, "BFM_Fitting")
        wav2lip_checkpoint = os.path.join(checkpoints, "wav2lip.pth")

        audio2pose_checkpoint = os.path.join(checkpoints, "auido2pose_00140-model.pth")
        audio2pose_yaml_path = os.path.join("src", "config", "auido2pose.yaml")

        audio2exp_checkpoint = os.path.join(checkpoints, "auido2exp_00300-model.pth")
        audio2exp_yaml_path = os.path.join("src", "config", "auido2exp.yaml")

        free_view_checkpoint = os.path.join(
            checkpoints, "facevid2vid_00189-model.pth.tar"
        )

        # init model
        self.preprocess_model = CropAndExtract(
            path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device
        )

        self.audio_to_coeff = Audio2Coeff(
            audio2pose_checkpoint,
            audio2pose_yaml_path,
            audio2exp_checkpoint,
            audio2exp_yaml_path,
            wav2lip_checkpoint,
            device,
        )

        self.animate_from_coeff = {
            "full": AnimateFromCoeff(
                free_view_checkpoint,
                os.path.join(checkpoints, "mapping_00109-model.pth.tar"),
                os.path.join("src", "config", "facerender_still.yaml"),
                device,
            ),
            "others": AnimateFromCoeff(
                free_view_checkpoint,
                os.path.join(checkpoints, "mapping_00229-model.pth.tar"),
                os.path.join("src", "config", "facerender.yaml"),
                device,
            ),
        }

    def predict(
        self,
        source_image: Path = Input(
            description="Upload the source image, it can be video.mp4 or picture.png",
        ),
        driven_audio: Path = Input(
            description="Upload the driven audio, accepts .wav and .mp4 file",
        ),
        enhancer: str = Input(
            description="Choose a face enhancer",
            choices=["gfpgan", "RestoreFormer"],
            default="gfpgan",
        ),
        preprocess: str = Input(
            description="how to preprocess the images",
            choices=["crop", "resize", "full"],
            default="full",
        ),
        ref_eyeblink: Path = Input(
            description="path to reference video providing eye blinking",
            default=None,
        ),
        ref_pose: Path = Input(
            description="path to reference video providing pose",
            default=None,
        ),
        still: bool = Input(
            description="can crop back to the original videos for the full body aniamtion when preprocess is full",
            default=True,
        ),
    ) -> Path:
        """Run a single prediction on the model"""

        animate_from_coeff = (
            self.animate_from_coeff["full"]
            if preprocess == "full"
            else self.animate_from_coeff["others"]
        )

        args = load_default()
        args.pic_path = str(source_image)
        args.audio_path = str(driven_audio)
        device = "cuda"
        args.still = still
        args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
        args.ref_pose = None if ref_pose is None else str(ref_pose)

        # crop image and extract 3dmm from image
        results_dir = "results"
        if os.path.exists(results_dir):
            shutil.rmtree(results_dir)
        os.makedirs(results_dir)
        first_frame_dir = os.path.join(results_dir, "first_frame_dir")
        os.makedirs(first_frame_dir)

        print("3DMM Extraction for source image")
        first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
            args.pic_path, first_frame_dir, preprocess, source_image_flag=True
        )
        if first_coeff_path is None:
            print("Can't get the coeffs of the input")
            return

        if ref_eyeblink is not None:
            ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
                0
            ]
            ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
            os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
            print("3DMM Extraction for the reference video providing eye blinking")
            ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
                ref_eyeblink, ref_eyeblink_frame_dir
            )
        else:
            ref_eyeblink_coeff_path = None

        if ref_pose is not None:
            if ref_pose == ref_eyeblink:
                ref_pose_coeff_path = ref_eyeblink_coeff_path
            else:
                ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
                ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
                os.makedirs(ref_pose_frame_dir, exist_ok=True)
                print("3DMM Extraction for the reference video providing pose")
                ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
                    ref_pose, ref_pose_frame_dir
                )
        else:
            ref_pose_coeff_path = None

        # audio2ceoff
        batch = get_data(
            first_coeff_path,
            args.audio_path,
            device,
            ref_eyeblink_coeff_path,
            still=still,
        )
        coeff_path = self.audio_to_coeff.generate(
            batch, results_dir, args.pose_style, ref_pose_coeff_path
        )
        # coeff2video
        print("coeff2video")
        data = get_facerender_data(
            coeff_path,
            crop_pic_path,
            first_coeff_path,
            args.audio_path,
            args.batch_size,
            args.input_yaw,
            args.input_pitch,
            args.input_roll,
            expression_scale=args.expression_scale,
            still_mode=still,
            preprocess=preprocess,
        )
        animate_from_coeff.generate(
            data, results_dir, args.pic_path, crop_info,
            enhancer=enhancer, background_enhancer=args.background_enhancer,
            preprocess=preprocess)

        output = "/tmp/out.mp4"
        mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
        shutil.copy(mp4_path, output)

        return Path(output)


def load_default():
    return Namespace(
        pose_style=0,
        batch_size=2,
        expression_scale=1.0,
        input_yaw=None,
        input_pitch=None,
        input_roll=None,
        background_enhancer=None,
        face3dvis=False,
        net_recon="resnet50",
        init_path=None,
        use_last_fc=False,
        bfm_folder="./checkpoints/BFM_Fitting/",
        bfm_model="BFM_model_front.mat",
        focal=1015.0,
        center=112.0,
        camera_d=10.0,
        z_near=5.0,
        z_far=15.0,
    )