teticio commited on
Commit
0463954
1 Parent(s): 2dddff0

handle tuple / list for resolutions

Browse files
Files changed (1) hide show
  1. audiodiffusion/__init__.py +4 -3
audiodiffusion/__init__.py CHANGED
@@ -213,11 +213,12 @@ class AudioDiffusionPipeline(DiffusionPipeline):
213
  step_generator = step_generator or generator
214
  # For backwards compatibility
215
  if type(self.unet.sample_size) == int:
216
- self.unet.sample_size = [self.unet.sample_size,
217
- self.unet.sample_size]
218
  if noise is None:
219
  noise = torch.randn(
220
- [batch_size, self.unet.in_channels] + self.unet.sample_size,
 
221
  generator=generator)
222
  images = noise
223
  mask = None
 
213
  step_generator = step_generator or generator
214
  # For backwards compatibility
215
  if type(self.unet.sample_size) == int:
216
+ self.unet.sample_size = (self.unet.sample_size,
217
+ self.unet.sample_size)
218
  if noise is None:
219
  noise = torch.randn(
220
+ (batch_size, self.unet.in_channels, self.unet.sample_size[0],
221
+ self.unet.sample_size[1]),
222
  generator=generator)
223
  images = noise
224
  mask = None