leonett's picture
Create app.py
a6cd9b7 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from facenet_pytorch import MTCNN
from model import CAAE # Modelo del repositorio Face-Aging-CAAE
# Configuraci贸n del modelo (CPU)
device = torch.device("cpu")
model_path = "model/CAAE_MORPH.pth"
model = CAAE(latent_dim=128).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
mtcnn = MTCNN(image_size=128, margin=0, min_face_size=20, device=device)
def align_and_preprocess(image):
"""Alinear y preprocesar la imagen para el modelo."""
img = Image.fromarray(image).convert("RGB")
detected_face = mtcnn(img)
if detected_face is None:
raise ValueError("No se detect贸 una cara en la imagen.")
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
return transform(detected_face).unsqueeze(0).to(device)
def generate_aged_image(input_image, target_age):
"""Generar imagen envejecida."""
try:
input_tensor = align_and_preprocess(input_image)
age_tensor = torch.tensor([[target_age / 100.0]], dtype=torch.float32).to(device)
with torch.no_grad():
output = model(input_tensor, age_tensor)
output_image = output.squeeze().permute(1, 2, 0).cpu().numpy()
output_image = np.clip(output_image * 255, 0, 255).astype(np.uint8)
return Image.fromarray(output_image).resize((input_image.width, input_image.height))
except Exception as e:
return f"Error: {str(e)}"
# Interfaz de Gradio
def app():
interface = gr.Interface(
fn=generate_aged_image,
inputs=[
gr.Image(label="Imagen de entrada", type="pil"),
gr.Slider(0, 100, value=30, step=1, label="Edad objetivo")
],
outputs=gr.Image(label="Resultado", type="pil"),
examples=[
["example_images/input.jpg", 40], # Ejemplo 1
["example_images/input2.jpg", 60] # Ejemplo 2
],
title="Envejecimiento Facial",
description="Carga una imagen, elige una edad y genera su versi贸n envejecida."
)
return interface
if __name__ == "__main__":
# Aseg煤rate de tener las dependencias instaladas:
# pip install gradio torch torchvision facenet-pytorch numpy
# Descarga el modelo de Face-Aging-CAAE:
# git clone https://github.com/ZZUTK/Face-Aging-CAAE.git
# wget -P Face-Aging-CAAE/model/ https://raw.githubusercontent.com/ZZUTK/Face-Aging-CAAE/master/model/CAAE_MORPH.pth
app().launch()