fffiloni commited on
Commit
18b0529
1 Parent(s): 17fe826

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -137
app.py CHANGED
@@ -1,146 +1,23 @@
1
  import gradio as gr
2
 
3
- """
4
- Audio processing tools to convert between spectrogram images and waveforms.
5
- """
6
- import io
7
- import typing as T
8
-
9
- import numpy as np
10
- from PIL import Image
11
- import pydub
12
- from scipy.io import wavfile
13
- import torch
14
- import torchaudio
15
 
 
16
  from diffusers import StableDiffusionPipeline
17
 
18
  model_id = "riffusion/riffusion-model-v1"
19
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
20
  pipe = pipe.to("cuda")
21
 
22
- def get_spectro(prompt):
23
- image = pipe(prompt).images[0]
24
- return image
25
-
26
- def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO, float]:
27
- """
28
- Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
29
- """
30
-
31
- max_volume = 50
32
- power_for_image = 0.25
33
- Sxx = spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)
34
-
35
- sample_rate = 44100 # [Hz]
36
- clip_duration_ms = 5000 # [ms]
37
-
38
- bins_per_image = 512
39
- n_mels = 512
40
-
41
- # FFT parameters
42
- window_duration_ms = 100 # [ms]
43
- padded_duration_ms = 400 # [ms]
44
- step_size_ms = 10 # [ms]
45
-
46
- # Derived parameters
47
- num_samples = int(image.width / float(bins_per_image) * clip_duration_ms) * sample_rate
48
- n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
49
- hop_length = int(step_size_ms / 1000.0 * sample_rate)
50
- win_length = int(window_duration_ms / 1000.0 * sample_rate)
51
-
52
- samples = waveform_from_spectrogram(
53
- Sxx=Sxx,
54
- n_fft=n_fft,
55
- hop_length=hop_length,
56
- win_length=win_length,
57
- num_samples=num_samples,
58
- sample_rate=sample_rate,
59
- mel_scale=True,
60
- n_mels=n_mels,
61
- max_mel_iters=200,
62
- num_griffin_lim_iters=32,
63
- )
64
-
65
- wav_bytes = io.BytesIO()
66
- wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
67
- wav_bytes.seek(0)
68
-
69
- duration_s = float(len(samples)) / sample_rate
70
-
71
- return wav_bytes
72
-
73
- def spectrogram_from_image(
74
- image: Image.Image, max_volume: float = 50, power_for_image: float = 0.25
75
- ) -> np.ndarray:
76
- """
77
- Compute a spectrogram magnitude array from a spectrogram image.
78
- TODO(hayk): Add image_from_spectrogram and call this out as the reverse.
79
- """
80
- # Convert to a numpy array of floats
81
- data = np.array(image).astype(np.float32)
82
-
83
- # Flip Y take a single channel
84
- data = data[::-1, :, 0]
85
-
86
- # Invert
87
- data = 255 - data
88
-
89
- # Rescale to max volume
90
- data = data * max_volume / 255
91
-
92
- # Reverse the power curve
93
- data = np.power(data, 1 / power_for_image)
94
-
95
- return data
96
-
97
- def waveform_from_spectrogram(
98
- Sxx: np.ndarray,
99
- n_fft: int,
100
- hop_length: int,
101
- win_length: int,
102
- num_samples: int,
103
- sample_rate: int,
104
- mel_scale: bool = True,
105
- n_mels: int = 512,
106
- max_mel_iters: int = 200,
107
- num_griffin_lim_iters: int = 32,
108
- device: str = "cuda:0",
109
- ) -> np.ndarray:
110
- """
111
- Reconstruct a waveform from a spectrogram.
112
- This is an approximate inverse of spectrogram_from_waveform, using the Griffin-Lim algorithm
113
- to approximate the phase.
114
- """
115
- Sxx_torch = torch.from_numpy(Sxx).to(device)
116
-
117
- # TODO(hayk): Make this a class that caches the two things
118
-
119
- if mel_scale:
120
- mel_inv_scaler = torchaudio.transforms.InverseMelScale(
121
- n_mels=n_mels,
122
- sample_rate=sample_rate,
123
- f_min=0,
124
- f_max=10000,
125
- n_stft=n_fft // 2 + 1,
126
- norm=None,
127
- mel_scale="htk",
128
- max_iter=max_mel_iters,
129
- ).to(device)
130
-
131
- Sxx_torch = mel_inv_scaler(Sxx_torch)
132
-
133
- griffin_lim = torchaudio.transforms.GriffinLim(
134
- n_fft=n_fft,
135
- win_length=win_length,
136
- hop_length=hop_length,
137
- power=1.0,
138
- n_iter=num_griffin_lim_iters,
139
- ).to(device)
140
-
141
- waveform = griffin_lim(Sxx_torch).cpu().numpy()
142
-
143
- return waveform
144
-
145
-
146
- gr.Interface(fn=get_spectro, inputs=[gr.Textbox()], outputs=[gr.Image()]).launch()
 
1
  import gradio as gr
2
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ from spectro import wav_bytes_from_spectrogram_image
5
  from diffusers import StableDiffusionPipeline
6
 
7
  model_id = "riffusion/riffusion-model-v1"
8
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
9
  pipe = pipe.to("cuda")
10
 
11
+ def predict(prompt):
12
+ spec = pipe(prompt).images[0]
13
+ wav = wav_bytes_from_spectrogram_image(spec)
14
+ with open("output.wav", "wb") as f:
15
+ f.write(wav[0].getbuffer())
16
+ return 'output.wav'
17
+
18
+ gr.Interface(
19
+ predict,
20
+ inputs="text",
21
+ outputs=gr.outputs.Audio(type='filepath'),
22
+ title="Riffusion",
23
+ ).launch(share=True, debug=True)