File size: 4,063 Bytes
23dd469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torchaudio
import scipy.io.wavfile
from transformers import AutoProcessor, SeamlessM4Tv2Model
from pathlib import Path
from typing import Optional, Union

class SeamlessTranslator:
    """
    A wrapper class for Facebook's SeamlessM4T translation model.
    Handles both text-to-speech and speech-to-speech translation.
    """
    
    def __init__(self, model_name: str = "facebook/seamless-m4t-v2-large"):
        """
        Initialize the translator with the specified model.
        
        Args:
            model_name (str): Name of the model to use
        """
        try:
            self.processor = AutoProcessor.from_pretrained(model_name)
            self.model = SeamlessM4Tv2Model.from_pretrained(model_name)
            self.sample_rate = self.model.config.sampling_rate
        except Exception as e:
            raise RuntimeError(f"Failed to initialize model: {str(e)}")

    def translate_text(self, text: str, src_lang: str, tgt_lang: str) -> numpy.ndarray:
        """
        Translate text to speech in the target language.
        
        Args:
            text (str): Input text to translate
            src_lang (str): Source language code (e.g., 'eng')
            tgt_lang (str): Target language code (e.g., 'rus')
            
        Returns:
            numpy.ndarray: Audio waveform array
        """
        try:
            inputs = self.processor(text=text, src_lang=src_lang, return_tensors="pt")
            audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze()
            return audio_array
        except Exception as e:
            raise RuntimeError(f"Text translation failed: {str(e)}")

    def translate_audio(self, audio_path: Union[str, Path], tgt_lang: str) -> numpy.ndarray:
        """
        Translate audio to speech in the target language.
        
        Args:
            audio_path (str or Path): Path to input audio file
            tgt_lang (str): Target language code (e.g., 'rus')
            
        Returns:
            numpy.ndarray: Audio waveform array
        """
        try:
            # Load and resample audio
            audio, orig_freq = torchaudio.load(audio_path)
            audio = torchaudio.functional.resample(
                audio, 
                orig_freq=orig_freq, 
                new_freq=16_000
            )
            
            # Process and generate translation
            inputs = self.processor(audios=audio, return_tensors="pt")
            audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze()
            return audio_array
        except Exception as e:
            raise RuntimeError(f"Audio translation failed: {str(e)}")

    def save_audio(self, audio_array: numpy.ndarray, output_path: Union[str, Path]) -> None:
        """
        Save an audio array to a WAV file.
        
        Args:
            audio_array (numpy.ndarray): Audio data to save
            output_path (str or Path): Path where to save the WAV file
        """
        try:
            scipy.io.wavfile.write(
                output_path,
                rate=self.sample_rate,
                data=audio_array
            )
        except Exception as e:
            raise RuntimeError(f"Failed to save audio: {str(e)}")

def main():
    """Example usage of the SeamlessTranslator class."""
    try:
        # Initialize translator
        translator = SeamlessTranslator()

        # Example text translation
        text_audio = translator.translate_text(
            text="Hello, my dog is cute",
            src_lang="eng",
            tgt_lang="rus"
        )
        translator.save_audio(text_audio, "output_from_text.wav")

        # Example audio translation
        audio_audio = translator.translate_audio(
            audio_path="input_audio.wav",
            tgt_lang="rus"
        )
        translator.save_audio(audio_audio, "output_from_audio.wav")

    except Exception as e:
        print(f"Translation failed: {str(e)}")

if __name__ == "__main__":
    main()