import gradio as gr import torch from huggingface_hub import hf_hub_download from torchvision import transforms from PIL import Image import requests import os # URL del modelo en Hugging Face model_url = "https://huggingface.co/macapa/blindness_clas/resolve/main/blindness_model.pth" model_path = "best_model_resnet18.pth" hf_hub_download( repo_id='macapa/blindness_clas', filename='best_model_resnet18.pth', local_dir='.' ) # Cargar el modelo PyTorch model = torch.load(model_path, map_location=torch.device('cpu')) # model.eval() # Definir las transformaciones de la imagen preprocess = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), ]) # Definir las etiquetas de clasificación labels = ["No Blindness", "Mild", "Moderate", "Severe", "Proliferative"] # Función para predecir la clase de ceguera def classify_image(img): img = preprocess(img).unsqueeze(0) with torch.no_grad(): outputs = model(img) _, predicted = torch.max(outputs, 1) return labels[predicted.item()] # Definir la interfaz de Gradio interface = gr.Interface( fn=classify_image, inputs=gr.Image(label="Carga una imagen aquí"), outputs=gr.Label(num_top_classes=1), title="Blindness Classification", description="Classify the severity of blindness from retinal images." ) # Ejecutar la aplicación interface.launch(share=True)