File size: 1,868 Bytes
bf601e4
0976e91
 
bf601e4
d1c29b6
bf601e4
 
 
 
 
 
 
 
 
 
 
 
 
d1c29b6
 
 
 
 
 
 
 
 
0976e91
 
 
 
 
 
 
 
 
 
 
aad7f4e
0976e91
 
 
bf601e4
 
 
 
11b107f
0976e91
bf601e4
0976e91
 
bf601e4
 
 
 
0976e91
bf601e4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import random
import numpy as np
import torch
from torch import nn
from transformers import (SegformerFeatureExtractor,
                          SegformerForSemanticSegmentation)


MODEL_PATH="./best_model_test/"

device = torch.device("cpu")

preprocessor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH)
model.eval()


def upscale_logits(logit_outputs, size):
   """Escala los logits a (4W)x(4H) para recobrar dimensiones originales del input"""
   return nn.functional.interpolate(
       logit_outputs,
       size=size,
       mode="bilinear",
       align_corners=False
   )


def visualize_instance_seg_mask(mask):
    """Agrega colores RGB a cada una de las clases en la mask"""
    image = np.zeros((mask.shape[0], mask.shape[1], 3))
    labels = np.unique(mask)
    label2color = {label: (random.randint(0, 1), 
                           random.randint(0, 255), 	
                           random.randint(0, 255)) for label in labels}
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            image[i, j, :] = label2color[mask[i, j]]
    image = image / 255
    return image


def query_image(img):
    """Función para generar predicciones a la escala origina"""
    inputs = preprocessor(images=img, return_tensors="pt")
    with torch.no_grad():
        preds = model(**inputs)["logits"]
    preds_upscale = upscale_logits(preds, preds.shape[2])
    predict_label = torch.argmax(preds_upscale, dim=1).to(device)
    result = predict_label[0,:,:].detach().cpu().numpy()
    return visualize_instance_seg_mask(result)


demo = gr.Interface(
    query_image,
    inputs=[gr.Image(type="pil")],
    outputs="image",
    title="SegFormer Model for rock glacier image segmentation"
)

demo.launch()