ivanoctaviogaitansantos's picture
Update app.py
363f1a6 verified
import gradio as gr
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
from PIL import Image
import tempfile
import os
import gc
# --- CONFIGURACIÓN DE MODELOS ---
MODEL_REALISTIC = "stabilityai/stable-diffusion-xl-base-1.0"
MODEL_PONY_REALISM = "john6666/pony-realism-v23-sdxl"
# Detectar dispositivo
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# Variables globales para los pipes
pipe_realistic = None
pipe_pony = None
def load_model(model_id):
"""Carga un modelo SDXL con optimizaciones."""
print(f"⏳ Cargando modelo: {model_id}")
pipe = StableDiffusionXLPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
use_safetensors=True,
low_cpu_mem_usage=True
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# Compilación segura solo si torch >= 2.0 y dispositivo es CUDA
if device == "cuda" and hasattr(torch, "compile"):
try:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
except Exception as e:
print(f"⚠️ No se pudo compilar el modelo: {e}")
pipe.to(device)
print(f"✅ Modelo {model_id.split('/')[-1]} listo")
return pipe
# Cargar modelos al inicio
print("⏳ Cargando modelos iniciales...")
pipe_realistic = load_model(MODEL_REALISTIC)
pipe_pony = load_model(MODEL_PONY_REALISM)
print("✅ Todos los modelos iniciales listos")
# --- FUNCIÓN DE GENERACIÓN ---
def generate_image(prompt, negative_prompt, model_choice, steps, guidance, width, height, seed):
global pipe_realistic, pipe_pony
pipe = pipe_pony if model_choice == "pony" else pipe_realistic
generator = None
if seed >= 0:
generator = torch.Generator(device=pipe.device).manual_seed(seed)
print(f"🔄 Generando imagen con '{model_choice}'...")
try:
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=guidance,
width=width,
height=height,
generator=generator
)
image = result.images[0]
# Guardar temporalmente
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
image.save(temp_file.name, format="JPEG", quality=95)
temp_filepath = temp_file.name
return image, temp_filepath
except Exception as e:
print(f"❌ Error al generar imagen: {e}")
return None, None
finally:
# Liberar memoria
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# --- INTERFAZ GRADIO ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="violet")) as demo:
gr.Markdown("# 🖼️✨ Generador de Imágenes Hiperrealistas con SDXL")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt (Descripción positiva)",
placeholder="Ej: photorealistic portrait in 9:16, cinematic, natural light...",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt (Lo que NO quieres)",
placeholder="Ej: blurry, text, low quality, deformed...",
lines=2
)
with gr.Row():
model_choice = gr.Radio(
choices=[
("Realista Puro", "realistic"),
("Pony Realism", "pony")
],
label="Modelo",
value="realistic"
)
seed = gr.Number(value=-1, label="Seed (-1 = aleatorio)", precision=0)
with gr.Accordion("Ajustes Avanzados", open=False):
steps = gr.Slider(20, 80, value=40, step=5, label="Pasos de inferencia")
guidance = gr.Slider(1, 15, value=7.5, step=0.5, label="Guidance scale (Creatividad)")
with gr.Row():
width = gr.Slider(512, 1024, value=768, step=64, label="Ancho")
height = gr.Slider(512, 1024, value=768, step=64, label="Alto")
btn = gr.Button("🎨 Generar Imagen", variant="primary")
with gr.Column():
output_img = gr.Image(label="Resultado", type="pil", height=512)
download_btn = gr.File(label="Descargar JPG")
btn.click(
fn=generate_image,
inputs=[prompt, negative_prompt, model_choice, steps, guidance, width, height, seed],
outputs=[output_img, download_btn]
)
if __name__ == "__main__":
demo.launch()