Dionyssos commited on
Commit
d8e2a3d
1 Parent(s): fe62fb4
audiocraft/__init__.py ADDED
File without changes
audiocraft/conv.py CHANGED
@@ -114,20 +114,7 @@ class NormConv1d(nn.Module):
114
  return x
115
 
116
 
117
- class NormConv2d(nn.Module):
118
- """Wrapper around Conv2d and normalization applied to this conv
119
- to provide a uniform interface across normalization approaches.
120
- """
121
- def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
122
- super().__init__()
123
- self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
124
- self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
125
- self.norm_type = norm
126
 
127
- def forward(self, x):
128
- x = self.conv(x)
129
- x = self.norm(x)
130
- return x
131
 
132
 
133
  class NormConvTranspose1d(nn.Module):
@@ -147,30 +134,25 @@ class NormConvTranspose1d(nn.Module):
147
  return x
148
 
149
 
150
- class NormConvTranspose2d(nn.Module):
151
- """Wrapper around ConvTranspose2d and normalization applied to this conv
152
- to provide a uniform interface across normalization approaches.
153
- """
154
- def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
155
- super().__init__()
156
- self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
157
- self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
158
 
159
- def forward(self, x):
160
- x = self.convtr(x)
161
- x = self.norm(x)
162
- return x
163
 
164
 
165
  class StreamableConv1d(nn.Module):
166
  """Conv1d with some builtin handling of asymmetric or causal padding
167
  and normalization.
168
  """
169
- def __init__(self, in_channels: int, out_channels: int,
170
- kernel_size: int, stride: int = 1, dilation: int = 1,
171
- groups: int = 1, bias: bool = True, causal: bool = False,
172
- norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
173
- pad_mode: str = 'reflect'):
 
 
 
 
 
 
 
174
  super().__init__()
175
  # warn user on unusual setup between dilation and stride
176
  if stride > 1 and dilation > 1:
@@ -192,12 +174,15 @@ class StreamableConv1d(nn.Module):
192
  extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
193
  if self.causal:
194
  # Left padding for causal
195
- x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
 
196
  else:
197
  # Asymmetric padding required for odd strides
198
  padding_right = padding_total // 2
199
  padding_left = padding_total - padding_right
200
  x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
 
 
201
  return self.conv(x)
202
 
203
 
@@ -230,13 +215,10 @@ class StreamableConvTranspose1d(nn.Module):
230
  # as removing it here would require also passing the length at the matching layer
231
  # in the encoder.
232
  if self.causal:
233
- # Trim the padding on the right according to the specified ratio
234
- # if trim_right_ratio = 1.0, trim everything from right
235
- padding_right = math.ceil(padding_total * self.trim_right_ratio)
236
- padding_left = padding_total - padding_right
237
- y = unpad1d(y, (padding_left, padding_right))
238
  else:
239
  # Asymmetric padding required for odd strides
 
240
  padding_right = padding_total // 2
241
  padding_left = padding_total - padding_right
242
  y = unpad1d(y, (padding_left, padding_right))
 
114
  return x
115
 
116
 
 
 
 
 
 
 
 
 
 
117
 
 
 
 
 
118
 
119
 
120
  class NormConvTranspose1d(nn.Module):
 
134
  return x
135
 
136
 
 
 
 
 
 
 
 
 
137
 
 
 
 
 
138
 
139
 
140
  class StreamableConv1d(nn.Module):
141
  """Conv1d with some builtin handling of asymmetric or causal padding
142
  and normalization.
143
  """
144
+ def __init__(self,
145
+ in_channels,
146
+ out_channels,
147
+ kernel_size,
148
+ stride=1,
149
+ dilation=1,
150
+ groups=1,
151
+ bias=True,
152
+ causal=False,
153
+ norm='none',
154
+ norm_kwargs={},
155
+ pad_mode='reflect'):
156
  super().__init__()
157
  # warn user on unusual setup between dilation and stride
158
  if stride > 1 and dilation > 1:
 
174
  extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
175
  if self.causal:
176
  # Left padding for causal
177
+ # x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
178
+ print('\n \n\n\nn\n\n\nnCAUSAL N\n\n\n')
179
  else:
180
  # Asymmetric padding required for odd strides
181
  padding_right = padding_total // 2
182
  padding_left = padding_total - padding_right
183
  x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
184
+ # print(f'\n \/n\n\n\nANTICaus N {x.shape=}\n')
185
+ # ANTICaus CONV OLD_SHAPE=torch.Size([1, 512, 280]) x.shape=torch.Size([1, 512, 282])
186
  return self.conv(x)
187
 
188
 
 
215
  # as removing it here would require also passing the length at the matching layer
216
  # in the encoder.
217
  if self.causal:
218
+ print('\n \n\n\nn\n\n\nnCAUSAL T\n\n\n\n\n')
 
 
 
 
219
  else:
220
  # Asymmetric padding required for odd strides
221
+ # print('\n \n\n\nn\n\n\nnANTICAUSAL T\n\n\n')
222
  padding_right = padding_total // 2
223
  padding_left = padding_total - padding_right
224
  y = unpad1d(y, (padding_left, padding_right))
audiocraft/lm.py CHANGED
@@ -435,7 +435,7 @@ class LMModel(StreamingModule):
435
  # print('Set All to Special')
436
 
437
  # RUNS with = 2047 just different of self.special_token_id -> 2047 is drill noise
438
- # next_token[:] = self.special_token_id
439
 
440
 
441
 
@@ -451,7 +451,7 @@ class LMModel(StreamingModule):
451
  unconditional_state.clear()
452
 
453
  out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
454
- print(f'{out_codes.shape=} {out_codes.min()} {out_codes.max()}\n')
455
  out_start_offset = start_offset if remove_prompts else 0
456
  out_codes = out_codes[..., out_start_offset:max_gen_len]
457
 
 
435
  # print('Set All to Special')
436
 
437
  # RUNS with = 2047 just different of self.special_token_id -> 2047 is drill noise
438
+ # next_token[:] = self.special_token_id
439
 
440
 
441
 
 
451
  unconditional_state.clear()
452
 
453
  out_codes, _, _ = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
454
+ print(f' <=> CODES {out_codes.shape=} {out_codes.min()} {out_codes.max()}\n') # ARRIVES here also if special
455
  out_start_offset = start_offset if remove_prompts else 0
456
  out_codes = out_codes[..., out_start_offset:max_gen_len]
457
 
audiocraft/seanet.py CHANGED
@@ -143,5 +143,8 @@ class SEANetDecoder(nn.Module):
143
  self.model = nn.Sequential(*model)
144
 
145
  def forward(self, z):
 
 
146
  y = self.model(z)
 
147
  return y
 
143
  self.model = nn.Sequential(*model)
144
 
145
  def forward(self, z):
146
+ print(f'\n Enter seanet with shape {z.shape}\n') # arrives here with (1,128,35)
147
+ # how can this convnet care for the value that is in z so it crashes?
148
  y = self.model(z)
149
+ print(f'\n Exit seanet with shape {y.shape}\n') # arrives here with (1,128,35)
150
  return y