File size: 3,981 Bytes
bf601e4
fb964ec
0976e91
 
bf601e4
d1c29b6
3237429
cb54c63
bf601e4
7751dfb
 
 
 
 
 
bf601e4
7751dfb
 
56744a5
 
 
 
7751dfb
 
7026019
bf601e4
 
 
3237429
9758099
3237429
 
bf601e4
 
 
7751dfb
d1c29b6
 
 
 
 
 
 
 
 
0976e91
 
 
 
 
 
 
 
 
 
 
aad7f4e
0976e91
 
 
bf601e4
 
3237429
bf601e4
3237429
0976e91
bf601e4
0976e91
 
bf601e4
7751dfb
bf601e4
 
334cac1
a5218e8
3237429
7751dfb
d427ed3
 
7751dfb
bf601e4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import gradio as gr
import os
import random
import numpy as np
import torch
from torch import nn
from torchvision import transforms
from transformers import SegformerForSemanticSegmentation

# examples
os.system("wget -O 073.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/073.png")
os.system("wget -O 356.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/356.png")
os.system("wget -O 599.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/599.png")
os.system("wget -O 630.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/630.png")
os.system("wget -O 673.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/buenos_resultados/673.png")


os.system("wget -O 019.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/019.png")
os.system("wget -O 261.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/261.png")
os.system("wget -O 524.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/524.png")
os.system("wget -O 716.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/716.png")
os.system("wget -O 898.png https://huggingface.co/spaces/alkzar90/rock-glacier-segmentation/resolve/main/example_images/malos_resultados/898.png")

# model-setting
MODEL_PATH="./best_model_mixto/"

device = torch.device("cpu")

preprocessor = transforms.Compose([
                        transforms.Resize(128),
                        transforms.ToTensor()
                 ])
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH)
model.eval()

# inference-functions
def upscale_logits(logit_outputs, size):
   """Escala los logits a (4W)x(4H) para recobrar dimensiones originales del input"""
   return nn.functional.interpolate(
       logit_outputs,
       size=size,
       mode="bilinear",
       align_corners=False
   )


def visualize_instance_seg_mask(mask):
    """Agrega colores RGB a cada una de las clases en la mask"""
    image = np.zeros((mask.shape[0], mask.shape[1], 3))
    labels = np.unique(mask)
    label2color = {label: (random.randint(0, 1), 
                           random.randint(0, 255), 	
                           random.randint(0, 255)) for label in labels}
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            image[i, j, :] = label2color[mask[i, j]]
    image = image / 255
    return image


def query_image(img):
    """Función para generar predicciones a la escala origina"""
    inputs = preprocessor(img).unsqueeze(0)
    with torch.no_grad():
        preds = model(inputs)["logits"]
    preds_upscale = upscale_logits(preds, preds.shape[2])
    predict_label = torch.argmax(preds_upscale, dim=1).to(device)
    result = predict_label[0,:,:].detach().cpu().numpy()
    return visualize_instance_seg_mask(result)

# demo
demo = gr.Interface(
    query_image,
    inputs=[gr.Image(type="pil").style(full_width=True, height=256)],
    outputs=[gr.Image().style(full_width=True, height=256)],
    title="Skyguard: segmentador de glaciares de roca  🛰️ +️ 🛡️ ️",
    description="Modelo de segmentación de imágenes para detectar glaciares de roca.<br> Se entrenó un modelo [nvidia/SegFormer](https://huggingface.co/nvidia/mit-b0) con _fine-tuning_ en el [rock-glacier-dataset](https://huggingface.co/datasets/alkzar90/rock-glacier-dataset)",
    examples=[["073.png"], ["356.png"], ["599.png"], ["630.png"], ["673.png"], 
              ["019.png"], ["261.png"], ["524.png"], ["716.png"], ["898.png"]],
    cache_examples=False
)

demo.launch()