hungchiayu1 commited on
Commit
df31906
1 Parent(s): 86a3494

Created tango2 pipeline

Browse files
Files changed (1) hide show
  1. app.py +172 -4
app.py CHANGED
@@ -11,6 +11,165 @@ from pydub import AudioSegment
11
  from gradio import Markdown
12
  import spaces
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Automatic device detection
15
  if torch.cuda.is_available():
16
  device_type = "cuda"
@@ -79,13 +238,22 @@ class Tango:
79
  # Initialize TANGO
80
 
81
  tango = Tango(device="cpu")
82
- tango.vae.to(device_type)
83
- tango.stft.to(device_type)
84
- tango.model.to(device_type)
 
 
 
 
 
 
 
 
85
 
86
  @spaces.GPU(duration=60)
87
  def gradio_generate(prompt, output_format, steps, guidance):
88
- output_wave = tango.generate(prompt, steps, guidance)
 
89
  # output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav"
90
  output_filename = "temp.wav"
91
  wavio.write(output_filename, output_wave, rate=16000, sampwidth=2)
 
11
  from gradio import Markdown
12
  import spaces
13
 
14
+ import torch
15
+ from diffusers.models.autoencoder_kl import AutoencoderKL
16
+ from diffusers.models.unet_2d_condition import UNet2DConditionModel
17
+ from diffusers import DiffusionPipeline,AudioPipelineOutput
18
+ from transformers import CLIPTextModel, T5EncoderModel, AutoModel, T5Tokenizer, T5TokenizerFast
19
+ from typing import Union
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from tqdm import tqdm
22
+
23
+
24
+
25
+
26
+
27
+ class Tango2Pipeline(DiffusionPipeline):
28
+
29
+
30
+ def __init__(
31
+ self,
32
+ vae: AutoencoderKL,
33
+ text_encoder: T5EncoderModel,
34
+ tokenizer: Union[T5Tokenizer, T5TokenizerFast],
35
+ unet: UNet2DConditionModel,
36
+ scheduler: DDPMScheduler
37
+ ):
38
+
39
+ super().__init__()
40
+
41
+ self.register_modules(vae=vae,
42
+ text_encoder=text_encoder,
43
+ tokenizer=tokenizer,
44
+ unet=unet,
45
+ scheduler=scheduler
46
+ )
47
+
48
+
49
+ def _encode_prompt(self, prompt):
50
+ device = self.text_encoder.device
51
+
52
+ batch = self.tokenizer(
53
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
54
+ )
55
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
56
+
57
+
58
+ encoder_hidden_states = self.text_encoder(
59
+ input_ids=input_ids, attention_mask=attention_mask
60
+ )[0]
61
+
62
+ boolean_encoder_mask = (attention_mask == 1).to(device)
63
+
64
+ return encoder_hidden_states, boolean_encoder_mask
65
+
66
+ def _encode_text_classifier_free(self, prompt, num_samples_per_prompt):
67
+ device = self.text_encoder.device
68
+ batch = self.tokenizer(
69
+ prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
70
+ )
71
+ input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
72
+
73
+ with torch.no_grad():
74
+ prompt_embeds = self.text_encoder(
75
+ input_ids=input_ids, attention_mask=attention_mask
76
+ )[0]
77
+
78
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
79
+ attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
80
+
81
+ # get unconditional embeddings for classifier free guidance
82
+ uncond_tokens = [""] * len(prompt)
83
+
84
+ max_length = prompt_embeds.shape[1]
85
+ uncond_batch = self.tokenizer(
86
+ uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt",
87
+ )
88
+ uncond_input_ids = uncond_batch.input_ids.to(device)
89
+ uncond_attention_mask = uncond_batch.attention_mask.to(device)
90
+
91
+ with torch.no_grad():
92
+ negative_prompt_embeds = self.text_encoder(
93
+ input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
94
+ )[0]
95
+
96
+ negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
97
+ uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
98
+
99
+ # For classifier free guidance, we need to do two forward passes.
100
+ # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
101
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
102
+ prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
103
+ boolean_prompt_mask = (prompt_mask == 1).to(device)
104
+
105
+ return prompt_embeds, boolean_prompt_mask
106
+
107
+ def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device):
108
+ shape = (batch_size, num_channels_latents, 256, 16)
109
+ latents = randn_tensor(shape, generator=None, device=device, dtype=dtype)
110
+ # scale the initial noise by the standard deviation required by the scheduler
111
+ latents = latents * inference_scheduler.init_noise_sigma
112
+ return latents
113
+
114
+ @torch.no_grad()
115
+ def inference(self, prompt, inference_scheduler, num_steps=20, guidance_scale=3, num_samples_per_prompt=1,
116
+ disable_progress=True):
117
+ device = self.text_encoder.device
118
+ classifier_free_guidance = guidance_scale > 1.0
119
+ batch_size = len(prompt) * num_samples_per_prompt
120
+
121
+ if classifier_free_guidance:
122
+ prompt_embeds, boolean_prompt_mask = self._encode_text_classifier_free(prompt, num_samples_per_prompt)
123
+ else:
124
+ prompt_embeds, boolean_prompt_mask = self._encode_text(prompt)
125
+ prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
126
+ boolean_prompt_mask = boolean_prompt_mask.repeat_interleave(num_samples_per_prompt, 0)
127
+
128
+ inference_scheduler.set_timesteps(num_steps, device=device)
129
+ timesteps = inference_scheduler.timesteps
130
+
131
+ num_channels_latents = self.unet.config.in_channels
132
+ latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, prompt_embeds.dtype, device)
133
+
134
+ num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order
135
+ progress_bar = tqdm(range(num_steps), disable=disable_progress)
136
+
137
+ for i, t in enumerate(timesteps):
138
+ # expand the latents if we are doing classifier free guidance
139
+ latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
140
+ latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t)
141
+
142
+ noise_pred = self.unet(
143
+ latent_model_input, t, encoder_hidden_states=prompt_embeds,
144
+ encoder_attention_mask=boolean_prompt_mask
145
+ ).sample
146
+
147
+ # perform guidance
148
+ if classifier_free_guidance:
149
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
150
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
151
+
152
+ # compute the previous noisy sample x_t -> x_t-1
153
+ latents = inference_scheduler.step(noise_pred, t, latents).prev_sample
154
+
155
+ # call the callback, if provided
156
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0):
157
+ progress_bar.update(1)
158
+
159
+ return latents
160
+
161
+ @torch.no_grad()
162
+ def __call__(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
163
+ """ Genrate audio for a single prompt string. """
164
+ with torch.no_grad():
165
+ latents = self.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
166
+ mel = self.vae.decode_first_stage(latents)
167
+ wave = self.vae.decode_to_waveform(mel)
168
+
169
+
170
+ return AudioPipelineOutput(audios=wave)
171
+
172
+
173
  # Automatic device detection
174
  if torch.cuda.is_available():
175
  device_type = "cuda"
 
238
  # Initialize TANGO
239
 
240
  tango = Tango(device="cpu")
241
+
242
+ pipe = Tango2Pipeline(vae=tango.vae,
243
+ text_encoder=tango.model.text_encoder,
244
+ tokenizer=tango.model.tokenizer,
245
+ unet=tango.model.unet,
246
+ scheduler=tango.scheduler
247
+ )
248
+ pipe.to(device)
249
+ #tango.vae.to(device_type)
250
+ #tango.stft.to(device_type)
251
+ #tango.model.to(device_type)
252
 
253
  @spaces.GPU(duration=60)
254
  def gradio_generate(prompt, output_format, steps, guidance):
255
+ output_wave = pipe(prompt,steps,guidance) ## Using pipeliine automatically uses flash attention for torch2.0 above
256
+ #output_wave = tango.generate(prompt, steps, guidance)
257
  # output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav"
258
  output_filename = "temp.wav"
259
  wavio.write(output_filename, output_wave, rate=16000, sampwidth=2)