#!/usr/bin/env python # coding: utf-8 import torch import model_handling from data_handling import DataCollatorForNormSeq2Seq from model_handling import EncoderDecoderSpokenNorm import os import random import data_handling from transformers.generation_logits_process import LogitsProcessorList from transformers.generation_stopping_criteria import StoppingCriteriaList from transformers.generation_beam_search import BeamSearchScorer from dataclasses import dataclass from transformers.file_utils import ModelOutput import utils # os.environ["CUDA_VISIBLE_DEVICES"] = "4" use_gpu = False if use_gpu: if not torch.cuda.is_available(): use_gpu = False tokenizer = model_handling.init_tokenizer() model = EncoderDecoderSpokenNorm.from_pretrained('nguyenvulebinh/spoken-norm-taggen-v2').eval() data_collator = DataCollatorForNormSeq2Seq(tokenizer) if use_gpu: model = model.cuda() def make_batch_input(text_input_list): batch_src_ids, batch_src_lengths = [], [] for text_input in text_input_list: src_ids, src_lengths = [], [] for src in text_input.split(): src_tokenized = tokenizer(src) ids = src_tokenized["input_ids"][1:-1] src_ids.extend(ids) src_lengths.append(len(ids)) src_ids = torch.tensor([0] + src_ids + [2]) src_lengths = torch.tensor([1] + src_lengths + [1]) + 1 batch_src_ids.append(src_ids) batch_src_lengths.append(src_lengths) assert sum(src_lengths - 1) == len(src_ids), "{} vs {}".format(sum(src_lengths), len(src_ids)) input_tokenized = tokenizer.pad({"input_ids": batch_src_ids}, padding=True) input_word_length = tokenizer.pad({"input_ids": batch_src_lengths}, padding=True)["input_ids"] - 1 return input_tokenized['input_ids'], input_tokenized['attention_mask'], input_word_length def make_batch_bias_list(bias_list): if len(bias_list) > 0: bias = data_collator.encode_list_string(bias_list) bias_input_ids = bias['input_ids'] bias_attention_mask = bias['attention_mask'] else: bias_input_ids = None bias_attention_mask = None return bias_input_ids, bias_attention_mask def build_spoken_pronounce_mapping(bias_list): list_pronounce = [] mapping = dict({}) for item in bias_list: pronounces = item.split(' | ')[1:] pronounces = [tokenizer(item)['input_ids'][1:-1] for item in pronounces] list_pronounce.extend(pronounces) subword_ids = list(set([item for sublist in list_pronounce for item in sublist])) mapping = {item: [] for item in subword_ids} for item in list_pronounce: for wid in subword_ids: if wid in item: mapping[wid].append(item) return mapping def find_pivot(seq, subseq): n = len(seq) m = len(subseq) result = [] for i in range(n - m + 1): if seq[i] == subseq[0] and seq[i:i + m] == subseq: result.append(i) return result def revise_spoken_tagging(list_tags, list_words, pronounce_mapping): if len(pronounce_mapping) == 0: return list_tags result = [] for tags_tensor, sen in zip(list_tags, list_words): tags = tags_tensor.detach().numpy().tolist() sen = sen.detach().numpy().tolist() candidate_pronounce = dict({}) for idx in range(len(tags)): if tags[idx] != 0 and sen[idx] in pronounce_mapping: for pronounce in pronounce_mapping[sen[idx]]: pronounce_word = str(pronounce) start_find_idx = max(0, idx - len(pronounce)) end_find_idx = idx + len(pronounce) find_idx = find_pivot(sen[start_find_idx: end_find_idx], pronounce) if len(find_idx) > 0: find_idx = [item + start_find_idx for item in find_idx] for map_idx in find_idx: if candidate_pronounce.get(map_idx, None) is None: candidate_pronounce[map_idx] = len(pronounce) else: candidate_pronounce[map_idx] = max(candidate_pronounce[map_idx], len(pronounce)) for idx, len_word in candidate_pronounce.items(): tags_tensor[idx] = 1 for i in range(1, len_word): tags_tensor[idx + i] = 2 result.append(tags_tensor) return result def make_spoken_feature(input_features, text_input_list, pronounce_mapping=dict({})): features = { "input_ids": input_features[0], "word_src_lengths": input_features[2], "attention_mask": input_features[1], # "bias_input_ids": bias_features[0], # "bias_attention_mask": bias_features[1], "bias_input_ids": None, "bias_attention_mask": None, } if use_gpu: for key in features.keys(): if features[key] is not None: features[key] = features[key].cuda() encoder_output = model.get_encoder()(**features) spoken_tagging_output = torch.argmax(encoder_output[0].spoken_tagging_output, dim=-1) spoken_tagging_output = revise_spoken_tagging(spoken_tagging_output, features['input_ids'], pronounce_mapping) # print(spoken_tagging_output) # print(features['input_ids']) word_src_lengths = features['word_src_lengths'] encoder_features = encoder_output[0][0] list_spoken_features = [] list_pre_norm = [] for tagging_sample, sample_word_length, text_input_features, sample_text in zip(spoken_tagging_output, word_src_lengths, encoder_features, text_input_list): spoken_feature_idx = [] sample_words = [''] + sample_text.split() + [''] norm_words = [] spoken_phrase = [] spoken_features = [] if tagging_sample.sum() == 0: list_pre_norm.append(sample_words) continue for idx, word_length in enumerate(sample_word_length): if word_length > 0: start = sample_word_length[:idx].sum() end = start + word_length if tagging_sample[start: end].sum() > 0 and sample_words[idx] not in ['', '']: # Word has start tag if (tagging_sample[start: end] == 1).sum(): if len(spoken_phrase) > 0: norm_words.append('[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase))) spoken_phrase = [] list_spoken_features.append(torch.cat(spoken_features)) spoken_features = [] spoken_phrase.append(sample_words[idx]) spoken_features.append(text_input_features[start: end]) else: if len(spoken_phrase) > 0: norm_words.append('[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase))) spoken_phrase = [] list_spoken_features.append(torch.cat(spoken_features)) spoken_features = [] norm_words.append(sample_words[idx]) if len(spoken_phrase) > 0: norm_words.append('[{}]({})'.format(len(list_spoken_features), ' '.join(spoken_phrase))) spoken_phrase = [] list_spoken_features.append(torch.cat(spoken_features)) spoken_features = [] list_pre_norm.append(norm_words) list_features_mask = [] if len(list_spoken_features) > 0: feature_pad = torch.zeros_like(list_spoken_features[0][:1, :]) max_length = max([len(item) for item in list_spoken_features]) for i in range(len(list_spoken_features)): spoken_length = len(list_spoken_features[i]) remain_length = max_length - spoken_length device = list_spoken_features[i].device list_spoken_features[i] = torch.cat([list_spoken_features[i], feature_pad.expand(remain_length, feature_pad.size(-1))]).unsqueeze(0) list_features_mask.append(torch.cat([torch.ones(spoken_length, device=device, dtype=torch.int64), torch.zeros(remain_length, device=device, dtype=torch.int64)]).unsqueeze(0)) if len(list_spoken_features) > 0: list_spoken_features = torch.cat(list_spoken_features) list_features_mask = torch.cat(list_features_mask) return list_spoken_features, list_features_mask, list_pre_norm def make_bias_feature(bias_raw_features): features = { "bias_input_ids": bias_raw_features[0], "bias_attention_mask": bias_raw_features[1] } if use_gpu: for key in features.keys(): if features[key] is not None: features[key] = features[key].cuda() return model.forward_bias(**features) def decode_plain_output(decoder_output): plain_output = [item.split()[1:] for item in tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=False)] scores = torch.stack(list(decoder_output['scores'])).transpose(1, 0) logit_output = torch.gather(scores, -1, decoder_output['sequences'][:, 1:].unsqueeze(-1)).squeeze(-1) special_tokens = list(tokenizer.special_tokens_map.values()) generated_output = [] generated_scores = [] # filter special tokens for out_text, out_score in zip(plain_output, logit_output): temp_str, tmp_score = [], [] for piece, score in zip(out_text, out_score): if piece not in special_tokens: temp_str.append(piece) tmp_score.append(score) if len(temp_str) > 0: generated_output.append(' '.join(temp_str).replace('▁', '|').replace(' ', '').replace('|', ' ').strip()) generated_scores.append((sum(tmp_score)/len(tmp_score)).cpu().detach().numpy().tolist()) else: generated_output.append("") generated_scores.append(0) return generated_output, generated_scores def generate_spoken_norm(list_spoken_features, list_features_mask, bias_features): @dataclass class EncoderOutputs(ModelOutput): last_hidden_state: torch.FloatTensor = None hidden_states: torch.FloatTensor = None attentions: torch.FloatTensor = None batch_size = list_spoken_features.size(0) max_length = 50 device = list_spoken_features.device decoder_input_ids = torch.zeros((batch_size, 1), device=device, dtype=torch.int64) stopping_criteria = model._get_stopping_criteria(max_length=max_length, max_time=None, stopping_criteria=StoppingCriteriaList()) model_kwargs = { "encoder_outputs": EncoderOutputs(last_hidden_state=list_spoken_features), "encoder_bias_outputs": bias_features, "attention_mask": list_features_mask } decoder_output = model.greedy_search( decoder_input_ids, logits_processor=LogitsProcessorList(), stopping_criteria=stopping_criteria, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, output_scores=True, return_dict_in_generate=True, **model_kwargs, ) plain_output, plain_score = decode_plain_output(decoder_output) # plain_output = tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=True) # # print(decoder_output) # plain_output = [word.replace('▁', '|').replace(' ', '').replace('|', ' ').strip() for word in plain_output] return plain_output, plain_score def generate_beam_spoken_norm(list_spoken_features, list_features_mask, bias_features, num_beams=3): @dataclass class EncoderOutputs(ModelOutput): last_hidden_state: torch.FloatTensor = None batch_size = list_spoken_features.size(0) max_length = 50 num_return_sequences = 1 device = list_spoken_features.device decoder_input_ids = torch.zeros((batch_size, 1), device=device, dtype=torch.int64) stopping_criteria = model._get_stopping_criteria(max_length=max_length, max_time=None, stopping_criteria=StoppingCriteriaList()) model_kwargs = { "encoder_outputs": EncoderOutputs(last_hidden_state=list_spoken_features), "encoder_bias_outputs": bias_features, "attention_mask": list_features_mask } beam_scorer = BeamSearchScorer( batch_size=batch_size, num_beams=num_beams, device=device, do_early_stopping=True, num_beam_hyps_to_keep=num_return_sequences, ) decoder_input_ids, model_kwargs = model._expand_inputs_for_generation( decoder_input_ids, expand_size=num_beams, is_encoder_decoder=True, **model_kwargs ) decoder_output = model.beam_search( decoder_input_ids, beam_scorer, logits_processor=LogitsProcessorList(), stopping_criteria=stopping_criteria, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, output_scores=None, return_dict_in_generate=True, **model_kwargs, ) plain_output = tokenizer.batch_decode(decoder_output['sequences'], skip_special_tokens=True) plain_output = [word.replace('▁', '|').replace(' ', '').replace('|', ' ').strip() for word in plain_output] return plain_output, None def reformat_normed_term(list_pre_norm, spoken_norm_output, spoken_norm_output_score=None, threshold=None, debug=False): output = [] for pre_norm in list_pre_norm: normed_words = [] # words = pre_norm.split() for w in pre_norm: if w.startswith(''): term = w[7:].split('](') # print(w) # print(term) term_idx = int(term[0]) norm_val = spoken_norm_output[term_idx] norm_val_score = None if (spoken_norm_output_score is None or threshold is None) else spoken_norm_output_score[term_idx] pre_norm_val = term[1][:-1] if debug: if norm_val_score is not None: normed_words.append("({})({:.2f})[{}]".format(norm_val, norm_val_score, pre_norm_val)) else: normed_words.append("({})[{}]".format(norm_val, pre_norm_val)) else: if threshold is not None and norm_val_score is not None: if norm_val_score > threshold: normed_words.append(norm_val) else: normed_words.append(pre_norm_val) else: normed_words.append(norm_val) else: normed_words.append(w) output.append(" ".join(normed_words)) return output def infer(text_input_list, bias_list): # extract bias feature bias_raw_features = make_batch_bias_list(bias_list) bias_features = make_bias_feature(bias_raw_features) pronounce_mapping = build_spoken_pronounce_mapping(bias_list) # Chunk split input and create feature text_input_chunk_list = [utils.split_chunk_input(item, chunk_size=60, overlap=20) for item in text_input_list] num_chunks = [len(i) for i in text_input_chunk_list] flatten_list = [y for x in text_input_chunk_list for y in x] input_raw_features = make_batch_input(flatten_list) # Extract norm term and spoken feature list_spoken_features, list_features_mask, list_pre_norm = make_spoken_feature(input_raw_features, flatten_list, pronounce_mapping) # Merge overlap chunks list_pre_norm_by_input = [] for idx, input_num in enumerate(num_chunks): start = sum(num_chunks[:idx]) end = start + num_chunks[idx] list_pre_norm_by_input.append(list_pre_norm[start:end]) text_input_list_pre_norm = [utils.merge_chunk_pre_norm(list_chunks, overlap=20, debug=False) for list_chunks in list_pre_norm_by_input] if len(list_spoken_features) > 0: spoken_norm_output, spoken_norm_score = generate_spoken_norm(list_spoken_features, list_features_mask, bias_features) else: spoken_norm_output, spoken_norm_score = [], None return reformat_normed_term(text_input_list_pre_norm, spoken_norm_output, spoken_norm_score, threshold=15, debug=False)