leofltt commited on
Commit
c88b4e1
1 Parent(s): fd6ca3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -16
app.py CHANGED
@@ -8,32 +8,33 @@ from transformers import BarkModel, BarkProcessor
8
 
9
  from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
10
 
11
- SAMPLE_RATE = 16000
12
-
13
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
 
15
- # asr_model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
16
- # asr_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
 
 
17
 
18
- asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
19
 
 
 
20
 
21
- bark_model = BarkModel.from_pretrained("suno/bark")
22
- bark_processor = BarkProcessor.from_pretrained("suno/bark")
23
 
24
 
25
  def translate(audio):
26
- # inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt")
27
- # generated_ids = asr_model.generate(inputs["input_features"],attention_mask=inputs["attention_mask"],
28
- # forced_bos_token_id=asr_processor.tokenizer.lang_code_to_id["it"],)
29
- # translation = asr_processor.batch_decode(generated_ids, skip_special_tokens=True)
30
- translation = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "transcribe", "language": "it"})
31
- return translation["text"]
32
 
33
 
34
  def synthesise(text):
35
  inputs = bark_processor(text=text, voice_preset="v2/it_speaker_4",return_tensors="pt")
36
  speech = bark_model.generate(**inputs, do_sample=True)
 
37
  return speech
38
 
39
 
@@ -41,7 +42,7 @@ def speech_to_speech_translation(audio):
41
  translated_text = translate(audio)
42
  synthesised_speech = synthesise(translated_text)
43
  synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
44
- return SAMPLE_RATE, synthesised_speech
45
 
46
 
47
  title = "Cascaded STST"
@@ -56,7 +57,7 @@ demo = gr.Blocks()
56
 
57
  mic_translate = gr.Interface(
58
  fn=speech_to_speech_translation,
59
- inputs=gr.Audio(source="microphone", type="filepath"),
60
  outputs=gr.Audio(label="Generated Speech", type="numpy"),
61
  title=title,
62
  description=description,
@@ -64,7 +65,7 @@ mic_translate = gr.Interface(
64
 
65
  file_translate = gr.Interface(
66
  fn=speech_to_speech_translation,
67
- inputs=gr.Audio(source="upload", type="filepath"),
68
  outputs=gr.Audio(label="Generated Speech", type="numpy"),
69
  examples=[["./example.wav"]],
70
  title=title,
 
8
 
9
  from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
10
 
 
 
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
13
+ # asr_pipe = pipeline("automatic-speech-recognition", model="facebook/s2t-medium-mustc-multilingual-st", device=device)
14
+
15
+ asr_model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
16
+ asr_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
17
 
18
+ asr_model.to(device)
19
 
20
+ bark_model = BarkModel.from_pretrained("suno/bark-small")
21
+ bark_processor = BarkProcessor.from_pretrained("suno/bark-small")
22
 
23
+ bark_model.to(device)
 
24
 
25
 
26
  def translate(audio):
27
+ inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt")
28
+ generated_ids = asr_model.generate(inputs["input_features"],attention_mask=inputs["attention_mask"],
29
+ forced_bos_token_id=asr_processor.tokenizer.lang_code_to_id['it'],)
30
+ translation = asr_processor.batch_decode(generated_ids, skip_special_tokens=True)
31
+ return translation
 
32
 
33
 
34
  def synthesise(text):
35
  inputs = bark_processor(text=text, voice_preset="v2/it_speaker_4",return_tensors="pt")
36
  speech = bark_model.generate(**inputs, do_sample=True)
37
+ speech = speech.cpu().numpy().squeeze()
38
  return speech
39
 
40
 
 
42
  translated_text = translate(audio)
43
  synthesised_speech = synthesise(translated_text)
44
  synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
45
+ return 16000, synthesised_speech
46
 
47
 
48
  title = "Cascaded STST"
 
57
 
58
  mic_translate = gr.Interface(
59
  fn=speech_to_speech_translation,
60
+ inputs=gr.Audio(sources="microphone", type="filepath"),
61
  outputs=gr.Audio(label="Generated Speech", type="numpy"),
62
  title=title,
63
  description=description,
 
65
 
66
  file_translate = gr.Interface(
67
  fn=speech_to_speech_translation,
68
+ inputs=gr.Audio(sources="upload", type="filepath"),
69
  outputs=gr.Audio(label="Generated Speech", type="numpy"),
70
  examples=[["./example.wav"]],
71
  title=title,