小 效懈褔械褉懈薪 commited on
Commit
8029b4a
1 Parent(s): d3426a1

added stylematte, fix gpu issue

Browse files
Files changed (3) hide show
  1. app.py +1 -1
  2. models.py +5 -5
  3. test.py +2 -2
app.py CHANGED
@@ -3,7 +3,7 @@ from test import inference_img
3
  from models import *
4
 
5
  device='cpu'
6
- model = MaskForm()
7
  model = model.to(device)
8
  checkpoint = f"stylematte.pth"
9
  state_dict = torch.load(checkpoint, map_location=f'{device}')
 
3
  from models import *
4
 
5
  device='cpu'
6
+ model = StyleMatte()
7
  model = model.to(device)
8
  checkpoint = f"stylematte.pth"
9
  state_dict = torch.load(checkpoint, map_location=f'{device}')
models.py CHANGED
@@ -8,7 +8,7 @@ from typing import List
8
  from itertools import chain
9
 
10
  from transformers import SegformerForSemanticSegmentation,Mask2FormerForUniversalSegmentation
11
-
12
  class EncoderDecoder(nn.Module):
13
  def __init__(
14
  self,
@@ -284,9 +284,9 @@ class SegForm(nn.Module):
284
  return upsampled_logits
285
 
286
 
287
- class MaskForm(nn.Module):
288
  def __init__(self):
289
- super(MaskForm, self).__init__()
290
  # configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
291
  # configuration.num_labels = 1 ## set output as 1
292
  self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256],fpn_out=256)
@@ -476,6 +476,6 @@ class BoxFilter(nn.Module):
476
  return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
477
 
478
  if __name__ == '__main__':
479
- model = MaskForm().cuda()
480
- out=model(torch.randn(1,3,640,480).cuda())
481
  print(out.shape)
 
8
  from itertools import chain
9
 
10
  from transformers import SegformerForSemanticSegmentation,Mask2FormerForUniversalSegmentation
11
+ device='cpu'
12
  class EncoderDecoder(nn.Module):
13
  def __init__(
14
  self,
 
284
  return upsampled_logits
285
 
286
 
287
+ class StyleMatte(nn.Module):
288
  def __init__(self):
289
+ super(StyleMatte, self).__init__()
290
  # configuration = SegformerConfig.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
291
  # configuration.num_labels = 1 ## set output as 1
292
  self.fpn = FPN_fuse(feature_channels=[256, 256, 256, 256],fpn_out=256)
 
476
  return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
477
 
478
  if __name__ == '__main__':
479
+ model = StyleMatte().to(device)
480
+ out=model(torch.randn(1,3,640,480).to(devuce))
481
  print(out.shape)
test.py CHANGED
@@ -26,7 +26,7 @@ import logging
26
  import time
27
  from omegaconf import OmegaConf
28
  config = OmegaConf.load("base.yaml")
29
- device = "cuda"
30
 
31
  def conv3x3(in_planes, out_planes, stride=1):
32
  "3x3 convolution with padding"
@@ -720,7 +720,7 @@ def get_masked_local_from_global_test(global_result, local_result):
720
  return fusion_result
721
  def inference_once( model, scale_img, scale_trimap=None):
722
  pred_list = []
723
- tensor_img = torch.from_numpy(scale_img[:, :, :]).permute(2, 0, 1).cuda()
724
  input_t = tensor_img
725
  input_t = input_t/255.0
726
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
 
26
  import time
27
  from omegaconf import OmegaConf
28
  config = OmegaConf.load("base.yaml")
29
+ device = "cpu"
30
 
31
  def conv3x3(in_planes, out_planes, stride=1):
32
  "3x3 convolution with padding"
 
720
  return fusion_result
721
  def inference_once( model, scale_img, scale_trimap=None):
722
  pred_list = []
723
+ tensor_img = torch.from_numpy(scale_img[:, :, :]).permute(2, 0, 1).to(device)
724
  input_t = tensor_img
725
  input_t = input_t/255.0
726
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],