TieIncred commited on
Commit
05ee4a8
1 Parent(s): ea819b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -17
app.py CHANGED
@@ -2,19 +2,31 @@ import gradio as gr
2
  import numpy as np
3
  import torch
4
  from datasets import load_dataset
5
-
6
  from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor, pipeline
 
7
 
8
 
9
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
 
11
  # load speech translation checkpoint
12
- asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
 
 
 
 
 
13
 
14
  # load text-to-speech checkpoint and speaker embeddings
15
- processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
 
 
 
 
 
 
 
16
 
17
- model = SpeechT5ForTextToSpeech.from_pretrained("Matthijs/mms-tts-fra").to(device)
18
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
19
 
20
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
@@ -22,23 +34,40 @@ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze
22
 
23
 
24
  def translate(audio):
25
- outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "transcribe", "language": "fr"})
26
- return outputs["text"]
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def synthesise(text):
30
- inputs = processor(text=text, return_tensors="pt")
31
- try:
32
- speech = model.generate_speech(inputs["input_ids"].to(device), speaker_embeddings.to(device))
33
- except Exception as e:
34
- print(f"Error occurred with the following input_ids: {inputs['input_ids']}")
35
- print(traceback.format_exc())
36
- raise e
37
- return speech.cpu()
 
 
 
 
 
 
38
 
39
 
40
  def speech_to_speech_translation(audio):
41
  translated_text = translate(audio)
 
42
  synthesised_speech = synthesise(translated_text)
43
  synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
44
  return 16000, synthesised_speech
@@ -46,9 +75,8 @@ def speech_to_speech_translation(audio):
46
 
47
  title = "Cascaded STST"
48
  description = """
49
- Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in French. Demo uses OpenAI's [Whisper Base](https://huggingface.co/openai/whisper-base) model for speech translation, and Microsoft's
50
  [SpeechT5 TTS](https://huggingface.co/microsoft/speecht5_tts) model for text-to-speech:
51
-
52
  ![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation")
53
  """
54
 
@@ -75,4 +103,4 @@ with demo:
75
  gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
76
 
77
 
78
- demo.launch()
 
2
  import numpy as np
3
  import torch
4
  from datasets import load_dataset
5
+ import librosa
6
  from transformers import SpeechT5ForTextToSpeech, SpeechT5HifiGan, SpeechT5Processor, pipeline
7
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
8
 
9
 
10
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
 
12
  # load speech translation checkpoint
13
+ # asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
14
+ asr_processor = WhisperProcessor.from_pretrained("openai/whisper-base")
15
+ asr_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").to(device)
16
+ asr_forced_decoder_ids = asr_processor.get_decoder_prompt_ids(language="dutch", task="transcribe")
17
+
18
+
19
 
20
  # load text-to-speech checkpoint and speaker embeddings
21
+ if 0:
22
+ processor = SpeechT5Processor.from_pretrained("sanchit-gandhi/speecht5_tts_vox_nl")
23
+
24
+ model = SpeechT5ForTextToSpeech.from_pretrained("sanchit-gandhi/speecht5_tts_vox_nl").to(device)
25
+ if 1:
26
+ from transformers import VitsModel, VitsTokenizer
27
+ model = VitsModel.from_pretrained("Matthijs/mms-tts-fra")
28
+ tokenizer = VitsTokenizer.from_pretrained("Matthijs/mms-tts-fra")
29
 
 
30
  vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
31
 
32
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
 
34
 
35
 
36
  def translate(audio):
37
+ if 0:
38
+ outputs = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"language":"dutch", "task":"transcribe"})
39
+ return outputs["text"]
40
+ else:
41
+
42
+ x, sr = librosa.load(audio)
43
+ input_features = asr_processor(x, sampling_rate=16000, return_tensors="pt").input_features
44
+ predicted_ids = asr_model.generate(input_features, forced_decoder_ids=asr_forced_decoder_ids)
45
+ # decode token ids to text
46
+ transcription = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)
47
+ return transcription
48
+
49
 
50
 
51
  def synthesise(text):
52
+ if 0:
53
+ inputs = processor(text=text, return_tensors="pt")
54
+ speech = model.generate_speech(inputs["input_ids"].to(device), speaker_embeddings.to(device), vocoder=vocoder)
55
+ return speech.cpu()
56
+ if 1:
57
+ inputs = tokenizer(text, return_tensors="pt")
58
+ input_ids = inputs["input_ids"]
59
+
60
+
61
+ with torch.no_grad():
62
+ outputs = model(input_ids)
63
+
64
+ speech = outputs.audio[0]
65
+ return speech.cpu()
66
 
67
 
68
  def speech_to_speech_translation(audio):
69
  translated_text = translate(audio)
70
+ print(translated_text)
71
  synthesised_speech = synthesise(translated_text)
72
  synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
73
  return 16000, synthesised_speech
 
75
 
76
  title = "Cascaded STST"
77
  description = """
78
+ Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in Dutch. Demo uses OpenAI's [Whisper Base](https://huggingface.co/openai/whisper-base) model for speech translation, and Microsoft's
79
  [SpeechT5 TTS](https://huggingface.co/microsoft/speecht5_tts) model for text-to-speech:
 
80
  ![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation")
81
  """
82
 
 
103
  gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
104
 
105
 
106
+ demo.launch()