import re from copy import deepcopy import argparse import torch import torch.nn.functional as F from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, BartForConditionalGeneration, BartTokenizer,) from src.bart_with_group_beam import BartForConditionalGeneration_GroupBeam from src.utils import (construct_template, filter_words, formalize_tA, post_process_template) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ORION_HYPO_GENERATOR = 'chenxran/orion-hypothesis-generator' ORION_INS_GENERATOR = 'chenxran/orion-instance-generator' RELATIONS = [ "Causes", "HasProperty", "MadeUpOf", "isAfter", "isBefore", "xReact", "xWant", "xReason", "xAttr", "Desires", ] class BartInductor(object): def __init__( self, group_beam=True, continue_pretrain_instance_generator=True, continue_pretrain_hypo_generator=True, if_then=False ): self.if_then = if_then self.orion_instance_generator_path = 'facebook/bart-large' if not continue_pretrain_instance_generator else ORION_INS_GENERATOR self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR if group_beam: self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval() else: self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval() self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).to(device).eval() self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large") self.word_length = 2 self.stop_sub_list = ['he', 'she', 'this', 'that', 'and', 'it', 'which', 'who', 'whose', 'there', 'they', '.', 'its', 'one', 'i', ',', 'the', 'nobody', 'his', 'her', 'also', 'only', 'currently', 'here', '()', 'what', 'where', 'why', 'a', 'some', '"', ')', '(', 'now', 'everyone', 'everybody', 'their', 'often', 'usually', 'you', '-', '?', ';', 'in', 'on', 'each', 'both', 'him', 'typically', 'mostly', 'sometimes', 'normally', 'always', 'usually', 'still', 'today', 'was', 'were', 'but', 'although', 'current', 'all', 'have', 'has', 'later', 'with', 'most', 'nowadays', 'then', 'every', 'when', 'someone', 'anyone', 'somebody', 'anybody', 'any', 'being', 'get', 'getting', 'thus', 'under', 'even', 'for', 'can', 'rarely', 'never', 'may', 'generally', 'other', 'another', 'too', 'first', 'second', 'third', 'mainly', 'primarily', 'having', 'have', 'has'] self.stop_size = len(self.stop_sub_list) for i in range(self.stop_size): if self.stop_sub_list[i][0].isalpha(): temp = self.stop_sub_list[i][0].upper() + self.stop_sub_list[i][1:] self.stop_sub_list.append(temp) self.bad_words_ids = [self.tokenizer.encode(bad_word)[1:-1] for bad_word in ['also', ' also']] stop_index = self.tokenizer(self.stop_sub_list, max_length=4, padding=True) stop_index = torch.tensor(stop_index['input_ids'])[:, 1] stop_weight = torch.zeros(1, self.tokenizer.vocab_size).to(device) stop_weight[0, stop_index] -= 100 self.stop_weight = stop_weight[0, :] def clean(self, text): segments = text.split('') if len(segments) == 3 and segments[2].startswith('.'): return ''.join(segments[:2]) + '.' else: return text def generate(self, inputs, k=10, topk=10): with torch.no_grad(): tB_probs = self.generate_rule(inputs, k) #ret = [t[0].replace('','').replace('','') for t in tB_probs] ret = [(t[0].replace('','').replace('',''), t[1]) for t in tB_probs] new_ret = [] for temp in ret: temp = (self.clean(temp[0].strip()), temp[1]) if len(new_ret) < topk and temp not in new_ret: new_ret.append(temp) return new_ret def explore_mask(self, tA, k, tokens, prob, required_token, probs): if required_token == 0: return [[tokens, prob, probs]] if required_token <= self.word_length: k = min(k, 2) ret = [] generated_ids = self.tokenizer(tA, max_length=128, padding='longest', return_tensors='pt') # ["input_ids"].to(device) for key in generated_ids.keys(): generated_ids[key] = generated_ids[key].to(device) mask_index = torch.where(generated_ids["input_ids"][0] == self.tokenizer.mask_token_id) generated_ret = self.orion_instance_generator(**generated_ids) #logits = generated_ret.logits logits = generated_ret[0] softmax = F.softmax(logits, dim=-1) mask_word = softmax[0, mask_index[0][0], :] + self.stop_weight top_k = torch.topk(mask_word, k, dim=0) for i in range(top_k[1].size(0)): token_s = top_k[1][i] prob_s = top_k[0][i].item() token_this = self.tokenizer.decode([token_s]).strip() if token_this[0].isalpha() == False or len(token_this) <= 2: continue index_s = tA.index(self.tokenizer.mask_token) tAs = tA[:index_s] + token_this + tA[index_s + len(self.tokenizer.mask_token):] tokens_this = [t for t in tokens] tokens_this.append(token_this) probs_new = deepcopy(probs) probs_new.append(prob_s) ret.extend(self.explore_mask(tAs, 1, tokens_this, prob_s * prob, required_token - 1,probs_new)) return ret def extract_words_for_tA_bart(self, tA, k=6, print_it = False): spans = [t.lower().strip() for t in tA[:-1].split('')] generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device).to(torch.int64) generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k), #num_beam_groups=max(120, k), max_length=generated_ids.size(1) + 15, num_return_sequences=max(120, k), #min_length=generated_ids.size(1), #diversity_penalty=2.0, #length_penalty= 0.8, #early_stopping=True, bad_words_ids=bad_words_ids, no_repeat_ngram_size=2, output_scores=True, return_dict_in_generate=True) summary_ids = generated_ret['sequences'] probs = F.softmax(generated_ret['sequences_scores'].to(torch.float32)) txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids] ret = [] for i, txt in enumerate(txts): if tA.endswith('.'): if txt.endswith('.'): txt = txt[:-1].strip() txt += '.' word_imcomplete = False prob = probs[i].item() words_i = [] start_index = 0 for j in range(len(spans)-1): span1 = spans[j] span2 = spans[j+1] if (span1 in txt.lower()[start_index:]) and (span2 in txt.lower()[start_index:]): index1 = txt.lower().index(span1,start_index)+len(span1) if span2 == '': if txt[-1] == '.': index2 = len(txt) -1 else: index2 = len(txt) else: index2 = txt.lower().index(span2, start_index) words_i.append(txt[index1:index2].strip()) start_index = index2 #if words_i[-1] == '': # word_imcomplete = True else: word_imcomplete = True if word_imcomplete: # if print_it: # print(txt + '\t' + tA + '\t' + '×') continue ret.append([words_i, prob]) return sorted(ret, key=lambda x: x[1], reverse=True)[:k] def extract_words_for_tA(self, tA, k=6): word_mask_str = ' '.join([self.tokenizer.mask_token] * self.word_length) tA = tA.replace('', word_mask_str) mask_count = tA.count(self.tokenizer.mask_token) mask_probs = self.explore_mask(tA, k*20, [], 1.0, mask_count, []) ret = [] visited_mask_txt = {} for mask, prob, probs in mask_probs: mask_txt = ' '.join(mask).lower() if mask_txt in visited_mask_txt: continue visited_mask_txt[mask_txt] = 1 words = [] probs_words = [] for i in range(0,mask_count, self.word_length): words.append(' '.join(mask[i: i + self.word_length])) prob_word = 1.0 for j in range(i, i + self.word_length): prob_word *= probs[j] probs_words.append(prob_word) ret.append([words, prob, probs_words]) return sorted(ret, key=lambda x: x[1], reverse=True)[:k] def extract_templateBs_batch(self, words_prob, tA, k, print_it = False): words_prob_sorted = [] for (words, probA, *_) in words_prob: tokenized_word = self.tokenizer(words[0]) words_prob_sorted.append([words,probA,len(tokenized_word['input_ids'])]) words_prob_sorted.sort(key=lambda x:x[2]) batch_size = 8 templates = [] index_words = {} ret = {} num_beams = k for enum, (words, probA, *_) in enumerate(words_prob_sorted): template = construct_template(words, tA, self.if_then) templates.extend(template) for t in template: index_words[len(index_words)] = '\t'.join(words) # index_words[len(templates)-1] = '\t'.join(words) if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]): generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].to(device).to(torch.int64) generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams, num_beam_groups=num_beams, max_length=28, #template_length+5, num_return_sequences=num_beams, min_length=3, diversity_penalty=1.0, early_stopping=True, #length_penalty = 0.1, bad_words_ids=self.bad_words_ids, #no_repeat_ngram_size=2, output_scores=True, return_dict_in_generate=True, decoder_ori_input_ids = generated_ids, top_p=0.95, ) summary_ids = generated_ret['sequences'].reshape((len(templates),num_beams,-1)) probs = F.softmax(generated_ret['sequences_scores'].reshape((len(templates),num_beams)),dim=1).to(torch.float32) for ii in range(summary_ids.size(0)): txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids[ii]] ii_template = [] words_ii = index_words[ii].split('\t') for i, txt in enumerate(txts): prob = probs[ii][i].item() * probA txt = txt.lower() txt = post_process_template(txt) words_ii_matched = [word.lower() for word in words_ii] #extract_similar_words(txt, words_ii) if words_ii_matched is None: prob = 0.0 else: for j, word in enumerate(words_ii_matched): if word not in txt: prob = 0.0 else: txt = txt.replace(word, ''.format(j), 1) if txt.count(' ')+1<=3: continue ii_template.append([txt, prob]) # if print_it: # print(index_words[ii]+'\t'+str(convert_for_print(ii_template))) for template, prob in ii_template: if template not in ret: ret[template] = 0.0 ret[template] += prob templates.clear() index_words.clear() return ret def generate_rule(self, tA, k=10, print_it = False): tA=formalize_tA(tA) if 'bart' in str(self.orion_instance_generator.__class__).lower(): words_prob = self.extract_words_for_tA_bart(tA, k,print_it=print_it) words_prob = filter_words(words_prob)[:k] # if print_it: # print(convert_for_print(words_prob)) else: words_prob = self.extract_words_for_tA(tA, k) words_prob = filter_words(words_prob)[:k] tB_prob = self.extract_templateBs_batch(words_prob, tA, k,print_it=print_it) ret = [] for k1 in tB_prob: ret.append([k1, tB_prob[k1]]) ret = sorted(ret, key=lambda x: x[1], reverse=True)[:k] if self.if_then: for i, temp in enumerate(ret): sentence = temp[0] if "then" in sentence: sentence = sentence.split("then")[-1] else: sentence = sentence.replace("if", "") ret[i][0] = sentence return ret class CometInductor(object): def __init__(self): self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").to(device).eval().float() # .half()->float self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART") self.task = "summarization" self.use_task_specific_params() self.decoder_start_token_id = None def drop_repeat(self, old_list): new_list = [] for item in old_list: if item not in new_list: new_list.append(item) return new_list def chunks(self, lst, n): """Yield successive n-sized chunks from lst.""" for i in range(0, len(lst), n): yield lst[i : i + n] def use_task_specific_params(self): """Update config with summarization specific params.""" task_specific_params = self.model.config.task_specific_params if task_specific_params is not None: pars = task_specific_params.get(self.task, {}) self.model.config.update(pars) def trim_batch( self, input_ids, pad_token_id, attention_mask=None, ): """Remove columns that are populated exclusively by pad_token_id""" keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) if attention_mask is None: return input_ids[:, keep_column_mask] else: return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) def generate(self, inputs, k, topk): outputs = [] words = ['PersonX', 'PersonY'] for i, _ in enumerate(re.findall("", inputs)): index = inputs.index('') inputs = inputs[:index] + words[i] + inputs[index + len(''):] for relation in RELATIONS: inputs = "{} {} [GEN]".format(inputs[:-1], relation) gen = self.generate_(inputs, num_generate=10) switch = 0 for output in gen[0]: output = output.strip() if re.search("PersonX|X", output) and re.search("PersonY|Y", output): temp = re.sub("PersonX|X|PersonY|Y", "", output.strip()) if temp.endswith("."): outputs.append(temp) else: outputs.append(temp + ".") switch = 1 break if switch == 0: output = gen[0][0] temp = re.sub("PersonX|X|PersonY|Y", "", output.strip()) if temp.endswith("."): outputs.append(temp) else: outputs.append(temp + ".") outputs = [output.replace('PersonX', '').replace('PersonY', '') for output in outputs] return outputs def generate_( self, queries, decode_method="beam", num_generate=5, ): with torch.no_grad(): decs = [] batch = self.tokenizer(queries, return_tensors="pt", padding="longest") input_ids, attention_mask = self.trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id) summaries = self.model.generate( input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), decoder_start_token_id=self.decoder_start_token_id, num_beams=num_generate, num_return_sequences=num_generate, ) dec = self.tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) decs.append(dec) return decs