fx
Browse files- audiocraft/__init__.py +0 -0
- audiocraft/conv.py +18 -36
- audiocraft/lm.py +2 -2
- audiocraft/seanet.py +3 -0
audiocraft/__init__.py
ADDED
File without changes
|
audiocraft/conv.py
CHANGED
@@ -114,20 +114,7 @@ class NormConv1d(nn.Module):
|
|
114 |
return x
|
115 |
|
116 |
|
117 |
-
class NormConv2d(nn.Module):
|
118 |
-
"""Wrapper around Conv2d and normalization applied to this conv
|
119 |
-
to provide a uniform interface across normalization approaches.
|
120 |
-
"""
|
121 |
-
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
122 |
-
super().__init__()
|
123 |
-
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
124 |
-
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
125 |
-
self.norm_type = norm
|
126 |
|
127 |
-
def forward(self, x):
|
128 |
-
x = self.conv(x)
|
129 |
-
x = self.norm(x)
|
130 |
-
return x
|
131 |
|
132 |
|
133 |
class NormConvTranspose1d(nn.Module):
|
@@ -147,30 +134,25 @@ class NormConvTranspose1d(nn.Module):
|
|
147 |
return x
|
148 |
|
149 |
|
150 |
-
class NormConvTranspose2d(nn.Module):
|
151 |
-
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
152 |
-
to provide a uniform interface across normalization approaches.
|
153 |
-
"""
|
154 |
-
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
155 |
-
super().__init__()
|
156 |
-
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
157 |
-
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
158 |
|
159 |
-
def forward(self, x):
|
160 |
-
x = self.convtr(x)
|
161 |
-
x = self.norm(x)
|
162 |
-
return x
|
163 |
|
164 |
|
165 |
class StreamableConv1d(nn.Module):
|
166 |
"""Conv1d with some builtin handling of asymmetric or causal padding
|
167 |
and normalization.
|
168 |
"""
|
169 |
-
def __init__(self,
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
super().__init__()
|
175 |
# warn user on unusual setup between dilation and stride
|
176 |
if stride > 1 and dilation > 1:
|
@@ -192,12 +174,15 @@ class StreamableConv1d(nn.Module):
|
|
192 |
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
193 |
if self.causal:
|
194 |
# Left padding for causal
|
195 |
-
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
|
|
196 |
else:
|
197 |
# Asymmetric padding required for odd strides
|
198 |
padding_right = padding_total // 2
|
199 |
padding_left = padding_total - padding_right
|
200 |
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
|
|
|
|
201 |
return self.conv(x)
|
202 |
|
203 |
|
@@ -230,13 +215,10 @@ class StreamableConvTranspose1d(nn.Module):
|
|
230 |
# as removing it here would require also passing the length at the matching layer
|
231 |
# in the encoder.
|
232 |
if self.causal:
|
233 |
-
|
234 |
-
# if trim_right_ratio = 1.0, trim everything from right
|
235 |
-
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
236 |
-
padding_left = padding_total - padding_right
|
237 |
-
y = unpad1d(y, (padding_left, padding_right))
|
238 |
else:
|
239 |
# Asymmetric padding required for odd strides
|
|
|
240 |
padding_right = padding_total // 2
|
241 |
padding_left = padding_total - padding_right
|
242 |
y = unpad1d(y, (padding_left, padding_right))
|
|
|
114 |
return x
|
115 |
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
|
|
|
|
|
|
|
|
118 |
|
119 |
|
120 |
class NormConvTranspose1d(nn.Module):
|
|
|
134 |
return x
|
135 |
|
136 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
|
|
|
|
|
|
|
|
|
138 |
|
139 |
|
140 |
class StreamableConv1d(nn.Module):
|
141 |
"""Conv1d with some builtin handling of asymmetric or causal padding
|
142 |
and normalization.
|
143 |
"""
|
144 |
+
def __init__(self,
|
145 |
+
in_channels,
|
146 |
+
out_channels,
|
147 |
+
kernel_size,
|
148 |
+
stride=1,
|
149 |
+
dilation=1,
|
150 |
+
groups=1,
|
151 |
+
bias=True,
|
152 |
+
causal=False,
|
153 |
+
norm='none',
|
154 |
+
norm_kwargs={},
|
155 |
+
pad_mode='reflect'):
|
156 |
super().__init__()
|
157 |
# warn user on unusual setup between dilation and stride
|
158 |
if stride > 1 and dilation > 1:
|
|
|
174 |
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
175 |
if self.causal:
|
176 |
# Left padding for causal
|
177 |
+
# x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
178 |
+
print('\n \n\n\nn\n\n\nnCAUSAL N\n\n\n')
|
179 |
else:
|
180 |
# Asymmetric padding required for odd strides
|
181 |
padding_right = padding_total // 2
|
182 |
padding_left = padding_total - padding_right
|
183 |
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
184 |
+
# print(f'\n \/n\n\n\nANTICaus N {x.shape=}\n')
|
185 |
+
# ANTICaus CONV OLD_SHAPE=torch.Size([1, 512, 280]) x.shape=torch.Size([1, 512, 282])
|
186 |
return self.conv(x)
|
187 |
|
188 |
|
|
|
215 |
# as removing it here would require also passing the length at the matching layer
|
216 |
# in the encoder.
|
217 |
if self.causal:
|
218 |
+
print('\n \n\n\nn\n\n\nnCAUSAL T\n\n\n\n\n')
|
|
|
|
|
|
|
|
|
219 |
else:
|
220 |
# Asymmetric padding required for odd strides
|
221 |
+
# print('\n \n\n\nn\n\n\nnANTICAUSAL T\n\n\n')
|
222 |
padding_right = padding_total // 2
|
223 |
padding_left = padding_total - padding_right
|
224 |
y = unpad1d(y, (padding_left, padding_right))
|
audiocraft/lm.py
CHANGED
@@ -435,7 +435,7 @@ class LMModel(StreamingModule):
|
|
435 |
# print('Set All to Special')
|
436 |
|
437 |
# RUNS with = 2047 just different of self.special_token_id -> 2047 is drill noise
|
438 |
-
# next_token[:] = self.special_token_id
|
439 |
|
440 |
|
441 |
|
@@ -451,7 +451,7 @@ class LMModel(StreamingModule):
|
|
451 |
unconditional_state.clear()
|
452 |
|
453 |
out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
454 |
-
print(f'{out_codes.shape=} {out_codes.min()} {out_codes.max()}\n')
|
455 |
out_start_offset = start_offset if remove_prompts else 0
|
456 |
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
457 |
|
|
|
435 |
# print('Set All to Special')
|
436 |
|
437 |
# RUNS with = 2047 just different of self.special_token_id -> 2047 is drill noise
|
438 |
+
# next_token[:] = self.special_token_id
|
439 |
|
440 |
|
441 |
|
|
|
451 |
unconditional_state.clear()
|
452 |
|
453 |
out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
454 |
+
print(f' <=> CODES {out_codes.shape=} {out_codes.min()} {out_codes.max()}\n') # ARRIVES here also if special
|
455 |
out_start_offset = start_offset if remove_prompts else 0
|
456 |
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
457 |
|
audiocraft/seanet.py
CHANGED
@@ -143,5 +143,8 @@ class SEANetDecoder(nn.Module):
|
|
143 |
self.model = nn.Sequential(*model)
|
144 |
|
145 |
def forward(self, z):
|
|
|
|
|
146 |
y = self.model(z)
|
|
|
147 |
return y
|
|
|
143 |
self.model = nn.Sequential(*model)
|
144 |
|
145 |
def forward(self, z):
|
146 |
+
print(f'\n Enter seanet with shape {z.shape}\n') # arrives here with (1,128,35)
|
147 |
+
# how can this convnet care for the value that is in z so it crashes?
|
148 |
y = self.model(z)
|
149 |
+
print(f'\n Exit seanet with shape {y.shape}\n') # arrives here with (1,128,35)
|
150 |
return y
|