alkzar90's picture
Resize input - output display 256x256
dd515ab
raw history blame
No virus
2.28 kB
import gradio as gr
import random
import numpy as np
import torch
from torch import nn
from torchvision import transforms
from transformers import SegformerForSemanticSegmentation
MODEL_PATH="./best_model_mixto/"
device = torch.device("cpu")
preprocessor = transforms.Compose([
transforms.Resize(128),
transforms.ToTensor()
])
model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH)
model.eval()
def upscale_logits(logit_outputs, size):
"""Escala los logits a (4W)x(4H) para recobrar dimensiones originales del input"""
return nn.functional.interpolate(
logit_outputs,
size=size,
mode="bilinear",
align_corners=False
)
def visualize_instance_seg_mask(mask):
"""Agrega colores RGB a cada una de las clases en la mask"""
image = np.zeros((mask.shape[0], mask.shape[1], 3))
labels = np.unique(mask)
label2color = {label: (random.randint(0, 1),
random.randint(0, 255),
random.randint(0, 255)) for label in labels}
for i in range(image.shape[0]):
for j in range(image.shape[1]):
image[i, j, :] = label2color[mask[i, j]]
image = image / 255
return image
def query_image(img):
"""Función para generar predicciones a la escala origina"""
inputs = preprocessor(img).unsqueeze(0)
with torch.no_grad():
preds = model(inputs)["logits"]
preds_upscale = upscale_logits(preds, preds.shape[2])
predict_label = torch.argmax(preds_upscale, dim=1).to(device)
result = predict_label[0,:,:].detach().cpu().numpy()
return visualize_instance_seg_mask(result)
demo = gr.Interface(
query_image,
inputs=[gr.Image(type="pil").style(full_width=True, height=256, width=256)],
outputs=[gr.Image().style(full_width=True, height=256, width=256)],
title="Skyguard: segmentador de glaciares de roca 🛰️ +️ 🛡️ ️",
description="Modelo de segmentación de imágenes para detectar glaciares de roca.<br> Se entrenó un modelo [nvidia/SegFormer](https://huggingface.co/nvidia/mit-b0) con _fine-tuning_ en el [rock-glacier-dataset](https://huggingface.co/datasets/alkzar90/rock-glacier-dataset)"
)
demo.launch()