Spaces:
Sleeping
Sleeping
import fire | |
import logging | |
import sys, os | |
import yaml | |
import json | |
import torch | |
import librosa | |
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC | |
import transformers | |
import pandas as pd | |
logger = logging.getLogger(__name__) | |
# Setup logging | |
logger.setLevel(logging.ERROR) | |
console_handler = logging.StreamHandler() | |
formater = logging.Formatter(fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S",) | |
console_handler.setFormatter(formater) | |
console_handler.setLevel(logging.ERROR) | |
logger.addHandler(console_handler) | |
class transcribe_SA(): | |
def __init__(self, model_path, verbose=0): | |
if verbose == 0: | |
logger.setLevel(logging.ERROR) | |
transformers.logging.set_verbosity_error() | |
#console_handler.setLevel(logging.ERROR) | |
elif verbose == 1: | |
logger.setLevel(logging.WARNING) | |
transformers.logging.set_verbosity_warning() | |
#console_handler.setLevel(logging.WARNING) | |
else: | |
logger.setLevel(logging.INFO) | |
transformers.logging.set_verbosity_info() | |
#console_handler.setLevel(logging.INFO) | |
# Read YAML file | |
logger.info('Init Object') | |
if torch.cuda.is_available(): | |
self.accelerate = True | |
self.device = torch.device('cuda') | |
self.n_devices = torch.cuda.device_count() | |
assert self.n_devices == 1, 'Support only single GPU. Please use CUDA_VISIBLE_DEVICES=gpu_index if you have multiple gpus' #Currently support only single gpu | |
else: | |
self.device = torch.device('cpu') | |
self.n_devices = 1 | |
self.model_path = model_path | |
self.load_model() | |
self.get_available_attributes() | |
self.get_att_binary_group_indexs() | |
def load_model(self): | |
if not os.path.exists(self.model_path): | |
logger.error(f'Model file {self.model_path} is not exist') | |
raise FileNotFoundError | |
self.processor = Wav2Vec2Processor.from_pretrained(self.model_path) | |
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_path) | |
self.pad_token_id = self.processor.tokenizer.pad_token_id | |
self.sampling_rate = self.processor.feature_extractor.sampling_rate | |
def get_available_attributes(self): | |
if not hasattr(self, 'model'): | |
logger.error('model not loaded, call load_model first!') | |
raise AttributeError("model not defined") | |
att_list = set(self.processor.tokenizer.get_vocab().keys()) - set(self.processor.tokenizer.all_special_tokens) | |
att_list = [p.replace('p_','') for p in att_list if p[0]=='p'] | |
self.att_list = att_list | |
def print_availabel_attributes(self): | |
print(self.att_list) | |
def get_att_binary_group_indexs(self): | |
self.group_ids = [] #Each group contains the token_ids of [<PAD>, n_att, p_att] sorted by their token ids | |
for i, att in enumerate(self.att_list): | |
n_indx = self.processor.tokenizer.convert_tokens_to_ids(f'n_{att}') | |
p_indx = self.processor.tokenizer.convert_tokens_to_ids(f'p_{att}') | |
self.group_ids.append(sorted([self.pad_token_id, n_indx, p_indx])) | |
def decode_att(self, logits, att): #Need to lowercase when first read from the user | |
mask = torch.zeros(logits.size()[2], dtype = torch.bool) | |
try: | |
i = self.att_list.index(att) | |
except ValueError: | |
logger.error(f'The given attribute {att} not supported in the given model {self.model_path}') | |
raise | |
mask[self.group_ids[i]] = True | |
logits_g = logits[:,:,mask] | |
pred_ids = torch.argmax(logits_g,dim=-1) | |
pred_ids = pred_ids.cpu().apply_(lambda x: self.group_ids[i][x]) | |
pred = self.processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0].split() | |
return list(map(lambda x:{f'p_{att}':'+',f'n_{att}':'-'}[x], pred)) | |
def read_audio_file(self, audio_file): | |
if not os.path.exists(audio_file): | |
logger.error(f'Audio file {audio_file} is not exist') | |
raise FileNotFoundError | |
y, _ = librosa.load(audio_file, sr=self.sampling_rate) | |
return y | |
def get_logits(self, y): | |
input_values = self.processor(audio=y, sampling_rate=self.sampling_rate, return_tensors="pt").input_values | |
with torch.no_grad(): | |
logits = self.model(input_values).logits | |
return logits | |
def check_identical_phonemes(self, df_p2att): | |
identical_phonemes = [] | |
for index,row in df_p2att.iterrows(): | |
mask = df_p2att.eq(row).all(axis=1) | |
indexes = df_p2att[mask].index.values | |
if len(indexes) > 1: | |
identical_phonemes.append(tuple(indexes)) | |
if identical_phonemes: | |
logger.warning('The following phonemes has identical phonological features given the phonological features used in the model. If using fixed weight layer, these phonemes will be confused with each other') | |
identical_phonemes = set(identical_phonemes) | |
for x in identical_phonemes: | |
logger.warning(f"{','.join(x)}") | |
def read_phoneme2att(self,p2att_file): | |
if not os.path.exists(p2att_file): | |
logger.error(f'Phonological matrix file {p2att_file} is not exist') | |
raise FileNotFoundError(f'{p2att_file}') | |
df_p2att = pd.read_csv(p2att_file, index_col=0) | |
self.check_identical_phonemes(df_p2att) | |
not_supported = set(df_p2att.columns) - set(self.att_list) | |
if not_supported: | |
logger.warning(f"Attribute/s {','.join(not_supported)} is not supported by the model {self.model_path} and will be ignored. To get available attributes of the selected model run transcribe --model_path=/path/to/model print_availabel_attributes") | |
df_p2att = df_p2att.drop(columns=not_supported) | |
self.phoneme_list = df_p2att.index.values | |
self.p2att_map = {} | |
for i, r in df_p2att.iterrows(): | |
phoneme = i | |
self.p2att_map[phoneme] = [] | |
for att in r.index.values: | |
if f'p_{att}' not in self.processor.tokenizer.vocab: | |
logger.warn(f'Attribute {att} is not supported by the model {self.model_path} and will be ignored. To get available attributes of the selected model run transcribe --model_path=/path/to/model print_availabel_attributes') | |
continue | |
value = r[att] | |
if value == 0: | |
self.p2att_map[phoneme].append(f'n_{att}') | |
elif value == 1: | |
self.p2att_map[phoneme].append(f'p_{att}') | |
else: | |
logger.error(f'Invalid value of {value} for attribute {att} of phoneme {phoneme}. Values in the phoneme to attribute map should be either 0 or 1') | |
raise ValueError(f'{value} should be 0 or 1') | |
def create_phoneme_tokenizer(self): | |
vocab_list = self.phoneme_list | |
vocab_dict = {v: k+1 for k, v in enumerate(vocab_list)} | |
vocab_dict['<pad>'] = 0 | |
vocab_dict = dict(sorted(vocab_dict.items(), key= lambda x: x[1])) | |
vocab_file = 'phoneme_vocab.json' | |
with open(vocab_file, 'w') as f: | |
json.dump(vocab_dict, f) | |
#Build processor | |
self.phoneme_tokenizer = Wav2Vec2CTCTokenizer(vocab_file, pad_token="<pad>", word_delimiter_token="") | |
def create_phonological_matrix(self): | |
self.phonological_matrix = torch.zeros((self.phoneme_tokenizer.vocab_size, self.processor.tokenizer.vocab_size)).type(torch.FloatTensor) | |
self.phonological_matrix[self.phoneme_tokenizer.pad_token_id, self.processor.tokenizer.pad_token_id] = 1 | |
for p in self.phoneme_list: | |
for att in self.p2att_map[p]: | |
self.phonological_matrix[self.phoneme_tokenizer.convert_tokens_to_ids(p), self.processor.tokenizer.convert_tokens_to_ids(att)] = 1 | |
#This function gets the attribute logits from the output layer and convert to phonemes | |
#Input is a sequence of logits (one vector per frame) and output phoneme sequence | |
#Note that this is CTC so number of output phonemes is not equal to number of input frames | |
def decode_phoneme(self,logits): | |
def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor: | |
if mask is not None: | |
mask = mask.float() | |
while mask.dim() < vector.dim(): | |
mask = mask.unsqueeze(1) | |
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it | |
# results in nans when the whole vector is masked. We need a very small value instead of a | |
# zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely | |
# just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it | |
# becomes 0 - this is just the smallest value we can actually use. | |
vector = vector + (mask + 1e-45).log() | |
return torch.nn.functional.log_softmax(vector, dim=dim) | |
log_props_all_masked = [] | |
for i in range(len(self.att_list)): | |
mask = torch.zeros(logits.size()[2], dtype = torch.bool) | |
mask[self.group_ids[i]] = True | |
mask.unsqueeze_(0).unsqueeze_(0) | |
log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0) | |
log_props_all_masked.append(log_probs) | |
log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0) | |
log_probs_phoneme = torch.matmul(self.phonological_matrix,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor) | |
pred_ids = torch.argmax(log_probs_phoneme,dim=-1) | |
pred = self.phoneme_tokenizer.batch_decode(pred_ids,spaces_between_special_tokens=True)[0] | |
return pred | |
def print_human_readable(self, output, with_phoneme = False): | |
column_widths = [] | |
rows = [] | |
if with_phoneme: | |
column_widths.append(max([len(att['Name']) for att in output['Attributes']]+[len('Phoneme')])) | |
column_widths.extend([5]*max([len(att['Pattern']) for att in output['Attributes']]+[len(output['Phoneme']['symbols'])])) | |
rows.append(('Phoneme'.center(column_widths[0]), *[s.center(column_widths[j+1]) for j,s in enumerate(output['Phoneme']['symbols'])])) | |
else: | |
column_widths.append(max([len(att['Name']) for att in output['Attributes']])) | |
column_widths.extend([5]*max([len(att['Pattern']) for att in output['Attributes']])) | |
for i in range(len(output['Attributes'])): | |
att = output['Attributes'][i] | |
rows.append((att['Name'].center(column_widths[0]), *[s.center(column_widths[j+1]) for j,s in enumerate(att['Pattern'])])) | |
out_string = '' | |
for row in rows: | |
out_string += '|'.join(row) | |
out_string += '\n' | |
return out_string | |
def transcribe(self, audio_file, | |
attributes='all', | |
phonological_matrix_file = None, | |
human_readable = True): | |
output = {} | |
output['wav_file_path'] = audio_file | |
output['Attributes'] = [] | |
output['Phoneme'] = {} | |
#Initiate the model | |
#self.load_model() | |
#self.get_available_attributes() | |
#self.get_att_binary_group_indexs() | |
if attributes == 'all': | |
target_attributes = self.att_list | |
else: | |
attributes = attributes if isinstance(attributes,tuple) else (attributes,) | |
target_attributes = [att.lower() for att in attributes if att.lower() in self.att_list] | |
if not target_attributes: | |
logger.error(f'None of the given attributes is supported by model {self.model_path}. To get available attributes of the selected model run transcribe --model_path=/path/to/model get_available_attributes') | |
raise ValueError("Invalid attributes") | |
#Process audio | |
y = self.read_audio_file(audio_file) | |
self.logits = self.get_logits(y) | |
for att in target_attributes: | |
output['Attributes'].append({'Name':att, 'Pattern' : self.decode_att(self.logits, att)}) | |
if phonological_matrix_file: | |
self.read_phoneme2att(phonological_matrix_file) | |
self.create_phoneme_tokenizer() | |
self.create_phonological_matrix() | |
output['Phoneme']['symbols'] = self.decode_phoneme(self.logits).split() | |
json_string = json.dumps(output, indent=4) | |
if human_readable: | |
return self.print_human_readable(output, phonological_matrix_file!=None) | |
else: | |
return json_string | |
#return json_string | |
def main(): | |
fire.Fire(transcribe_SA) | |
if __name__ == '__main__': | |
main() | |