import gradio as gr
from gradio_client import Client
import os 

hf_token = os.environ.get("HF_TKN")

from style_template import styles
STYLE_NAMES = list(styles.keys())
DEFAULT_STYLE_NAME = "(No style)"

def get_instantID(portrait_in, condition_pose, prompt, style):
    client = Client("https://fffiloni-instantid.hf.space/", hf_token=hf_token)
    negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, (frame:1.2), deformed, ugly, deformed eyes, blur, out of focus, blurry, deformed, monochrome, gun, weapon"
    result = client.predict(
		portrait_in,	# filepath  in 'Upload a photo of your face' Image component
		condition_pose,	# filepath  in 'Upload a reference pose image (optional)' Image component
	    prompt,	# str  in 'Prompt' Textbox component
		negative_prompt,	# str  in 'Negative Prompt' Textbox component
		style,	# Literal['(No style)', 'Watercolor', 'Film Noir', 'Neon', 'Jungle', 'Mars', 'Vibrant Color', 'Snow', 'Line art']  in 'Style template' Dropdown component
		True,	# bool  in 'Enhance non-face region' Checkbox component
		20,	# float (numeric value between 20 and 100) in 'Number of sample steps' Slider component
		0.8,	# float (numeric value between 0 and 1.5) in 'IdentityNet strength (for fedility)' Slider component
		0.8,	# float (numeric value between 0 and 1.5) in 'Image adapter strength (for detail)' Slider component
		5,	# float (numeric value between 0.1 and 10.0) in 'Guidance scale' Slider component
		0,	# float (numeric value between 0 and 2147483647) in 'Seed' Slider component
        True,	# bool  in 'Randomize seed' Checkbox component
		api_name="/generate_image"
    )
    print(result)
    return result[0]

def get_video_i2vgen(image_in, prompt):
    client = Client("https://modelscope-i2vgen-xl.hf.space/")
    result = client.predict(
        image_in,
        prompt,
        fn_index=1
    )
    print(result)
    return result

def get_video_svd(image_in):
    from gradio_client import Client

    client = Client("https://multimodalart-stable-video-diffusion.hf.space/--replicas/ej45m/")
    result = client.predict(
		image_in,	# filepath  in 'Upload your image' Image component
		0,	# float (numeric value between 0 and 9223372036854775807) in 'Seed' Slider component
		True,	# bool  in 'Randomize seed' Checkbox component
		127,	# float (numeric value between 1 and 255) in 'Motion bucket id' Slider component
		6,	# float (numeric value between 5 and 30) in 'Frames per second' Slider component
		api_name="/video"
    )
    print(result)
    return result[0]["video"]

def load_sample_shot(camera_shot):
    if camera_shot == "close-up":
        conditional_pose = "camera_shots/close_up_shot.jpeg"
    elif camera_shot == "medium close-up":
        conditional_pose = "camera_shots/medium_close_up.jpeg"
    elif camera_shot == "medium shot":
        conditional_pose = "camera_shots/medium_shot.png"
    elif camera_shot == "cowboy shot":
        conditional_pose = "camera_shots/cowboy_shot.jpeg"
    elif camera_shot == "medium full shot":
        conditional_pose = "camera_shots/medium_full_shot.png"
    elif camera_shot == "full shot":
        conditional_pose = "camera_shots/full_shot.jpeg"
    elif camera_shot == "custom":
        conditional_pose = None
    return conditional_pose

def use_custom_cond():
    return "custom"

def get_short_caption(image_in):
    client = Client("https://vikhyatk-moondream1.hf.space/")
    result = client.predict(
		image_in,	# filepath  in 'image' Image component
		"Describe what is happening in one sentence",	# str  in 'Question' Textbox component
		api_name="/answer_question"
    )
    print(result)
    return result

def infer(image_in, camera_shot, conditional_pose, prompt, style, chosen_model):
    
    if camera_shot == "custom":
        if conditional_pose != None:
            conditional_pose = conditional_pose
        else :
            raise gr.Error("No custom conditional shot found !")
    
    
    iid_img = get_instantID(image_in, conditional_pose, prompt, style)

    short_cap = get_short_caption(iid_img)
    
    if chosen_model == "i2vgen-xl" :
        video_res = get_video_i2vgen(iid_img, short_cap)
    elif chosen_model == "stable-video" :
        video_res = get_video_svd(image_in)
    
    print(video_res)
    
    return video_res


css = """
#col-container{
    margin: 0 auto;
    max-width: 1080px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML("""
        <h2 style="text-align: center;">
            InstantID-2V
        </h2>
        <p style="text-align: center;">
            Generate alive camera shot from input face
        </p>
        """)
        
        with gr.Row():
            with gr.Column():
                face_in = gr.Image(type="filepath", label="Face to copy", value="monalisa.png")
            with gr.Column():
                with gr.Row():
                    camera_shot = gr.Dropdown(
                        label = "Camera Shot", 
                        info = "Use standard camera shots vocabulary, or drop your custom shot as conditional pose (1280*720 ratio is recommended)",
                        choices = [
                            "custom", "close-up", "medium close-up", "medium shot", "cowboy shot", "medium full shot", "full shot"
                        ],
                        value = "custom"
                    )
                    style = gr.Dropdown(label="Style template", info="InstantID legacy templates", choices=STYLE_NAMES, value=DEFAULT_STYLE_NAME)

                condition_shot = gr.Image(type="filepath", label="Custom conditional shot (Important) [1280*720 recommended]")
                prompt = gr.Textbox(label="Short Prompt (keeping it short is better)")
                chosen_model = gr.Radio(label="Choose a model", choices=["i2vgen-xl", "stable-video"], value="i2vgen-xl", interactive=False, visible=False)
            
        with gr.Column():
            submit_btn = gr.Button("Submit")
            video_out = gr.Video()

    
    camera_shot.change(
        fn = load_sample_shot,
        inputs = camera_shot,
        outputs = condition_shot,
        queue=False
    )
    condition_shot.clear(
        fn = use_custom_cond,
        inputs = None,
        outputs = camera_shot,
        queue=False,
    )
    submit_btn.click(
        fn = infer,
        inputs = [
            face_in,
            camera_shot,
            condition_shot,
            prompt,
            style,
            chosen_model
        ],
        outputs = [
            video_out
        ]
    )

demo.queue(max_size=3).launch(share=False, show_error=True, show_api=False)