karolmajek commited on
Commit
9dcc578
1 Parent(s): a35f3d3
Files changed (1) hide show
  1. app.py +5 -6
app.py CHANGED
@@ -2,7 +2,7 @@ from matplotlib.pyplot import axis
2
  import gradio as gr
3
  import requests
4
  import numpy as np
5
- # from torch import nn
6
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
7
  import requests
8
 
@@ -30,11 +30,10 @@ def inference(image):
30
  outputs = model(**inputs)
31
 
32
  # First, rescale logits to original image size
33
- # logits = nn.functional.interpolate(outputs.logits.detach().cpu(),
34
- # size=image.size[::-1], # (height, width)
35
- # mode='bilinear',
36
- # align_corners=False)
37
- logits = outputs.logits.detach().cpu()
38
 
39
  # Second, apply argmax on the class dimension
40
  seg = logits.argmax(dim=1)[0]
2
  import gradio as gr
3
  import requests
4
  import numpy as np
5
+ from torch import nn
6
  from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
7
  import requests
8
 
30
  outputs = model(**inputs)
31
 
32
  # First, rescale logits to original image size
33
+ logits = nn.functional.interpolate(outputs.logits.detach().cpu(),
34
+ size=image.size[::-1], # (height, width)
35
+ mode='bilinear',
36
+ align_corners=False)
 
37
 
38
  # Second, apply argmax on the class dimension
39
  seg = logits.argmax(dim=1)[0]