Spaces:
Runtime error
Runtime error
С Чичерин
commited on
Commit
•
8029b4a
1
Parent(s):
d3426a1
added stylematte, fix gpu issue
Browse files
app.py
CHANGED
@@ -3,7 +3,7 @@ from test import inference_img
|
|
3 |
from models import *
|
4 |
|
5 |
device='cpu'
|
6 |
-
model =
|
7 |
model = model.to(device)
|
8 |
checkpoint = f"stylematte.pth"
|
9 |
state_dict = torch.load(checkpoint, map_location=f'{device}')
|
|
|
3 |
from models import *
|
4 |
|
5 |
device='cpu'
|
6 |
+
model = StyleMatte()
|
7 |
model = model.to(device)
|
8 |
checkpoint = f"stylematte.pth"
|
9 |
state_dict = torch.load(checkpoint, map_location=f'{device}')
|
models.py
CHANGED
@@ -8,7 +8,7 @@ from typing import List
|
|
8 |
from itertools import chain
|
9 |
|
10 |
from transformers import SegformerForSemanticSegmentation,Mask2FormerForUniversalSegmentation
|
11 |
-
|
12 |
class EncoderDecoder(nn.Module):
|
13 |
def __init__(
|
14 |
self,
|
@@ -284,9 +284,9 @@ class SegForm(nn.Module):
|
|
284 |
return upsampled_logits
|
285 |
|
286 |
|
287 |
-
class
|
288 |
def __init__(self):
|
289 |
-
super(
|
290 |
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
291 |
# configuration.num_labels = 1 ## set output as 1
|
292 |
self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256],fpn_out=256)
|
@@ -476,6 +476,6 @@ class BoxFilter(nn.Module):
|
|
476 |
return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
|
477 |
|
478 |
if __name__ == '__main__':
|
479 |
-
model =
|
480 |
-
out=model(torch.randn(1,3,640,480).
|
481 |
print(out.shape)
|
|
|
8 |
from itertools import chain
|
9 |
|
10 |
from transformers import SegformerForSemanticSegmentation,Mask2FormerForUniversalSegmentation
|
11 |
+
device='cpu'
|
12 |
class EncoderDecoder(nn.Module):
|
13 |
def __init__(
|
14 |
self,
|
|
|
284 |
return upsampled_logits
|
285 |
|
286 |
|
287 |
+
class StyleMatte(nn.Module):
|
288 |
def __init__(self):
|
289 |
+
super(StyleMatte, self).__init__()
|
290 |
# configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
|
291 |
# configuration.num_labels = 1 ## set output as 1
|
292 |
self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256],fpn_out=256)
|
|
|
476 |
return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
|
477 |
|
478 |
if __name__ == '__main__':
|
479 |
+
model = StyleMatte().to(device)
|
480 |
+
out=model(torch.randn(1,3,640,480).to(devuce))
|
481 |
print(out.shape)
|
test.py
CHANGED
@@ -26,7 +26,7 @@ import logging
|
|
26 |
import time
|
27 |
from omegaconf import OmegaConf
|
28 |
config = OmegaConf.load("base.yaml")
|
29 |
-
device = "
|
30 |
|
31 |
def conv3x3(in_planes, out_planes, stride=1):
|
32 |
"3x3 convolution with padding"
|
@@ -720,7 +720,7 @@ def get_masked_local_from_global_test(global_result, local_result):
|
|
720 |
return fusion_result
|
721 |
def inference_once( model, scale_img, scale_trimap=None):
|
722 |
pred_list = []
|
723 |
-
tensor_img = torch.from_numpy(scale_img[:, :, :]).permute(2, 0, 1).
|
724 |
input_t = tensor_img
|
725 |
input_t = input_t/255.0
|
726 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
|
|
26 |
import time
|
27 |
from omegaconf import OmegaConf
|
28 |
config = OmegaConf.load("base.yaml")
|
29 |
+
device = "cpu"
|
30 |
|
31 |
def conv3x3(in_planes, out_planes, stride=1):
|
32 |
"3x3 convolution with padding"
|
|
|
720 |
return fusion_result
|
721 |
def inference_once( model, scale_img, scale_trimap=None):
|
722 |
pred_list = []
|
723 |
+
tensor_img = torch.from_numpy(scale_img[:, :, :]).permute(2, 0, 1).to(device)
|
724 |
input_t = tensor_img
|
725 |
input_t = input_t/255.0
|
726 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|