ylacombe commited on
Commit
8d3d7b9
1 Parent(s): 6ca328f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -6
app.py CHANGED
@@ -13,10 +13,13 @@ device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
 
14
 
15
  repo_id = "parler-tts/parler-tts-mini-v1"
16
- repo_id_large = "ylacombe/parler-large-v1-og"
 
17
 
18
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
19
  model_large = ParlerTTSForConditionalGeneration.from_pretrained(repo_id_large).to(device)
 
 
20
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
21
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
22
 
@@ -76,19 +79,23 @@ def preprocess(text):
76
  return text
77
 
78
  @spaces.GPU
79
- def gen_tts(text, description, use_large=False):
80
  inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
81
  prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
82
 
83
  set_seed(SEED)
84
- if use_large:
85
  generation = model_large.generate(
86
  input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
87
  )
88
- else:
89
  generation = model.generate(
90
  input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
91
  )
 
 
 
 
92
  audio_arr = generation.cpu().numpy().squeeze()
93
 
94
  return SAMPLE_RATE, audio_arr
@@ -163,12 +170,12 @@ with gr.Blocks(css=css) as block:
163
  with gr.Column():
164
  input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
165
  description = gr.Textbox(label="Description", lines=2, value=default_description, elem_id="input_description")
166
- use_large = gr.Checkbox(value=False, label="Use Large checkpoint", info="Generate with Parler-TTS Large v1 instead of Mini v1 - Better but way slower.")
167
  run_button = gr.Button("Generate Audio", variant="primary")
168
  with gr.Column():
169
  audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
170
 
171
- inputs = [input_text, description, use_large]
172
  outputs = [audio_out]
173
  run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
174
  gr.Examples(examples=examples, fn=gen_tts, inputs=inputs, outputs=outputs, cache_examples=True)
 
13
 
14
 
15
  repo_id = "parler-tts/parler-tts-mini-v1"
16
+ repo_id_large = "parler-tts/parler-tts-large-v1"
17
+ repo_id_tiny = "parler-tts/parler-tts-tiny-v1"
18
 
19
  model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
20
  model_large = ParlerTTSForConditionalGeneration.from_pretrained(repo_id_large).to(device)
21
+ model_tiny = ParlerTTSForConditionalGeneration.from_pretrained(repo_id_tiny).to(device)
22
+
23
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
24
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
25
 
 
79
  return text
80
 
81
  @spaces.GPU
82
+ def gen_tts(text, description, version_to_use=False):
83
  inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
84
  prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
85
 
86
  set_seed(SEED)
87
+ if version_to_use=="Large":
88
  generation = model_large.generate(
89
  input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
90
  )
91
+ elif version_to_use=="Miny":
92
  generation = model.generate(
93
  input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
94
  )
95
+ else:
96
+ generation = model_tiny.generate(
97
+ input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
98
+ )
99
  audio_arr = generation.cpu().numpy().squeeze()
100
 
101
  return SAMPLE_RATE, audio_arr
 
170
  with gr.Column():
171
  input_text = gr.Textbox(label="Input Text", lines=2, value=default_text, elem_id="input_text")
172
  description = gr.Textbox(label="Description", lines=2, value=default_description, elem_id="input_description")
173
+ version_to_use = gr.Radio(["Tiny", "Mini", "Large"], value="Mini", label="Checkpoint to use", info="The larger the model, the better it is, at the cost of speed.")
174
  run_button = gr.Button("Generate Audio", variant="primary")
175
  with gr.Column():
176
  audio_out = gr.Audio(label="Parler-TTS generation", type="numpy", elem_id="audio_out")
177
 
178
+ inputs = [input_text, description, version_to_use]
179
  outputs = [audio_out]
180
  run_button.click(fn=gen_tts, inputs=inputs, outputs=outputs, queue=True)
181
  gr.Examples(examples=examples, fn=gen_tts, inputs=inputs, outputs=outputs, cache_examples=True)