teticio commited on
Commit
4f552a8
1 Parent(s): 13aa297

take channels into account

Browse files
Files changed (1) hide show
  1. scripts/train_vae.py +1 -1
scripts/train_vae.py CHANGED
@@ -44,7 +44,7 @@ class AudioDiffusion(Dataset):
44
  if self.channels == 3:
45
  image = image.convert('RGB')
46
  image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
47
- (image.height, image.width, 3))
48
  image = ((image / 255) * 2 - 1)
49
  return {'image': image}
50
 
 
44
  if self.channels == 3:
45
  image = image.convert('RGB')
46
  image = np.frombuffer(image.tobytes(), dtype="uint8").reshape(
47
+ (image.height, image.width, self.channels))
48
  image = ((image / 255) * 2 - 1)
49
  return {'image': image}
50