leofltt commited on
Commit
fd6ca3f
1 Parent(s): 9ff0018

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -8,21 +8,27 @@ from transformers import BarkModel, BarkProcessor
8
 
9
  from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
10
 
 
 
11
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
12
 
13
- asr_model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
14
- asr_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
 
 
 
15
 
16
  bark_model = BarkModel.from_pretrained("suno/bark")
17
  bark_processor = BarkProcessor.from_pretrained("suno/bark")
18
 
19
 
20
  def translate(audio):
21
- inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt")
22
- generated_ids = asr_model.generate(inputs["input_features"],attention_mask=inputs["attention_mask"],
23
- forced_bos_token_id=asr_processor.tokenizer.lang_code_to_id["it"],)
24
- translation = asr_processor.batch_decode(generated_ids, skip_special_tokens=True)
25
- return translation
 
26
 
27
 
28
  def synthesise(text):
@@ -35,7 +41,7 @@ def speech_to_speech_translation(audio):
35
  translated_text = translate(audio)
36
  synthesised_speech = synthesise(translated_text)
37
  synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
38
- return 16000, synthesised_speech
39
 
40
 
41
  title = "Cascaded STST"
 
8
 
9
  from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
10
 
11
+ SAMPLE_RATE = 16000
12
+
13
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
 
15
+ # asr_model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
16
+ # asr_processor = Speech2TextProcessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
17
+
18
+ asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
19
+
20
 
21
  bark_model = BarkModel.from_pretrained("suno/bark")
22
  bark_processor = BarkProcessor.from_pretrained("suno/bark")
23
 
24
 
25
  def translate(audio):
26
+ # inputs = asr_processor(audio, sampling_rate=16000, return_tensors="pt")
27
+ # generated_ids = asr_model.generate(inputs["input_features"],attention_mask=inputs["attention_mask"],
28
+ # forced_bos_token_id=asr_processor.tokenizer.lang_code_to_id["it"],)
29
+ # translation = asr_processor.batch_decode(generated_ids, skip_special_tokens=True)
30
+ translation = asr_pipe(audio, max_new_tokens=256, generate_kwargs={"task": "transcribe", "language": "it"})
31
+ return translation["text"]
32
 
33
 
34
  def synthesise(text):
 
41
  translated_text = translate(audio)
42
  synthesised_speech = synthesise(translated_text)
43
  synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
44
+ return SAMPLE_RATE, synthesised_speech
45
 
46
 
47
  title = "Cascaded STST"