Spaces:
Running
Running
import torch | |
import torchaudio | |
import scipy.io.wavfile | |
from transformers import AutoProcessor, SeamlessM4Tv2Model | |
from pathlib import Path | |
from typing import Optional, Union | |
class SeamlessTranslator: | |
""" | |
A wrapper class for Facebook's SeamlessM4T translation model. | |
Handles both text-to-speech and speech-to-speech translation. | |
""" | |
def __init__(self, model_name: str = "facebook/seamless-m4t-v2-large"): | |
""" | |
Initialize the translator with the specified model. | |
Args: | |
model_name (str): Name of the model to use | |
""" | |
try: | |
self.processor = AutoProcessor.from_pretrained(model_name) | |
self.model = SeamlessM4Tv2Model.from_pretrained(model_name) | |
self.sample_rate = self.model.config.sampling_rate | |
except Exception as e: | |
raise RuntimeError(f"Failed to initialize model: {str(e)}") | |
def translate_text(self, text: str, src_lang: str, tgt_lang: str) -> numpy.ndarray: | |
""" | |
Translate text to speech in the target language. | |
Args: | |
text (str): Input text to translate | |
src_lang (str): Source language code (e.g., 'eng') | |
tgt_lang (str): Target language code (e.g., 'rus') | |
Returns: | |
numpy.ndarray: Audio waveform array | |
""" | |
try: | |
inputs = self.processor(text=text, src_lang=src_lang, return_tensors="pt") | |
audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze() | |
return audio_array | |
except Exception as e: | |
raise RuntimeError(f"Text translation failed: {str(e)}") | |
def translate_audio(self, audio_path: Union[str, Path], tgt_lang: str) -> numpy.ndarray: | |
""" | |
Translate audio to speech in the target language. | |
Args: | |
audio_path (str or Path): Path to input audio file | |
tgt_lang (str): Target language code (e.g., 'rus') | |
Returns: | |
numpy.ndarray: Audio waveform array | |
""" | |
try: | |
# Load and resample audio | |
audio, orig_freq = torchaudio.load(audio_path) | |
audio = torchaudio.functional.resample( | |
audio, | |
orig_freq=orig_freq, | |
new_freq=16_000 | |
) | |
# Process and generate translation | |
inputs = self.processor(audios=audio, return_tensors="pt") | |
audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze() | |
return audio_array | |
except Exception as e: | |
raise RuntimeError(f"Audio translation failed: {str(e)}") | |
def save_audio(self, audio_array: numpy.ndarray, output_path: Union[str, Path]) -> None: | |
""" | |
Save an audio array to a WAV file. | |
Args: | |
audio_array (numpy.ndarray): Audio data to save | |
output_path (str or Path): Path where to save the WAV file | |
""" | |
try: | |
scipy.io.wavfile.write( | |
output_path, | |
rate=self.sample_rate, | |
data=audio_array | |
) | |
except Exception as e: | |
raise RuntimeError(f"Failed to save audio: {str(e)}") | |
def main(): | |
"""Example usage of the SeamlessTranslator class.""" | |
try: | |
# Initialize translator | |
translator = SeamlessTranslator() | |
# Example text translation | |
text_audio = translator.translate_text( | |
text="Hello, my dog is cute", | |
src_lang="eng", | |
tgt_lang="rus" | |
) | |
translator.save_audio(text_audio, "output_from_text.wav") | |
# Example audio translation | |
audio_audio = translator.translate_audio( | |
audio_path="input_audio.wav", | |
tgt_lang="rus" | |
) | |
translator.save_audio(audio_audio, "output_from_audio.wav") | |
except Exception as e: | |
print(f"Translation failed: {str(e)}") | |
if __name__ == "__main__": | |
main() |