import gradio as gr
import numpy as np
import random
import spaces
import torch
import time
import os
from diffusers import DiffusionPipeline
from custom_pipeline import FLUXPipelineWithIntermediateOutputs
from transformers import pipeline

# Translation model loading
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")

# Constants
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_INFERENCE_STEPS = 1

# Device and model setup
dtype = torch.float16
pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
    "black-forest-labs/FLUX.1-schnell", torch_dtype=dtype
).to("cuda")
torch.cuda.empty_cache()

# Menu labels dictionary
english_labels = {
    "Generated Image": "Generated Image",
    "Prompt": "Prompt",
    "Enhance Image": "Enhance Image",
    "Advanced Options": "Advanced Options",
    "Seed": "Seed",
    "Randomize Seed": "Randomize Seed",
    "Width": "Width",
    "Height": "Height",
    "Inference Steps": "Inference Steps",
    "Inspiration Gallery": "Inspiration Gallery"
}

def translate_if_korean(text):
    if any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text):
        return translator(text)[0]['translation_text']
    return text

# Modified inference function to always use random seed for examples
@spaces.GPU(duration=25)
def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS):
    prompt = translate_if_korean(prompt)
    
    # Always generate a random seed if none provided or randomize_seed is True
    if seed is None or randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)

    start_time = time.time()

    for img in pipe.generate_images(  
            prompt=prompt,
            guidance_scale=0,
            num_inference_steps=num_inference_steps,
            width=width,
            height=height,
            generator=generator
        ): 
        latency = f"Processing Time: {(time.time()-start_time):.2f} seconds"    
        yield img, seed, latency

# Function specifically for examples that always uses random seeds
def generate_example_image(prompt):
    return generate_image(prompt, randomize_seed=True)

# Example prompts
examples = [
    "비너 슈니첼의 애니메이션 일러스트레이션",
    "A steampunk owl wearing Victorian-era clothing and reading a mechanical book",
    "A floating island made of books with waterfalls of knowledge cascading down",
    "A bioluminescent forest where mushrooms glow like neon signs in a cyberpunk city",
    "An ancient temple being reclaimed by nature, with robots performing archaeology",
    "A cosmic coffee shop where baristas are constellations serving drinks made of stardust"
]

css = """
footer {
    visibility: hidden;
}
"""

# --- Gradio UI ---
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
    with gr.Column(elem_id="app-container"):
        with gr.Row():
            with gr.Column(scale=3):
                result = gr.Image(label=english_labels["Generated Image"], show_label=False, interactive=False)
            with gr.Column(scale=1):
                prompt = gr.Text(
                    label=english_labels["Prompt"],
                    placeholder="Describe the image you want to generate...",
                    lines=3,
                    show_label=False,
                    container=False,
                )
                enhanceBtn = gr.Button(f"🚀 {english_labels['Enhance Image']}")

                with gr.Column(english_labels["Advanced Options"]):
                    with gr.Row():
                        latency = gr.Text(show_label=False)
                    with gr.Row():
                        seed = gr.Number(label=english_labels["Seed"], value=42, precision=0)
                        randomize_seed = gr.Checkbox(label=english_labels["Randomize Seed"], value=True)
                    with gr.Row():
                        width = gr.Slider(label=english_labels["Width"], minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
                        height = gr.Slider(label=english_labels["Height"], minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
                        num_inference_steps = gr.Slider(label=english_labels["Inference Steps"], minimum=1, maximum=4, step=1, value=DEFAULT_INFERENCE_STEPS)

        with gr.Row():
            gr.Markdown(f"### 🌟 {english_labels['Inspiration Gallery']}")
        with gr.Row():
            gr.Examples(
                examples=examples,
                fn=generate_example_image,  # Use the example-specific function
                inputs=[prompt],
                outputs=[result, seed],
                cache_examples=False  # Disable caching to ensure new generation each time
            )

    # Event handling
    enhanceBtn.click(
        fn=generate_image,
        inputs=[prompt, seed, width, height],
        outputs=[result, seed, latency],
        show_progress="hidden",
        show_api=False,
        queue=False
    )

    gr.on(
        triggers=[prompt.input, width.input, height.input, num_inference_steps.input],
        fn=generate_image,
        inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
        outputs=[result, seed, latency],
        show_progress="hidden",
        show_api=False,
        trigger_mode="always_last",
        queue=False
    )

demo.launch()