|
|
|
import whisper |
|
from transformers import WhisperProcessor, WhisperForConditionalGeneration |
|
from googletrans import Translator |
|
import torch |
|
|
|
def load_models(): |
|
lang_detector = whisper.load_model("small") |
|
tamil_processor = WhisperProcessor.from_pretrained("Lingalingeswaran/whisper-small-ta") |
|
tamil_model = WhisperForConditionalGeneration.from_pretrained("Lingalingeswaran/whisper-small-ta") |
|
sinhala_processor = WhisperProcessor.from_pretrained("Lingalingeswaran/whisper-small-sinhala") |
|
sinhala_model = WhisperForConditionalGeneration.from_pretrained("Lingalingeswaran/whisper-small-sinhala") |
|
english_model = whisper.load_model("small") |
|
return lang_detector, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model |
|
|
|
def detect_language(audio_file, lang_detector): |
|
audio = whisper.load_audio(audio_file) |
|
audio = whisper.pad_or_trim(audio) |
|
mel = whisper.log_mel_spectrogram(audio).to(lang_detector.device) |
|
_, probs = lang_detector.detect_language(mel) |
|
return max(probs, key=probs.get) |
|
|
|
def transcribe_audio(audio_file, detected_lang, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model): |
|
if detected_lang == "ta": |
|
processor, model = tamil_processor, tamil_model |
|
elif detected_lang == "si": |
|
processor, model = sinhala_processor, sinhala_model |
|
else: |
|
model = english_model |
|
return model.transcribe(audio_file)["text"] |
|
|
|
audio = whisper.load_audio(audio_file) |
|
inputs = processor(audio, return_tensors="pt", sampling_rate=16000) |
|
with torch.no_grad(): |
|
predicted_ids = model.generate(**inputs) |
|
return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] |
|
|
|
def translate_to_english(text): |
|
return Translator().translate(text, dest="en").text |
|
|
|
def full_pipeline(audio_file): |
|
lang_detector, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model = load_models() |
|
detected_lang = detect_language(audio_file, lang_detector) |
|
transcription = transcribe_audio(audio_file, detected_lang, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model) |
|
return translate_to_english(transcription) |
|
|