Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from torchvision import models, transforms | |
from PIL import Image | |
import json | |
# Cargar la configuración | |
#with open("reconocimiento_facial/config.json") as f: | |
with open("config.json") as f: | |
config = json.load(f) | |
# Definir las etiquetas de las clases | |
class_labels = ["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"] | |
# Crear el modelo y ajustar la última capa | |
model = models.resnet50() | |
model.fc = torch.nn.Linear(model.fc.in_features, len(class_labels)) # len(class_labels) = 7 | |
# Cargar los pesos del modelo | |
model.load_state_dict(torch.load("pytorch_model.bin", map_location=torch.device('cpu'))) | |
model.eval() | |
# Preprocesamiento de imágenes | |
transform = transforms.Compose([ | |
transforms.Resize((config["image_size"], config["image_size"])), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=config["transformations"]["Normalize"]["mean"], std=config["transformations"]["Normalize"]["std"]) | |
]) | |
def predict(image): | |
image = transform(image).unsqueeze(0) | |
with torch.no_grad(): | |
output = model(image) | |
_, predicted = torch.max(output, 1) | |
predicted_class = class_labels[predicted.item()] | |
return predicted_class | |
# Crear la interfaz de Gradio | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Textbox(label="Predicted Expression"), | |
title="Reconocimiento Facial de Expresiones", | |
description="Sube una imagen de una cara para clasificar la expresión facial en una de las siete categorías: Enfadado, Disgustado, Miedo, Feliz, Triste, Sorprendido y Neutral." | |
) | |
# Ejecutar la interfaz | |
iface.launch() |