alkzar90 commited on
Commit
d1c29b6
1 Parent(s): 11b107f

add upscale function

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import (SegformerFeatureExtractor,
4
  SegformerForSemanticSegmentation)
5
 
@@ -13,13 +14,22 @@ model = SegformerForSemanticSegmentation.from_pretrained(MODEL_PATH)
13
  model.eval()
14
 
15
 
 
 
 
 
 
 
 
 
 
16
  def query_image(img):
17
  """Función para generar predicciones a la escala origina"""
18
  inputs = preprocessor(images=img, return_tensors="pt")
19
  with torch.no_grad():
20
  #preds = model(inputs.unsqueeze(0).to(device))["logits"]
21
  preds = model(**inputs)["logits"]
22
- preds_upscale = upscale_logits_modified(preds, image.shape[2])
23
  predict_label = torch.argmax(preds_upscale, dim=1).to(device)
24
  return predict_label[0,:,:].detach().cpu().numpy()
25
 
 
1
  import gradio as gr
2
  import torch
3
+ from torch import nn
4
  from transformers import (SegformerFeatureExtractor,
5
  SegformerForSemanticSegmentation)
6
 
 
14
  model.eval()
15
 
16
 
17
+ def upscale_logits(logit_outputs, size):
18
+ """Escala los logits a (4W)x(4H) para recobrar dimensiones originales del input"""
19
+ return nn.functional.interpolate(
20
+ logit_outputs,
21
+ size=size,
22
+ mode="bilinear",
23
+ align_corners=False
24
+ )
25
+
26
  def query_image(img):
27
  """Función para generar predicciones a la escala origina"""
28
  inputs = preprocessor(images=img, return_tensors="pt")
29
  with torch.no_grad():
30
  #preds = model(inputs.unsqueeze(0).to(device))["logits"]
31
  preds = model(**inputs)["logits"]
32
+ preds_upscale = upscale_logits(preds, image.shape[2])
33
  predict_label = torch.argmax(preds_upscale, dim=1).to(device)
34
  return predict_label[0,:,:].detach().cpu().numpy()
35