classifier proto
Browse files- models/classifier.py +153 -0
models/classifier.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class ResBlock(nn.Module):
|
5 |
+
def __init__(
|
6 |
+
self,
|
7 |
+
channels,
|
8 |
+
dropout,
|
9 |
+
out_channels=None,
|
10 |
+
use_conv=False,
|
11 |
+
use_scale_shift_norm=False,
|
12 |
+
dims=2,
|
13 |
+
up=False,
|
14 |
+
down=False,
|
15 |
+
kernel_size=3,
|
16 |
+
do_checkpoint=True,
|
17 |
+
):
|
18 |
+
super().__init__()
|
19 |
+
self.channels = channels
|
20 |
+
self.dropout = dropout
|
21 |
+
self.out_channels = out_channels or channels
|
22 |
+
self.use_conv = use_conv
|
23 |
+
self.use_scale_shift_norm = use_scale_shift_norm
|
24 |
+
self.do_checkpoint = do_checkpoint
|
25 |
+
padding = 1 if kernel_size == 3 else 2
|
26 |
+
|
27 |
+
self.in_layers = nn.Sequential(
|
28 |
+
normalization(channels),
|
29 |
+
nn.SiLU(),
|
30 |
+
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
31 |
+
)
|
32 |
+
|
33 |
+
self.updown = up or down
|
34 |
+
|
35 |
+
if up:
|
36 |
+
self.h_upd = Upsample(channels, False, dims)
|
37 |
+
self.x_upd = Upsample(channels, False, dims)
|
38 |
+
elif down:
|
39 |
+
self.h_upd = Downsample(channels, False, dims)
|
40 |
+
self.x_upd = Downsample(channels, False, dims)
|
41 |
+
else:
|
42 |
+
self.h_upd = self.x_upd = nn.Identity()
|
43 |
+
|
44 |
+
self.out_layers = nn.Sequential(
|
45 |
+
normalization(self.out_channels),
|
46 |
+
nn.SiLU(),
|
47 |
+
nn.Dropout(p=dropout),
|
48 |
+
zero_module(
|
49 |
+
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
50 |
+
),
|
51 |
+
)
|
52 |
+
|
53 |
+
if self.out_channels == channels:
|
54 |
+
self.skip_connection = nn.Identity()
|
55 |
+
elif use_conv:
|
56 |
+
self.skip_connection = conv_nd(
|
57 |
+
dims, channels, self.out_channels, kernel_size, padding=padding
|
58 |
+
)
|
59 |
+
else:
|
60 |
+
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
if self.do_checkpoint:
|
64 |
+
return checkpoint(
|
65 |
+
self._forward, x
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
return self._forward(x)
|
69 |
+
|
70 |
+
def _forward(self, x):
|
71 |
+
if self.updown:
|
72 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
73 |
+
h = in_rest(x)
|
74 |
+
h = self.h_upd(h)
|
75 |
+
x = self.x_upd(x)
|
76 |
+
h = in_conv(h)
|
77 |
+
else:
|
78 |
+
h = self.in_layers(x)
|
79 |
+
h = self.out_layers(h)
|
80 |
+
return self.skip_connection(x) + h
|
81 |
+
|
82 |
+
|
83 |
+
class AudioMiniEncoder(nn.Module):
|
84 |
+
def __init__(self,
|
85 |
+
spec_dim,
|
86 |
+
embedding_dim,
|
87 |
+
base_channels=128,
|
88 |
+
depth=2,
|
89 |
+
resnet_blocks=2,
|
90 |
+
attn_blocks=4,
|
91 |
+
num_attn_heads=4,
|
92 |
+
dropout=0,
|
93 |
+
downsample_factor=2,
|
94 |
+
kernel_size=3):
|
95 |
+
super().__init__()
|
96 |
+
self.init = nn.Sequential(
|
97 |
+
conv_nd(1, spec_dim, base_channels, 3, padding=1)
|
98 |
+
)
|
99 |
+
ch = base_channels
|
100 |
+
res = []
|
101 |
+
self.layers = depth
|
102 |
+
for l in range(depth):
|
103 |
+
for r in range(resnet_blocks):
|
104 |
+
res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size))
|
105 |
+
res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor))
|
106 |
+
ch *= 2
|
107 |
+
self.res = nn.Sequential(*res)
|
108 |
+
self.final = nn.Sequential(
|
109 |
+
normalization(ch),
|
110 |
+
nn.SiLU(),
|
111 |
+
conv_nd(1, ch, embedding_dim, 1)
|
112 |
+
)
|
113 |
+
attn = []
|
114 |
+
for a in range(attn_blocks):
|
115 |
+
attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False))
|
116 |
+
self.attn = nn.Sequential(*attn)
|
117 |
+
self.dim = embedding_dim
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
h = self.init(x)
|
121 |
+
h = sequential_checkpoint(self.res, self.layers, h)
|
122 |
+
h = self.final(h)
|
123 |
+
for blk in self.attn:
|
124 |
+
h = checkpoint(blk, h)
|
125 |
+
return h[:, :, 0]
|
126 |
+
|
127 |
+
|
128 |
+
class AudioMiniEncoderWithClassifierHead(nn.Module):
|
129 |
+
def __init__(self, classes, distribute_zero_label=True, **kwargs):
|
130 |
+
super().__init__()
|
131 |
+
self.enc = AudioMiniEncoder(**kwargs)
|
132 |
+
self.head = nn.Linear(self.enc.dim, classes)
|
133 |
+
self.num_classes = classes
|
134 |
+
self.distribute_zero_label = distribute_zero_label
|
135 |
+
|
136 |
+
def forward(self, x, labels=None):
|
137 |
+
h = self.enc(x)
|
138 |
+
logits = self.head(h)
|
139 |
+
if labels is None:
|
140 |
+
return logits
|
141 |
+
else:
|
142 |
+
if self.distribute_zero_label:
|
143 |
+
oh_labels = nn.functional.one_hot(labels, num_classes=self.num_classes)
|
144 |
+
zeros_indices = (labels == 0).unsqueeze(-1)
|
145 |
+
# Distribute 20% of the probability mass on all classes when zero is specified, to compensate for dataset noise.
|
146 |
+
zero_extra_mass = torch.full_like(oh_labels, dtype=torch.float, fill_value=.2/(self.num_classes-1))
|
147 |
+
zero_extra_mass[:, 0] = -.2
|
148 |
+
zero_extra_mass = zero_extra_mass * zeros_indices
|
149 |
+
oh_labels = oh_labels + zero_extra_mass
|
150 |
+
else:
|
151 |
+
oh_labels = labels
|
152 |
+
loss = nn.functional.cross_entropy(logits, oh_labels)
|
153 |
+
return loss
|