Sa-m commited on
Commit
a2e588c
1 Parent(s): 5225eaa

Update models/common.py

Browse files
Files changed (1) hide show
  1. models/common.py +28 -0
models/common.py CHANGED
@@ -53,6 +53,34 @@ class ReOrg(nn.Module):
53
  return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class Concat(nn.Module):
57
  def __init__(self, dimension=1):
58
  super(Concat, self).__init__()
 
53
  return torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)
54
 
55
 
56
+ class Merge(nn.Module):
57
+ def __init__(self,ch=()):
58
+ super(Merge, self).__init__()
59
+
60
+ def forward(self, x):
61
+
62
+ return [x[0],x[1],x[2]]
63
+
64
+
65
+ class Refine(nn.Module):
66
+
67
+ def __init__(self, c2, k, s, ch): # ch_in, ch_out, kernel, stride, padding, groups
68
+ super(Refine, self).__init__()
69
+ self.refine = nn.ModuleList()
70
+ for c in ch:
71
+ self.refine.append(Conv(c, c2, k, s))
72
+
73
+ def forward(self, x):
74
+ for i, f in enumerate(x):
75
+ if i == 0:
76
+ r = self.refine[i](f)
77
+ else:
78
+ r_p = self.refine[i](f)
79
+ r_p = F.interpolate(r_p, r.size()[2:], mode="bilinear", align_corners=False)
80
+ r = r + r_p
81
+ return r
82
+
83
+
84
  class Concat(nn.Module):
85
  def __init__(self, dimension=1):
86
  super(Concat, self).__init__()