File size: 2,759 Bytes
4a0cc75
8e4035c
88a3ed8
4a0cc75
8e4035c
 
d5e510d
 
8e4035c
 
d5e510d
8e4035c
d5e510d
 
3957fb5
8e4035c
 
4a0cc75
 
 
d5e510d
4a0cc75
d5e510d
 
 
 
4a0cc75
d5e510d
 
 
8e4035c
d5e510d
 
 
 
8e4035c
d5e510d
 
 
 
 
8e4035c
d5e510d
 
 
 
4a0cc75
d5e510d
 
88a3ed8
8e4035c
 
 
 
4a0cc75
ff6b31d
 
d5e510d
8e4035c
d5e510d
 
 
ff6b31d
 
 
d5e510d
 
 
4a0cc75
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image

# Configuración del dispositivo
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"

# Configuración de modelos
model_id_image = "sd-legacy/stable-diffusion-v1-5"
model_id_image_description = "vikhyatk/moondream2"
revision = "2024-08-26"

torch_dtype = torch.float32
if torch.cuda.is_available():
    torch_dtype = torch.bfloat16  # Optimización en GPU

# Carga de modelos persistente
print("Cargando modelo de descripción de imágenes...")
model_description = AutoModelForCausalLM.from_pretrained(model_id_image_description, trust_remote_code=True, revision=revision)
tokenizer_description = AutoTokenizer.from_pretrained(model_id_image_description, revision=revision)

print("Cargando modelo de Stable Diffusion...")
pipe_sd = StableDiffusionPipeline.from_pretrained(model_id_image, torch_dtype=torch_dtype)
pipe_sd = pipe_sd.to(device)

# Opciones para optimizar memoria
pipe_sd.enable_attention_slicing()
if device == "cuda":
    pipe_sd.enable_sequential_cpu_offload()  # Liberar memoria gradualmente para GPUs pequeñas

def generate_description(image_path):
    image_test = Image.open(image_path)
    enc_image = model_description.encode_image(image_test)
    description = model_description.answer_question(enc_image, "Describe this image to create an avatar", tokenizer_description)
    return description

def generate_image_by_description(description, avatar_style=None):
    prompt = f"Create a pigeon profile avatar. Use the following description: {description}."
    if avatar_style:
        prompt += f" Use {avatar_style} style."

    result = pipe_sd(prompt)
    return result.images[0]

def process_and_generate(image, avatar_style):
    description = generate_description(image)
    return generate_image_by_description(description, avatar_style)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=2, min_width=300):
            selected_image = gr.Image(type="filepath", label="Upload an Image of the Pigeon", height=300)
            avatar_style = gr.Radio(
                ["Realistic", "Pixel Art", "Imaginative", "Cartoon"], 
                label="(optional) Select the avatar style:"
            )
            generate_button = gr.Button("Generate Avatar", variant="primary")
        with gr.Column(scale=2, min_width=300):
            generated_image = gr.Image(type="numpy", label="Generated Avatar", height=300)
            
        generate_button.click(process_and_generate, inputs=[selected_image, avatar_style], outputs=generated_image)

demo.launch()