pigeon-avatar / app.py
ItzRoBeerT's picture
Update app.py
3957fb5 verified
raw
history blame
2.76 kB
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
# Configuración del dispositivo
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
# Configuración de modelos
model_id_image = "sd-legacy/stable-diffusion-v1-5"
model_id_image_description = "vikhyatk/moondream2"
revision = "2024-08-26"
torch_dtype = torch.float32
if torch.cuda.is_available():
torch_dtype = torch.bfloat16 # Optimización en GPU
# Carga de modelos persistente
print("Cargando modelo de descripción de imágenes...")
model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)
print("Cargando modelo de Stable Diffusion...")
pipe_sd = StableDiffusionPipeline.from_pretrained(model_id_image, torch_dtype=torch_dtype)
pipe_sd = pipe_sd.to(device)
# Opciones para optimizar memoria
pipe_sd.enable_attention_slicing()
if device == "cuda":
pipe_sd.enable_sequential_cpu_offload() # Liberar memoria gradualmente para GPUs pequeñas
def generate_description(image_path):
image_test = Image.open(image_path)
enc_image = model_description.encode_image(image_test)
description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description)
return description
def generate_image_by_description(description, avatar_style=None):
prompt = f"Create a pigeon profile avatar. Use the following description: {description}."
if avatar_style:
prompt += f" Use {avatar_style} style."
result = pipe_sd(prompt)
return result.images[0]
def process_and_generate(image, avatar_style):
description = generate_description(image)
return generate_image_by_description(description, avatar_style)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=2, min_width=300):
selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300)
avatar_style = gr.Radio(
["Realistic", "Pixel Art", "Imaginative", "Cartoon"],
label="(optional) Select the avatar style:"
)
generate_button = gr.Button("Generate Avatar", variant="primary")
with gr.Column(scale=2, min_width=300):
generated_image = gr.Image(type="numpy", label="Generated Avatar", height=300)
generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image)
demo.launch()