import gradio as gr
import cv2
from component import bbox
from component import skeleton
from component import control

# Load example image
image_example = cv2.cvtColor(cv2.imread("examples/a.png"), cv2.COLOR_BGR2RGB)
uni_height = 800

# Create the interface with a cute theme
with gr.Blocks(theme=gr.themes.Soft()) as interface:
    # Title with kawaii emojis
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown(f"""
            <div style='text-align: center; 
                        padding: 25px; 
                        border-radius: 12px;
                        background: #f3f4f6;
                        box-shadow: 0 10px 25px rgba(0, 0, 0, 0.08),
                                    0 2px 4px rgba(0, 0, 0, 0.05);
                        margin: 20px 0;
                        border: 1px solid rgba(0,0,0,0.1);
                        width: 100%;
                        position: relative;
                        transition: transform 0.2s ease, box-shadow 0.2s ease;'>
                <h1 style='margin: 0; 
                        color: #BE830E;
                        font-family: "Comic Sans MS", cursive;
                        text-shadow: 2px 2px 4px rgba(190, 131, 14, 0.1);
                        letter-spacing: 1.2px;
                        position: relative;
                        z-index: 2;'>
                    🖐️ HandCraft 🖐️
                </h1>
                <h3 style='margin: 10px 0 0;
                        color: #374151;
                        font-family: "Comic Sans MS", cursive;
                        font-weight: 500;
                        position: relative;
                        z-index: 2;
                        line-height: 1.4;
                        text-shadow: 0 1px 2px rgba(0,0,0,0.05);'>
                    🐾✨ Anatomically Correct Restoration of Malformed Hands in Diffusion Generated Images ✨🐾
                </h3>
            </div>
            """)
    
    # Shared input image at the top
    input_image = gr.Image(
        type="numpy", 
        label="📸 Upload Your Image with Hands", 
        height=uni_height,
        value=image_example
    )
    
    # Button to trigger the cascade
    generate_btn = gr.Button("✨🪄 Generate Control Mask 🪄✨", variant="primary", size="lg")
    
    # State variables to store intermediate results
    bbox_mask_state = gr.State()
    keypoints_state = gr.State()
    skeleton_state = gr.State()
    
    # Results section with tabs for each step
    with gr.Tabs():
        with gr.TabItem("🐾 Step 1: Malformed Hand Detection"):
            with gr.Row():
                with gr.Column(scale=1):
                    output_bbox_result = gr.Textbox(label="🏷️ Number of Hands & Classification with Confidence")
                    include_standard = gr.Checkbox(
                        label="🤲 Include Standard Hands", 
                        value=False
                    )
                    expand_ratio = gr.Slider(
                        minimum=0.5, 
                        maximum=2, 
                        step=0.01, 
                        value=1, 
                        label="📏 Bounding Box Expand Ratio"
                    )
                with gr.Column(scale=2):
                    with gr.Row():
                        with gr.Column(scale=1):
                            output_bbox_vis = gr.Image(type="numpy", label="📦 Bounding Box", height=uni_height)
                        with gr.Column(scale=1):
                            output_bbox_mask = gr.Image(type="numpy", label="🎭 Bounding Box Mask", height=uni_height)
        
        with gr.TabItem("💃 Step 2: Body Pose Estimation"):
            with gr.Row():
                with gr.Column(scale=1):
                    output_keypoints = gr.Textbox(label="📊 Key Points String")
                with gr.Column(scale=2):
                    output_skeleton = gr.Image(type="numpy", label="💪 Body Skeleton", height=uni_height)
        
        with gr.TabItem("🎨 Step 3: Control Image Generation"):
            with gr.Row():
                with gr.Column(scale=1):
                    hand_template = gr.Radio(
                        ["opened-palm", "fist-back"], 
                        label="👋 Hand Template", 
                        value="opened-palm"
                    )
                    control_expand = gr.Slider(
                        minimum=0.5, 
                        maximum=2, 
                        step=0.01, 
                        value=1,
                        label="🔍 Control Image Expand Ratio"
                    )
                    include_undetected = gr.Checkbox(
                        label="🔎 Include Undetected Hand", 
                        value=False
                    )
                with gr.Column(scale=2):
                    with gr.Row():
                        with gr.Column():
                            output_viz = gr.Image(type="numpy", label="👁️ Visualization Image", height=300)
                        with gr.Column():
                            output_control = gr.Image(type="numpy", label="🎮 Control Image", height=300)
                    with gr.Row():
                        with gr.Column():
                            output_control_mask = gr.Image(type="numpy", label="🎭 Control Mask", height=300)
                        with gr.Column():
                            output_union_mask = gr.Image(type="numpy", label="🔄 Union Mask", height=300)
                            
        with gr.TabItem("🎉 Output Control Image"):
            gr.Markdown("""
            ### ✨🌈 Control Image 🌈✨
            Your hand-fixed image is ready! (ノ◕ヮ◕)ノ*:・゚✧
            """)
            with gr.Row():
                with gr.Column():
                    output_final_control = gr.Image(type="numpy", label="👐 Fixed Hand Image", interactive=False, height=uni_height)
                with gr.Column():
                    gr.Markdown("""
                    ### 🌟✨ How to Use Your Control Image ✨🌟
                    
                    1. 🪄 Take this Control Image to your favorite Stable Diffusion model
                    2. 🎀 Apply that to the ControlNet
                    3. 🍬 Sprinkle some parameters until it looks just right!
                    
                    """)
    
    # Citation information with cute emojis
    with gr.Accordion("📚✨ Citation Information ✨📚", open=False):
        gr.Markdown("""
        If you find this tool helpful for your research, please cite our paper: 📝
        
        ```bibtex
        @InProceedings{2025_wacv_handcraft,
            author    = {Qin, Zhenyue and Zhang, Yiqun and Liu, Yang and Campbell, Dylan},
            title     = {HandCraft: Anatomically Correct Restoration of Malformed Hands in Diffusion Generated Images},
            booktitle = {Proceedings of the Winter Conference on Applications of Computer Vision (WACV)},
            month     = {February},
            year      = {2025},
            pages     = {3925-3933}
        }
        ```
        
        Thank you for using HandCraft! ✨👐✨
        """)
    
    # Define the step functions with improved data flow
    def run_step1(image, include_std, expand_r):
        # Step 1: Run hand detection
        bbox_result, bbox_vis, bbox_mask = bbox(image, include_std, expand_r)
        return bbox_result, bbox_vis, bbox_mask, bbox_mask
    
    def run_step2(image):
        # Step 2: Run pose estimation
        keypoints, skeleton_img = skeleton(image)
        return keypoints, skeleton_img, keypoints, skeleton_img
    
    def run_step3(image, bbox_mask, keypoints, control_exp, hand_tmpl, skeleton_img, include_undetect):
        # Step 3: Generate control images
        viz, control_img, control_mask, union_mask = control(
            image, bbox_mask, keypoints, control_exp, hand_tmpl, skeleton_img, include_undetect
        )
        return viz, control_img, control_mask, union_mask, control_img
    
    # Connect the Generate button to trigger all steps in sequence
    generate_btn.click(
        fn=run_step1,
        inputs=[input_image, include_standard, expand_ratio],
        outputs=[output_bbox_result, output_bbox_vis, output_bbox_mask, bbox_mask_state]
    ).then(
        fn=run_step2,
        inputs=[input_image],
        outputs=[output_keypoints, output_skeleton, keypoints_state, skeleton_state]
    ).then(
        fn=run_step3,
        inputs=[
            input_image, 
            bbox_mask_state, 
            keypoints_state,
            control_expand,
            hand_template,
            skeleton_state,
            include_undetected
        ],
        outputs=[output_viz, output_control, output_control_mask, output_union_mask, output_final_control]
    )

# Launch the interface
interface.launch(server_name="0.0.0.0", server_port=7860, share=True)