Margaritamawyin commited on
Commit
43cf32d
1 Parent(s): ac0605a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
5
+ from PIL import Image
6
+
7
+ # Cargar el modelo y el preprocesador
8
+ device = torch.device("cpu")
9
+ model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
10
+ model.eval()
11
+ preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")
12
+
13
+ # Función de consulta para Gradio
14
+ def query_image(img):
15
+ # Procesar la imagen con el preprocesador
16
+ inputs = preprocessor(images=img, return_tensors="pt")
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+
20
+ # Obtener la máscara de segmentación (asegúrate de que esta lógica coincida con tu configuración)
21
+ mask = torch.argmax(outputs.logits[0], dim=0).cpu().detach().numpy()
22
+
23
+ # Crear una máscara binaria solo para la clase de "regla" (de acuerdo a tu código original)
24
+ rule_class_id = 1 # ID de la clase "regla"
25
+ rule_mask = (mask == rule_class_id).astype(np.uint8)
26
+
27
+ # Crear una imagen RGB para visualizar la máscara
28
+ mask_image = np.stack([rule_mask] * 3, axis=-1)
29
+
30
+ return Image.fromarray((mask_image * 255).astype(np.uint8))
31
+
32
+ # Crear la interfaz Gradio
33
+ demo = gr.Interface(
34
+ query_image,
35
+ inputs=[gr.Image()],
36
+ outputs="image",
37
+ title="Rule Segmentation Demo",
38
+ description="Please upload an image to see rule segmentation",
39
+ )
40
+
41
+ # Lanzar la interfaz Gradio
42
+ demo.launch()