import re from copy import deepcopy import argparse import torch import torch.nn.functional as F from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, BartForConditionalGeneration, BartTokenizer,) from src.bart_with_group_beam import BartForConditionalGeneration_GroupBeam from src.utils import (construct_template, filter_words, formalize_tA, post_process_template) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") ORION_HYPO_GENERATOR = 'chenxran/orion-hypothesis-generator' ORION_INS_GENERATOR = 'chenxran/orion-instance-generator' RELATIONS = [ "Causes", "HasProperty", "MadeUpOf", "isAfter", "isBefore", "xReact", "xWant", "xReason", "xAttr", "Desires", ] class BartInductor(object): def __init__( self, group_beam=True, continue_pretrain_instance_generator=True, continue_pretrain_hypo_generator=True, if_then=False ): self.if_then = if_then self.orion_instance_generator_path = 'facebook/bart-large' if not continue_pretrain_instance_generator else ORION_INS_GENERATOR self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR if group_beam: self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval() else: self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval() self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).to(device).eval() self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large", use_fast=True) self.word_length = 2 self.stop_sub_list = ['he', 'she', 'this', 'that', 'and', 'it', 'which', 'who', 'whose', 'there', 'they', '.', 'its', 'one', 'i', ',', 'the', 'nobody', 'his', 'her', 'also', 'only', 'currently', 'here', '()', 'what', 'where', 'why', 'a', 'some', '"', ')', '(', 'now', 'everyone', 'everybody', 'their', 'often', 'usually', 'you', '-', '?', ';', 'in', 'on', 'each', 'both', 'him', 'typically', 'mostly', 'sometimes', 'normally', 'always', 'usually', 'still', 'today', 'was', 'were', 'but', 'although', 'current', 'all', 'have', 'has', 'later', 'with', 'most', 'nowadays', 'then', 'every', 'when', 'someone', 'anyone', 'somebody', 'anybody', 'any', 'being', 'get', 'getting', 'thus', 'under', 'even', 'for', 'can', 'rarely', 'never', 'may', 'generally', 'other', 'another', 'too', 'first', 'second', 'third', 'mainly', 'primarily', 'having', 'have', 'has'] self.stop_size = len(self.stop_sub_list) for i in range(self.stop_size): if self.stop_sub_list[i][0].isalpha(): temp = self.stop_sub_list[i][0].upper() + self.stop_sub_list[i][1:] self.stop_sub_list.append(temp) self.bad_words_ids = [self.tokenizer.encode(bad_word)[1:-1] for bad_word in ['also', ' also']] stop_index = self.tokenizer(self.stop_sub_list, max_length=4, padding=True) stop_index = torch.tensor(stop_index['input_ids'])[:, 1] stop_weight = torch.zeros(1, self.tokenizer.vocab_size).to(device) stop_weight[0, stop_index] -= 100 self.stop_weight = stop_weight[0, :] def clean(self, text): segments = re.split(r'', text) last_segment = segments[-1] if last_segment.startswith('.'): return text[:text.rfind(last_segment)]+'.' else: return text def generate(self, inputs, k=10, topk=10, return_scores=False): with torch.no_grad(): tB_probs = self.generate_rule(inputs, k) new_ret = [] if return_scores: ret = [(t[0], t[1]) for t in tB_probs] for temp in ret: temp = (self.clean(temp[0].strip()), temp[1]) if len(new_ret) < topk and temp not in new_ret: new_ret.append(temp) else: ret = [t[0] for t in tB_probs] for temp in ret: temp = self.clean(temp.strip()) if len(new_ret) < topk and temp not in new_ret: new_ret.append(temp) return new_ret def explore_mask(self, tA, k, tokens, prob, required_token, probs): if required_token == 0: return [[tokens, prob, probs]] if required_token <= self.word_length: k = min(k, 2) ret = [] generated_ids = self.tokenizer(tA, max_length=128, padding='longest', return_tensors='pt') # ["input_ids"].to(device) for key in generated_ids.keys(): generated_ids[key] = generated_ids[key].to(device) mask_index = torch.where(generated_ids["input_ids"][0] == self.tokenizer.mask_token_id) generated_ret = self.orion_instance_generator(**generated_ids) #logits = generated_ret.logits logits = generated_ret[0] softmax = F.softmax(logits, dim=-1) mask_word = softmax[0, mask_index[0][0], :] + self.stop_weight top_k = torch.topk(mask_word, k, dim=0) for i in range(top_k[1].size(0)): token_s = top_k[1][i] prob_s = top_k[0][i].item() token_this = self.tokenizer.decode([token_s]).strip() if token_this[0].isalpha() == False or len(token_this) <= 2: continue index_s = tA.index(self.tokenizer.mask_token) tAs = tA[:index_s] + token_this + tA[index_s + len(self.tokenizer.mask_token):] tokens_this = [t for t in tokens] tokens_this.append(token_this) probs_new = deepcopy(probs) probs_new.append(prob_s) ret.extend(self.explore_mask(tAs, 1, tokens_this, prob_s * prob, required_token - 1,probs_new)) return ret def extract_words_for_tA_bart(self, tA, k=6, print_it = False): spans = [t.lower().strip() for t in re.split(r'<.*?>', tA[:-1])] generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device).to(torch.int64) generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k), #num_beam_groups=max(120, k), max_length=generated_ids.size(1) + 15, num_return_sequences=max(120, k), #min_length=generated_ids.size(1), #diversity_penalty=2.0, #length_penalty= 0.8, #early_stopping=True, bad_words_ids=bad_words_ids, no_repeat_ngram_size=2, output_scores=True, return_dict_in_generate=True) summary_ids = generated_ret['sequences'] probs = F.softmax(generated_ret['sequences_scores'].to(torch.float32)) txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids] ret = [] for i, txt in enumerate(txts): if tA.endswith('.'): if txt.endswith('.'): txt = txt[:-1].strip() txt += '.' word_imcomplete = False prob = probs[i].item() words_i = [] start_index = 0 for j in range(len(spans)-1): span1 = spans[j] span2 = spans[j+1] if (span1 in txt.lower()[start_index:]) and (span2 in txt.lower()[start_index:]): index1 = txt.lower().index(span1,start_index)+len(span1) if span2 == '': if txt[-1] == '.': index2 = len(txt) -1 else: index2 = len(txt) else: index2 = txt.lower().index(span2, start_index) words_i.append(txt[index1:index2].strip()) start_index = index2 #if words_i[-1] == '': # word_imcomplete = True else: word_imcomplete = True if word_imcomplete: # if print_it: # print(txt + '\t' + tA + '\t' + '×') continue ret.append([words_i, prob]) return sorted(ret, key=lambda x: x[1], reverse=True)[:k] def extract_words_for_tA(self, tA, k=6): word_mask_str = ' '.join([self.tokenizer.mask_token] * self.word_length) tA = tA.replace('', word_mask_str) mask_count = tA.count(self.tokenizer.mask_token) mask_probs = self.explore_mask(tA, k*20, [], 1.0, mask_count, []) ret = [] visited_mask_txt = {} for mask, prob, probs in mask_probs: mask_txt = ' '.join(mask).lower() if mask_txt in visited_mask_txt: continue visited_mask_txt[mask_txt] = 1 words = [] probs_words = [] for i in range(0,mask_count, self.word_length): words.append(' '.join(mask[i: i + self.word_length])) prob_word = 1.0 for j in range(i, i + self.word_length): prob_word *= probs[j] probs_words.append(prob_word) ret.append([words, prob, probs_words]) return sorted(ret, key=lambda x: x[1], reverse=True)[:k] def extract_templateBs_batch(self, words_prob, tA, k, print_it = False): words_prob_sorted = [] for (words, probA, *_) in words_prob: tokenized_word = self.tokenizer(words[0]) words_prob_sorted.append([words,probA,len(tokenized_word['input_ids'])]) words_prob_sorted.sort(key=lambda x:x[2]) batch_size = 8 templates = [] index_words = {} ret = {} num_beams = k for enum, (words, probA, *_) in enumerate(words_prob_sorted): template = construct_template(words, tA, self.if_then) templates.extend(template) for t in template: index_words[len(index_words)] = '\t'.join(words) # index_words[len(templates)-1] = '\t'.join(words) if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]): generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].to(device).to(torch.int64) generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams, num_beam_groups=num_beams, max_length=28, #template_length+5, num_return_sequences=num_beams, min_length=3, diversity_penalty=1.0, early_stopping=True, #length_penalty = 0.1, bad_words_ids=self.bad_words_ids, #no_repeat_ngram_size=2, output_scores=True, return_dict_in_generate=True, decoder_ori_input_ids = generated_ids, top_p=0.95, ) summary_ids = generated_ret['sequences'].reshape((len(templates),num_beams,-1)) probs = F.softmax(generated_ret['sequences_scores'].reshape((len(templates),num_beams)),dim=1).to(torch.float32) for ii in range(summary_ids.size(0)): txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids[ii]] ii_template = [] words_ii = index_words[ii].split('\t') for i, txt in enumerate(txts): prob = probs[ii][i].item() * probA txt = txt.lower() txt = post_process_template(txt) words_ii_matched = [word.lower() for word in words_ii] #extract_similar_words(txt, words_ii) if words_ii_matched is None: prob = 0.0 else: for j, word in enumerate(words_ii_matched): if word not in txt: prob = 0.0 else: txt = txt.replace(word, ''.format(j), 1) if txt.count(' ')+1<=3: continue ii_template.append([txt, prob]) # if print_it: # print(index_words[ii]+'\t'+str(convert_for_print(ii_template))) for template, prob in ii_template: if template not in ret: ret[template] = 0.0 ret[template] += prob templates.clear() index_words.clear() return ret def generate_rule(self, tA, k=10, print_it = False): tA=formalize_tA(tA) if 'bart' in str(self.orion_instance_generator.__class__).lower(): words_prob = self.extract_words_for_tA_bart(tA, k,print_it=print_it) words_prob = filter_words(words_prob)[:k] # if print_it: # print(convert_for_print(words_prob)) else: words_prob = self.extract_words_for_tA(tA, k) words_prob = filter_words(words_prob)[:k] tB_prob = self.extract_templateBs_batch(words_prob, tA, k,print_it=print_it) ret = [] for k1 in tB_prob: ret.append([k1, tB_prob[k1]]) ret = sorted(ret, key=lambda x: x[1], reverse=True)[:k] if self.if_then: for i, temp in enumerate(ret): sentence = temp[0] if "then" in sentence: sentence = sentence.split("then")[-1] else: sentence = sentence.replace("if", "") ret[i][0] = sentence return ret