FrexG's picture
Create asr.py
b82438a
raw
history blame
No virus
896 Bytes
import torch
import torchaudio
import torchaudio.functional as AF
from transformers import Wav2Vec2ForCTC, AutoProcessor
from pydub import AudioSegment
from pydub.silence import split_on_silence
class Transcribe:
def __init__(self, freq: float = 16000.0) -> None:
self.freq = freq
self.model_id = "facebook/mms-1b-fl102"
self.processor = AutoProcessor.from_pretrained(self.model_id)
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_id)
@torch.inference_mode()
def __call__(self, audio_tensor: torch.tensor, lang: str = "amh"):
print(lang)
self.processor.tokenizer.set_target_lang(lang)
self.model.load_adapter(lang)
outputs = self.model(audio_tensor)
logits = outputs.logits
ids = torch.argmax(logits, dim=-1)[0]
decoded_token = self.processor.decode(ids)
return decoded_token