teticio commited on
Commit
072978d
1 Parent(s): 5e9c370

allow steps to be different from 1000

Browse files
audiodiffusion/__init__.py CHANGED
@@ -4,12 +4,12 @@ import torch
4
  import numpy as np
5
  from PIL import Image
6
  from tqdm.auto import tqdm
7
- from diffusers import DDPMPipeline
8
  from librosa.beat import beat_track
 
9
 
10
  from .mel import Mel
11
 
12
- VERSION = "1.1.2"
13
 
14
 
15
  class AudioDiffusion:
@@ -60,7 +60,7 @@ class AudioDiffusion:
60
  raw_audio: np.ndarray = None,
61
  slice: int = 0,
62
  start_step: int = 0,
63
- steps: int = 1000,
64
  generator: torch.Generator = None
65
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
66
  """Generate random mel spectrogram from audio input and convert to audio.
@@ -70,7 +70,7 @@ class AudioDiffusion:
70
  raw_audio (np.ndarray): audio as numpy array
71
  slice (int): slice number of audio to convert
72
  start_step (int): step to start from
73
- steps (int): number of de-noising steps to perform
74
  generator (torch.Generator): random number generator or None
75
 
76
  Returns:
@@ -80,6 +80,10 @@ class AudioDiffusion:
80
 
81
  # It would be better to derive a class from DDPMDiffusionPipeline
82
  # but currently the return type ImagePipelineOutput cannot be imported.
 
 
 
 
83
  images = torch.randn(
84
  (1, self.ddpm.unet.in_channels, self.ddpm.unet.sample_size,
85
  self.ddpm.unet.sample_size),
@@ -94,16 +98,17 @@ class AudioDiffusion:
94
  input_image.height))
95
  input_image = ((input_image / 255) * 2 - 1)
96
  if start_step > 0:
97
- images[0][0] = self.ddpm.scheduler.add_noise(
98
- torch.tensor(input_image[np.newaxis, np.newaxis, :]), images,
99
- steps - start_step)
100
 
101
  images = images.to(self.ddpm.device)
102
- self.ddpm.scheduler.set_timesteps(steps)
103
- for t in self.progress_bar(self.ddpm.scheduler.timesteps[start_step:]):
104
  model_output = self.ddpm.unet(images, t)['sample']
105
- images = self.ddpm.scheduler.step(
106
- model_output, t, images, generator=generator)['prev_sample']
 
 
107
  images = (images / 2 + 0.5).clamp(0, 1)
108
  images = images.cpu().permute(0, 2, 3, 1).numpy()
109
 
 
4
  import numpy as np
5
  from PIL import Image
6
  from tqdm.auto import tqdm
 
7
  from librosa.beat import beat_track
8
+ from diffusers import DDPMPipeline, DDPMScheduler
9
 
10
  from .mel import Mel
11
 
12
+ VERSION = "1.1.3"
13
 
14
 
15
  class AudioDiffusion:
 
60
  raw_audio: np.ndarray = None,
61
  slice: int = 0,
62
  start_step: int = 0,
63
+ steps: int = None,
64
  generator: torch.Generator = None
65
  ) -> Tuple[Image.Image, Tuple[int, np.ndarray]]:
66
  """Generate random mel spectrogram from audio input and convert to audio.
 
70
  raw_audio (np.ndarray): audio as numpy array
71
  slice (int): slice number of audio to convert
72
  start_step (int): step to start from
73
+ steps (int): number of de-noising steps to perform (defaults to num_train_timesteps)
74
  generator (torch.Generator): random number generator or None
75
 
76
  Returns:
 
80
 
81
  # It would be better to derive a class from DDPMDiffusionPipeline
82
  # but currently the return type ImagePipelineOutput cannot be imported.
83
+ if steps is None:
84
+ steps = self.ddpm.scheduler.num_train_timesteps
85
+ scheduler = DDPMScheduler(num_train_timesteps=steps)
86
+ scheduler.set_timesteps(steps)
87
  images = torch.randn(
88
  (1, self.ddpm.unet.in_channels, self.ddpm.unet.sample_size,
89
  self.ddpm.unet.sample_size),
 
98
  input_image.height))
99
  input_image = ((input_image / 255) * 2 - 1)
100
  if start_step > 0:
101
+ images[0][0] = scheduler.add_noise(
102
+ torch.tensor(input_image[np.newaxis, np.newaxis, :]),
103
+ images, steps - start_step)
104
 
105
  images = images.to(self.ddpm.device)
106
+ for t in self.progress_bar(scheduler.timesteps[start_step:]):
 
107
  model_output = self.ddpm.unet(images, t)['sample']
108
+ images = scheduler.step(model_output,
109
+ t,
110
+ images,
111
+ generator=generator)['prev_sample']
112
  images = (images / 2 + 0.5).clamp(0, 1)
113
  images = images.cpu().permute(0, 2, 3, 1).numpy()
114
 
notebooks/test_model.ipynb CHANGED
The diff for this file is too large to render. See raw diff