import gradio as gr import torch import numpy as np from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation from PIL import Image # Cargar el modelo y el preprocesador device = torch.device("cpu") model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device) model.eval() preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade") # Función de consulta para Gradio def query_image(img): # Procesar la imagen con el preprocesador inputs = preprocessor(images=img, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) # Obtener la máscara de segmentación (asegúrate de que esta lógica coincida con tu configuración) mask = torch.argmax(outputs.logits[0], dim=0).cpu().detach().numpy() # Crear una máscara binaria solo para la clase de "regla" (de acuerdo a tu código original) rule_class_id = 1 # ID de la clase "regla" rule_mask = (mask == rule_class_id).astype(np.uint8) # Crear una imagen RGB para visualizar la máscara mask_image = np.stack([rule_mask] * 3, axis=-1) return Image.fromarray((mask_image * 255).astype(np.uint8)) # Crear la interfaz Gradio demo = gr.Interface( query_image, inputs=[gr.Image()], outputs="image", title="Rule Segmentation Demo", description="Please upload an image to see rule segmentation", ) # Lanzar la interfaz Gradio demo.launch()