Spaces:
Runtime error
Runtime error
Raghavan1988
commited on
Commit
·
b67fe1a
1
Parent(s):
f86940b
Adding the predict method from facebook/seamless_m4t
Browse files
app.py
CHANGED
@@ -24,6 +24,54 @@ DEFAULT_TARGET_LANGUAGE = "English"
|
|
24 |
AUDIO_SAMPLE_RATE = 16000.0
|
25 |
MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def process_image_with_openai(image):
|
28 |
image_data = convert_image_to_required_format(image)
|
29 |
openai_api_key = config('OPENAI_API_KEY') # Make sure to have this in your .env file
|
|
|
24 |
AUDIO_SAMPLE_RATE = 16000.0
|
25 |
MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
|
26 |
|
27 |
+
|
28 |
+
def predict(
|
29 |
+
task_name: str,
|
30 |
+
audio_source: str,
|
31 |
+
input_audio_mic: str | None,
|
32 |
+
input_audio_file: str | None,
|
33 |
+
input_text: str | None,
|
34 |
+
source_language: str | None,
|
35 |
+
target_language: str,
|
36 |
+
) -> tuple[tuple[int, np.ndarray] | None, str]:
|
37 |
+
task_name = task_name.split()[0]
|
38 |
+
source_language_code = LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
|
39 |
+
target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
|
40 |
+
|
41 |
+
if task_name in ["S2ST", "S2TT", "ASR"]:
|
42 |
+
if audio_source == "microphone":
|
43 |
+
input_data = input_audio_mic
|
44 |
+
else:
|
45 |
+
input_data = input_audio_file
|
46 |
+
|
47 |
+
arr, org_sr = torchaudio.load(input_data)
|
48 |
+
new_arr = torchaudio.functional.resample(arr, orig_freq=org_sr, new_freq=AUDIO_SAMPLE_RATE)
|
49 |
+
max_length = int(MAX_INPUT_AUDIO_LENGTH * AUDIO_SAMPLE_RATE)
|
50 |
+
if new_arr.shape[1] > max_length:
|
51 |
+
new_arr = new_arr[:, :max_length]
|
52 |
+
gr.Warning(f"Input audio is too long. Only the first {MAX_INPUT_AUDIO_LENGTH} seconds is used.")
|
53 |
+
|
54 |
+
|
55 |
+
input_data = processor(audios = new_arr, sampling_rate=AUDIO_SAMPLE_RATE, return_tensors="pt").to(device)
|
56 |
+
else:
|
57 |
+
input_data = processor(text = input_text, src_lang=source_language_code, return_tensors="pt").to(device)
|
58 |
+
|
59 |
+
|
60 |
+
if task_name in ["S2TT", "T2TT"]:
|
61 |
+
tokens_ids = model.generate(**input_data, generate_speech=False, tgt_lang=target_language_code, num_beams=5, do_sample=True)[0].cpu().squeeze().detach().tolist()
|
62 |
+
else:
|
63 |
+
output = model.generate(**input_data, return_intermediate_token_ids=True, tgt_lang=target_language_code, num_beams=5, do_sample=True, spkr_id=LANG_TO_SPKR_ID[target_language_code][0])
|
64 |
+
|
65 |
+
waveform = output.waveform.cpu().squeeze().detach().numpy()
|
66 |
+
tokens_ids = output.sequences.cpu().squeeze().detach().tolist()
|
67 |
+
|
68 |
+
text_out = processor.decode(tokens_ids, skip_special_tokens=True)
|
69 |
+
|
70 |
+
if task_name in ["S2ST", "T2ST"]:
|
71 |
+
return (AUDIO_SAMPLE_RATE, waveform), text_out
|
72 |
+
else:
|
73 |
+
return None, text_out
|
74 |
+
|
75 |
def process_image_with_openai(image):
|
76 |
image_data = convert_image_to_required_format(image)
|
77 |
openai_api_key = config('OPENAI_API_KEY') # Make sure to have this in your .env file
|