Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,853 Bytes
4c954ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch
import torch.nn as nn
class MultitaskHead(nn.Module):
def __init__(self, input_channels, num_class, head_size):
super(MultitaskHead, self).__init__()
m = int(input_channels / 4)
heads = []
for output_channels in sum(head_size, []):
heads.append(
nn.Sequential(
nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(m, output_channels, kernel_size=1),
)
)
self.heads = nn.ModuleList(heads)
assert num_class == sum(sum(head_size, []))
def forward(self, x):
# import pdb;pdb.set_trace()
return torch.cat([head(x) for head in self.heads], dim=1)
class AngleDistanceHead(nn.Module):
def __init__(self, input_channels, num_class, head_size):
super(AngleDistanceHead, self).__init__()
m = int(input_channels/4)
heads = []
for output_channels in sum(head_size, []):
if output_channels != 2:
heads.append(
nn.Sequential(
nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(m, output_channels, kernel_size=1),
)
)
else:
heads.append(
nn.Sequential(
nn.Conv2d(input_channels, m, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
CosineSineLayer(m)
)
)
self.heads = nn.ModuleList(heads)
assert num_class == sum(sum(head_size, []))
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=1) |