voiced / pipeline.py
mostafaashahin's picture
Update pipeline.py
d914b65
raw
history blame contribute delete
No virus
3.53 kB
from typing import Dict
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
#define groups
#make sure that all phonemes covered in each group
g1 = ['p_alveolar','n_alveolar']
g2 = ['p_palatal','n_palatal']
g3 = ['p_dental','n_dental']
g4 = ['p_glottal','n_glottal']
g5 = ['p_labial','n_labial']
g6 = ['p_velar','n_velar']
g7 = ['p_anterior','n_anterior']
g8 = ['p_posterior','n_posterior']
g9 = ['p_retroflex','n_retroflex']
g10 = ['p_mid','n_mid']
g11 = ['p_high_v','n_high_v']
g12 = ['p_low','n_low']
g13 = ['p_front','n_front']
g14 = ['p_back','n_back']
g15 = ['p_central','n_central']
g16 = ['p_consonant','n_consonant']
g17 = ['p_sonorant','n_sonorant']
g18 = ['p_long','n_long']
g19 = ['p_short','n_short']
g20 = ['p_vowel','n_vowel']
g21 = ['p_semivowel','n_semivowel']
g22 = ['p_fricative','n_fricative']
g23 = ['p_nasal','n_nasal']
g24 = ['p_stop','n_stop']
g25 = ['p_approximant','n_approximant']
g26 = ['p_affricate','n_affricate']
g27 = ['p_liquid','n_liquid']
g28 = ['p_continuant','n_continuant']
g29 = ['p_monophthong','n_monophthong']
g30 = ['p_diphthong','n_diphthong']
g31 = ['p_round','n_round']
g32 = ['p_voiced','n_voiced']
g33 = ['p_bilabial','n_bilabial']
g34 = ['p_coronal','n_coronal']
g35 = ['p_dorsal','n_dorsal']
groups = [g1,g2,g3,g4,g5,g6,g7,g8,g9,g10,g11,g12,g13,g14,g15,g16,g17,g18,g19,g20,g21,g22,g23,g24,g25,g26,g27,g28,g29,g30,g31,g32,g33,g34,g35]
class PreTrainedPipeline():
def __init__(self, path=""):
# IMPLEMENT_THIS
# Preload all the elements you are going to need at inference.
# For instance your model, processors, tokenizer that might be needed.
# This function is only called once, so do all the heavy processing I/O here"""
print('Init')
self.sampling_rate = 16000
self.processor = Wav2Vec2Processor.from_pretrained(path)
self.model = Wav2Vec2ForCTC.from_pretrained(path)
self.group_ids = [sorted(self.processor.tokenizer.convert_tokens_to_ids(group)) for group in groups]
self.group_ids = [dict([(x[0]+1,x[1]) for x in list(enumerate(g))]) for g in self.group_ids] #This is the inversion of the one used in training as here we need to map prediction back to original tokens
def __call__(self, inputs: np.array)-> Dict[str, str]:
"""
Args:
inputs (:obj:`np.array`):
The raw waveform of audio received. By default at 16KHz.
Return:
A :obj:`dict`:. The object return should be liked {"text": "XXX"} containing
the detected text from the input audio.
"""
# IMPLEMENT_THIS
assert 1==0, "call"
input_values = self.processor(audio=inputs, sampling_rate=self.sampling_rate, return_tensors="pt").input_values
if torch.cuda.is_available():
self.model.to("cuda")
input_values = input_values.to("cuda")
with torch.no_grad():
logits = self.model(input_values).logits
mask = torch.zeros(logits.size()[2], dtype = torch.bool)
mask[0] = True
mask[list(self.group_ids[31].values())] = True
logits_g = logits[:,:,mask]
pred_ids = torch.argmax(logits_g,dim=-1)
pred_ids = pred_ids.cpu().apply_(lambda x: self.group_ids[31].get(x,x))
pred = self.processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
pred = pred.replace('p_','+').replace('n_', '-')
return({"text":pred})