import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import DiffusionPipeline
from PIL import Image
import io

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=dtype
).to(device)

MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048

import numpy as np
from collections import Counter

def get_prominent_colors(image, num_colors=5):
    """
    Get the most prominent colors from an image, focusing on edges
    """
    # Convert to numpy array
    img_array = np.array(image)
    
    # Create a simple edge mask using gradient magnitude
    gradient_x = np.gradient(img_array.mean(axis=2))[1]
    gradient_y = np.gradient(img_array.mean(axis=2))[0]
    gradient_magnitude = np.sqrt(gradient_x**2 + gradient_y**2)
    
    # Threshold to get edge pixels
    edge_threshold = np.percentile(gradient_magnitude, 90)  # Adjust percentile as needed
    edge_mask = gradient_magnitude > edge_threshold
    
    # Get colors from edge pixels
    edge_colors = img_array[edge_mask]
    
    # Convert colors to tuples for counting
    colors = [tuple(color) for color in edge_colors]
    
    # Count occurrences of each color
    color_counts = Counter(colors)
    
    # Get most common colors
    prominent_colors = color_counts.most_common(num_colors)
    
    return prominent_colors

def create_tshirt_preview(design_image, tshirt_color="white"):
    """
    Overlay the design onto the existing t-shirt template and color match
    """
    # Load the template t-shirt image
    tshirt = Image.open('image.jpeg')
    tshirt_width, tshirt_height = tshirt.size
    
    # Convert design to PIL Image if it's not already
    if not isinstance(design_image, Image.Image):
        design_image = Image.fromarray(design_image)
    
    # Get prominent colors from the design
    prominent_colors = get_prominent_colors(design_image)
    if prominent_colors:
        # Use the most prominent color for the t-shirt
        main_color = prominent_colors[0][0]  # RGB tuple of most common color
    else:
        # Fallback to white if no colors found
        main_color = (255, 255, 255)
    
    # Convert design to PIL Image if it's not already
    if not isinstance(design_image, Image.Image):
        design_image = Image.fromarray(design_image)
    
    # Resize design to fit nicely on shirt (40% of shirt width)
    design_width = int(tshirt_width * 0.35)  # Adjust this percentage as needed
    design_height = int(design_width * design_image.size[1] / design_image.size[0])
    design_image = design_image.resize((design_width, design_height), Image.Resampling.LANCZOS)
    
    # Calculate position to center design on shirt
    x = (tshirt_width - design_width) // 2
    y = int(tshirt_height * 0.2)  # Adjust this value based on your template
    
    # If design has transparency (RGBA), create mask
    if design_image.mode == 'RGBA':
        mask = design_image.split()[3]
    else:
        mask = None
    
    # Paste design onto shirt
    tshirt.paste(design_image, (x, y), mask)
    
    return tshirt

def enhance_prompt_for_tshirt(prompt, style=None):
    """Add specific terms to ensure good t-shirt designs."""
    style_terms = {
        "minimal": ["simple geometric shapes", "clean lines", "minimalist illustration"],
        "vintage": ["distressed effect", "retro typography", "vintage illustration"],
        "artistic": ["hand-drawn style", "watercolor effect", "artistic illustration"],
        "geometric": ["abstract shapes", "geometric patterns", "modern design"],
        "typography": ["bold typography", "creative lettering", "text-based design"],
        "realistic": ["realistic", "cinematic", "photograph"]
    }
    
    base_terms = [
        "create t-shirt design",
        "with centered composition",
        "high quality",
        "professional design",
        "clear background"
    ]
    
    enhanced_prompt = f"{prompt}, {', '.join(base_terms)}"
    
    if style and style in style_terms:
        style_specific_terms = style_terms[style]
        enhanced_prompt = f"{enhanced_prompt}, {', '.join(style_specific_terms)}"
    
    return enhanced_prompt

@spaces.GPU()
def infer(prompt, style=None, tshirt_color="white", seed=42, randomize_seed=False, 
          width=1024, height=1024, num_inference_steps=4, 
          progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    enhanced_prompt = enhance_prompt_for_tshirt(prompt, style)
    generator = torch.Generator().manual_seed(seed)
    
    # Generate the design
    design_image = pipe(
        prompt=enhanced_prompt,
        width=width,
        height=height,
        num_inference_steps=num_inference_steps,
        generator=generator,
        guidance_scale=0.0
    ).images[0]
    
    # Create t-shirt preview
    tshirt_preview = create_tshirt_preview(design_image, tshirt_color)
    
    return design_image, tshirt_preview, seed

# Available t-shirt colors
TSHIRT_COLORS = {
    "White": "#FFFFFF",
    "Black": "#000000",
    "Navy": "#000080",
    "Gray": "#808080"
}

examples = [
    ["Cool geometric mountain landscape", "minimal", "White"],
    ["Vintage motorcycle with flames", "vintage", "Black"],
    ["flamingo in scenic forset", "realistic", "White"],
    ["Adventure Starts typography", "typography", "White"]
]

styles = [
    "minimal",
    "vintage",
    "artistic",
    "geometric",
    "typography",
    "realistic"
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 1200px !important;
    padding: 20px;
}
.main-title {
    text-align: center;
    color: #2d3748;
    margin-bottom: 1rem;
    font-family: 'Poppins', sans-serif;
}
.subtitle {
    text-align: center;
    color: #4a5568;
    margin-bottom: 2rem;
    font-family: 'Inter', sans-serif;
    font-size: 0.95rem;
    line-height: 1.5;
}
.design-input {
    border: 2px solid #e2e8f0;
    border-radius: 10px;
    padding: 12px !important;
    margin-bottom: 1rem !important;
    font-size: 1rem;
    transition: all 0.3s ease;
}
.results-row {
    display: grid;
    grid-template-columns: 1fr 1fr;
    gap: 20px;
    margin-top: 20px;
}
"""

with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(
            """
            # 👕Deradh's T-Shirt Design Generator
            """,
            elem_classes=["main-title"]
        )
        
        gr.Markdown(
            """
            Create unique t-shirt designs using Deradh's AI. 
            Describe your design idea and select a style to generate professional-quality artwork 
            perfect for custom t-shirts.
            """,
            elem_classes=["subtitle"]
        )
        
        with gr.Row():
            with gr.Column(scale=2):
                prompt = gr.Text(
                    label="Design Description",
                    show_label=False,
                    max_lines=1,
                    placeholder="Describe your t-shirt design idea",
                    container=False,
                    elem_classes=["design-input"]
                )
            with gr.Column(scale=1):
                style = gr.Dropdown(
                    choices=[""] + styles,
                    value="",
                    label="Style",
                    container=False
                )
            with gr.Column(scale=1):
                tshirt_color = gr.Dropdown(
                    choices=list(TSHIRT_COLORS.keys()),
                    value="White",
                    label="T-Shirt Color",
                    container=False
                )
            run_button = gr.Button(
                "✨ Generate",
                scale=0,
                elem_classes=["generate-button"]
            )
        
        with gr.Row(elem_classes=["results-row"]):
            result = gr.Image(
                label="Generated Design",
                show_label=True,
                elem_classes=["result-image"]
            )
            preview = gr.Image(
                label="T-Shirt Preview",
                show_label=True,
                elem_classes=["preview-image"]
            )
        
        with gr.Accordion("🔧 Advanced Settings", open=False):
            with gr.Group():
                seed = gr.Slider(
                    label="Design Seed",
                    minimum=0,
                    maximum=MAX_SEED,
                    step=1,
                    value=0,
                )
                randomize_seed = gr.Checkbox(
                    label="Randomize Design",
                    value=True
                )
                
                with gr.Row():
                    width = gr.Slider(
                        label="Width",
                        minimum=256,
                        maximum=MAX_IMAGE_SIZE,
                        step=32,
                        value=1024,
                    )
                    height = gr.Slider(
                        label="Height",
                        minimum=256,
                        maximum=MAX_IMAGE_SIZE,
                        step=32,
                        value=1024,
                    )
                
                num_inference_steps = gr.Slider(
                    label="Generation Quality (Steps)",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=4,
                )
        
        gr.Examples(
            examples=examples,
            fn=infer,
            inputs=[prompt, style, tshirt_color],
            outputs=[result, preview, seed],
            cache_examples=True
        )
        
        gr.on(
            triggers=[run_button.click, prompt.submit],
            fn=infer,
            inputs=[prompt, style, tshirt_color, seed, randomize_seed, width, height, num_inference_steps],
            outputs=[result, preview, seed]
        )

demo.launch()