Spaces:
Sleeping
Sleeping
File size: 13,160 Bytes
4c01711 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
import fire
import logging
import sys, os
import yaml
import json
import torch
import librosa
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2Processor, Wav2Vec2ForCTC
import transformers
import pandas as pd
logger = logging.getLogger(__name__)
# Setup logging
logger.setLevel(logging.ERROR)
console_handler = logging.StreamHandler()
formater = logging.Formatter(fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",)
console_handler.setFormatter(formater)
console_handler.setLevel(logging.ERROR)
logger.addHandler(console_handler)
class transcribe_SA():
def __init__(self, model_path, verbose=0):
if verbose == 0:
logger.setLevel(logging.ERROR)
transformers.logging.set_verbosity_error()
#console_handler.setLevel(logging.ERROR)
elif verbose == 1:
logger.setLevel(logging.WARNING)
transformers.logging.set_verbosity_warning()
#console_handler.setLevel(logging.WARNING)
else:
logger.setLevel(logging.INFO)
transformers.logging.set_verbosity_info()
#console_handler.setLevel(logging.INFO)
# Read YAML file
logger.info('Init Object')
if torch.cuda.is_available():
self.accelerate = True
self.device = torch.device('cuda')
self.n_devices = torch.cuda.device_count()
assert self.n_devices == 1, 'Support only single GPU. Please use CUDA_VISIBLE_DEVICES=gpu_index if you have multiple gpus' #Currently support only single gpu
else:
self.device = torch.device('cpu')
self.n_devices = 1
self.model_path = model_path
self.load_model()
self.get_available_attributes()
self.get_att_binary_group_indexs()
def load_model(self):
if not os.path.exists(self.model_path):
logger.error(f'Model file {self.model_path} is not exist')
raise FileNotFoundError
self.processor = Wav2Vec2Processor.from_pretrained(self.model_path)
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_path)
self.pad_token_id = self.processor.tokenizer.pad_token_id
self.sampling_rate = self.processor.feature_extractor.sampling_rate
def get_available_attributes(self):
if not hasattr(self, 'model'):
logger.error('model not loaded, call load_model first!')
raise AttributeError("model not defined")
att_list = set(self.processor.tokenizer.get_vocab().keys()) - set(self.processor.tokenizer.all_special_tokens)
att_list = [p.replace('p_','') for p in att_list if p[0]=='p']
self.att_list = att_list
def print_availabel_attributes(self):
print(self.att_list)
def get_att_binary_group_indexs(self):
self.group_ids = [] #Each group contains the token_ids of [<PAD>, n_att, p_att] sorted by their token ids
for i, att in enumerate(self.att_list):
n_indx = self.processor.tokenizer.convert_tokens_to_ids(f'n_{att}')
p_indx = self.processor.tokenizer.convert_tokens_to_ids(f'p_{att}')
self.group_ids.append(sorted([self.pad_token_id, n_indx, p_indx]))
def decode_att(self, logits, att): #Need to lowercase when first read from the user
mask = torch.zeros(logits.size()[2], dtype = torch.bool)
try:
i = self.att_list.index(att)
except ValueError:
logger.error(f'The given attribute {att} not supported in the given model {self.model_path}')
raise
mask[self.group_ids[i]] = True
logits_g = logits[:,:,mask]
pred_ids = torch.argmax(logits_g,dim=-1)
pred_ids = pred_ids.cpu().apply_(lambda x: self.group_ids[i][x])
pred = self.processor.batch_decode(pred_ids,spaces_between_special_tokens=True)[0].split()
return list(map(lambda x:{f'p_{att}':'+',f'n_{att}':'-'}[x], pred))
def read_audio_file(self, audio_file):
if not os.path.exists(audio_file):
logger.error(f'Audio file {audio_file} is not exist')
raise FileNotFoundError
y, _ = librosa.load(audio_file, sr=self.sampling_rate)
return y
def get_logits(self, y):
input_values = self.processor(audio=y, sampling_rate=self.sampling_rate, return_tensors="pt").input_values
with torch.no_grad():
logits = self.model(input_values).logits
return logits
def check_identical_phonemes(self, df_p2att):
identical_phonemes = []
for index,row in df_p2att.iterrows():
mask = df_p2att.eq(row).all(axis=1)
indexes = df_p2att[mask].index.values
if len(indexes) > 1:
identical_phonemes.append(tuple(indexes))
if identical_phonemes:
logger.warning('The following phonemes has identical phonological features given the phonological features used in the model. If using fixed weight layer, these phonemes will be confused with each other')
identical_phonemes = set(identical_phonemes)
for x in identical_phonemes:
logger.warning(f"{','.join(x)}")
def read_phoneme2att(self,p2att_file):
if not os.path.exists(p2att_file):
logger.error(f'Phonological matrix file {p2att_file} is not exist')
raise FileNotFoundError(f'{p2att_file}')
df_p2att = pd.read_csv(p2att_file, index_col=0)
self.check_identical_phonemes(df_p2att)
not_supported = set(df_p2att.columns) - set(self.att_list)
if not_supported:
logger.warning(f"Attribute/s {','.join(not_supported)} is not supported by the model {self.model_path} and will be ignored. To get available attributes of the selected model run transcribe --model_path=/path/to/model print_availabel_attributes")
df_p2att = df_p2att.drop(columns=not_supported)
self.phoneme_list = df_p2att.index.values
self.p2att_map = {}
for i, r in df_p2att.iterrows():
phoneme = i
self.p2att_map[phoneme] = []
for att in r.index.values:
if f'p_{att}' not in self.processor.tokenizer.vocab:
logger.warn(f'Attribute {att} is not supported by the model {self.model_path} and will be ignored. To get available attributes of the selected model run transcribe --model_path=/path/to/model print_availabel_attributes')
continue
value = r[att]
if value == 0:
self.p2att_map[phoneme].append(f'n_{att}')
elif value == 1:
self.p2att_map[phoneme].append(f'p_{att}')
else:
logger.error(f'Invalid value of {value} for attribute {att} of phoneme {phoneme}. Values in the phoneme to attribute map should be either 0 or 1')
raise ValueError(f'{value} should be 0 or 1')
def create_phoneme_tokenizer(self):
vocab_list = self.phoneme_list
vocab_dict = {v: k+1 for k, v in enumerate(vocab_list)}
vocab_dict['<pad>'] = 0
vocab_dict = dict(sorted(vocab_dict.items(), key= lambda x: x[1]))
vocab_file = 'phoneme_vocab.json'
with open(vocab_file, 'w') as f:
json.dump(vocab_dict, f)
#Build processor
self.phoneme_tokenizer = Wav2Vec2CTCTokenizer(vocab_file, pad_token="<pad>", word_delimiter_token="")
def create_phonological_matrix(self):
self.phonological_matrix = torch.zeros((self.phoneme_tokenizer.vocab_size, self.processor.tokenizer.vocab_size)).type(torch.FloatTensor)
self.phonological_matrix[self.phoneme_tokenizer.pad_token_id, self.processor.tokenizer.pad_token_id] = 1
for p in self.phoneme_list:
for att in self.p2att_map[p]:
self.phonological_matrix[self.phoneme_tokenizer.convert_tokens_to_ids(p), self.processor.tokenizer.convert_tokens_to_ids(att)] = 1
#This function gets the attribute logits from the output layer and convert to phonemes
#Input is a sequence of logits (one vector per frame) and output phoneme sequence
#Note that this is CTC so number of output phonemes is not equal to number of input frames
def decode_phoneme(self,logits):
def masked_log_softmax(vector: torch.Tensor, mask: torch.Tensor, dim: int = -1) -> torch.Tensor:
if mask is not None:
mask = mask.float()
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases. log(1 + 1e-45) is still basically 0, so we can safely
# just add 1e-45 before calling mask.log(). We use 1e-45 because 1e-46 is so small it
# becomes 0 - this is just the smallest value we can actually use.
vector = vector + (mask + 1e-45).log()
return torch.nn.functional.log_softmax(vector, dim=dim)
log_props_all_masked = []
for i in range(len(self.att_list)):
mask = torch.zeros(logits.size()[2], dtype = torch.bool)
mask[self.group_ids[i]] = True
mask.unsqueeze_(0).unsqueeze_(0)
log_probs = masked_log_softmax(vector=logits, mask=mask, dim=-1).masked_fill(~mask,0)
log_props_all_masked.append(log_probs)
log_probs_cat = torch.stack(log_props_all_masked, dim=0).sum(dim=0)
log_probs_phoneme = torch.matmul(self.phonological_matrix,log_probs_cat.transpose(1,2)).transpose(1,2).type(torch.FloatTensor)
pred_ids = torch.argmax(log_probs_phoneme,dim=-1)
pred = self.phoneme_tokenizer.batch_decode(pred_ids,spaces_between_special_tokens=True)[0]
return pred
def print_human_readable(self, output, with_phoneme = False):
column_widths = []
rows = []
if with_phoneme:
column_widths.append(max([len(att['Name']) for att in output['Attributes']]+[len('Phoneme')]))
column_widths.extend([5]*max([len(att['Pattern']) for att in output['Attributes']]+[len(output['Phoneme']['symbols'])]))
rows.append(('Phoneme'.center(column_widths[0]), *[s.center(column_widths[j+1]) for j,s in enumerate(output['Phoneme']['symbols'])]))
else:
column_widths.append(max([len(att['Name']) for att in output['Attributes']]))
column_widths.extend([5]*max([len(att['Pattern']) for att in output['Attributes']]))
for i in range(len(output['Attributes'])):
att = output['Attributes'][i]
rows.append((att['Name'].center(column_widths[0]), *[s.center(column_widths[j+1]) for j,s in enumerate(att['Pattern'])]))
out_string = ''
for row in rows:
out_string += '|'.join(row)
out_string += '\n'
return out_string
def transcribe(self, audio_file,
attributes='all',
phonological_matrix_file = None,
human_readable = True):
output = {}
output['wav_file_path'] = audio_file
output['Attributes'] = []
output['Phoneme'] = {}
#Initiate the model
#self.load_model()
#self.get_available_attributes()
#self.get_att_binary_group_indexs()
if attributes == 'all':
target_attributes = self.att_list
else:
attributes = attributes if isinstance(attributes,tuple) else (attributes,)
target_attributes = [att.lower() for att in attributes if att.lower() in self.att_list]
if not target_attributes:
logger.error(f'None of the given attributes is supported by model {self.model_path}. To get available attributes of the selected model run transcribe --model_path=/path/to/model get_available_attributes')
raise ValueError("Invalid attributes")
#Process audio
y = self.read_audio_file(audio_file)
self.logits = self.get_logits(y)
for att in target_attributes:
output['Attributes'].append({'Name':att, 'Pattern' : self.decode_att(self.logits, att)})
if phonological_matrix_file:
self.read_phoneme2att(phonological_matrix_file)
self.create_phoneme_tokenizer()
self.create_phonological_matrix()
output['Phoneme']['symbols'] = self.decode_phoneme(self.logits).split()
json_string = json.dumps(output, indent=4)
if human_readable:
return self.print_human_readable(output, phonological_matrix_file!=None)
else:
return json_string
#return json_string
def main():
fire.Fire(transcribe_SA)
if __name__ == '__main__':
main()
|