Fabrice-TIERCELIN commited on
Commit
32614b8
·
verified ·
1 Parent(s): 9e23cf9

Fix indentation

Browse files
Files changed (1) hide show
  1. app.py +17 -4
app.py CHANGED
@@ -9,6 +9,7 @@ from huggingface_hub import snapshot_download
9
  from models import AudioDiffusion, DDPMScheduler
10
  from audioldm.audio.stft import TacotronSTFT
11
  from audioldm.variational_autoencoder import AutoencoderKL
 
12
 
13
  # Automatic device detection
14
  if torch.cuda.is_available():
@@ -55,7 +56,7 @@ class Tango:
55
  def generate(self, prompt, steps = 100, guidance = 3, samples = 1, disable_progress = True):
56
  # Generate audio for a single prompt string
57
  with torch.no_grad():
58
- latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress = disable_progress, length = 20)
59
  mel = self.vae.decode_first_stage(latents)
60
  wave = self.vae.decode_to_waveform(mel)
61
  return wave
@@ -112,18 +113,30 @@ def text2audio(
112
  start = time.time()
113
  output_wave = tango.generate(prompt, steps, guidance, output_number)
114
 
115
- output_filename_1 = "tmp1_.wav"
116
  wavio.write(output_filename_1, output_wave[0], rate = 16000, sampwidth = 2)
117
 
 
 
 
 
118
  if (2 <= output_number):
119
- output_filename_2 = "tmp2_.wav"
120
  wavio.write(output_filename_2, output_wave[1], rate = 16000, sampwidth = 2)
 
 
 
 
121
  else:
122
  output_filename_2 = None
123
 
124
  if (output_number == 3):
125
- output_filename_3 = "tmp3_.wav"
126
  wavio.write(output_filename_3, output_wave[2], rate = 16000, sampwidth = 2)
 
 
 
 
127
  else:
128
  output_filename_3 = None
129
 
 
9
  from models import AudioDiffusion, DDPMScheduler
10
  from audioldm.audio.stft import TacotronSTFT
11
  from audioldm.variational_autoencoder import AutoencoderKL
12
+ from pydub import AudioSegment
13
 
14
  # Automatic device detection
15
  if torch.cuda.is_available():
 
56
  def generate(self, prompt, steps = 100, guidance = 3, samples = 1, disable_progress = True):
57
  # Generate audio for a single prompt string
58
  with torch.no_grad():
59
+ latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress = disable_progress)
60
  mel = self.vae.decode_first_stage(latents)
61
  wave = self.vae.decode_to_waveform(mel)
62
  return wave
 
113
  start = time.time()
114
  output_wave = tango.generate(prompt, steps, guidance, output_number)
115
 
116
+ output_filename_1 = "tmp1.wav"
117
  wavio.write(output_filename_1, output_wave[0], rate = 16000, sampwidth = 2)
118
 
119
+ if (output_format == "mp3"):
120
+ AudioSegment.from_wav("tmp1.wav").export("tmp1.mp3", format = "mp3")
121
+ output_filename_1 = "tmp1.mp3"
122
+
123
  if (2 <= output_number):
124
+ output_filename_2 = "tmp2.wav"
125
  wavio.write(output_filename_2, output_wave[1], rate = 16000, sampwidth = 2)
126
+
127
+ if (output_format == "mp3"):
128
+ AudioSegment.from_wav("tmp2.wav").export("tmp2.mp3", format = "mp3")
129
+ output_filename_2 = "tmp2.mp3"
130
  else:
131
  output_filename_2 = None
132
 
133
  if (output_number == 3):
134
+ output_filename_3 = "tmp3.wav"
135
  wavio.write(output_filename_3, output_wave[2], rate = 16000, sampwidth = 2)
136
+
137
+ if (output_format == "mp3"):
138
+ AudioSegment.from_wav("tmp3.wav").export("tmp3.mp3", format = "mp3")
139
+ output_filename_3 = "tmp3.mp3"
140
  else:
141
  output_filename_3 = None
142