Spaces:
Running
Running
File size: 4,063 Bytes
23dd469 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import torch
import torchaudio
import scipy.io.wavfile
from transformers import AutoProcessor, SeamlessM4Tv2Model
from pathlib import Path
from typing import Optional, Union
class SeamlessTranslator:
"""
A wrapper class for Facebook's SeamlessM4T translation model.
Handles both text-to-speech and speech-to-speech translation.
"""
def __init__(self, model_name: str = "facebook/seamless-m4t-v2-large"):
"""
Initialize the translator with the specified model.
Args:
model_name (str): Name of the model to use
"""
try:
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = SeamlessM4Tv2Model.from_pretrained(model_name)
self.sample_rate = self.model.config.sampling_rate
except Exception as e:
raise RuntimeError(f"Failed to initialize model: {str(e)}")
def translate_text(self, text: str, src_lang: str, tgt_lang: str) -> numpy.ndarray:
"""
Translate text to speech in the target language.
Args:
text (str): Input text to translate
src_lang (str): Source language code (e.g., 'eng')
tgt_lang (str): Target language code (e.g., 'rus')
Returns:
numpy.ndarray: Audio waveform array
"""
try:
inputs = self.processor(text=text, src_lang=src_lang, return_tensors="pt")
audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze()
return audio_array
except Exception as e:
raise RuntimeError(f"Text translation failed: {str(e)}")
def translate_audio(self, audio_path: Union[str, Path], tgt_lang: str) -> numpy.ndarray:
"""
Translate audio to speech in the target language.
Args:
audio_path (str or Path): Path to input audio file
tgt_lang (str): Target language code (e.g., 'rus')
Returns:
numpy.ndarray: Audio waveform array
"""
try:
# Load and resample audio
audio, orig_freq = torchaudio.load(audio_path)
audio = torchaudio.functional.resample(
audio,
orig_freq=orig_freq,
new_freq=16_000
)
# Process and generate translation
inputs = self.processor(audios=audio, return_tensors="pt")
audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze()
return audio_array
except Exception as e:
raise RuntimeError(f"Audio translation failed: {str(e)}")
def save_audio(self, audio_array: numpy.ndarray, output_path: Union[str, Path]) -> None:
"""
Save an audio array to a WAV file.
Args:
audio_array (numpy.ndarray): Audio data to save
output_path (str or Path): Path where to save the WAV file
"""
try:
scipy.io.wavfile.write(
output_path,
rate=self.sample_rate,
data=audio_array
)
except Exception as e:
raise RuntimeError(f"Failed to save audio: {str(e)}")
def main():
"""Example usage of the SeamlessTranslator class."""
try:
# Initialize translator
translator = SeamlessTranslator()
# Example text translation
text_audio = translator.translate_text(
text="Hello, my dog is cute",
src_lang="eng",
tgt_lang="rus"
)
translator.save_audio(text_audio, "output_from_text.wav")
# Example audio translation
audio_audio = translator.translate_audio(
audio_path="input_audio.wav",
tgt_lang="rus"
)
translator.save_audio(audio_audio, "output_from_audio.wav")
except Exception as e:
print(f"Translation failed: {str(e)}")
if __name__ == "__main__":
main() |