orion / inductor.py
andreslu's picture
Update inductor.py
a54b7f1
raw
history blame contribute delete
No virus
15.1 kB
import re
from copy import deepcopy
import argparse
import torch
import torch.nn.functional as F
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
BartForConditionalGeneration, BartTokenizer,)
from src.bart_with_group_beam import BartForConditionalGeneration_GroupBeam
from src.utils import (construct_template, filter_words,
formalize_tA, post_process_template)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ORION_HYPO_GENERATOR = 'chenxran/orion-hypothesis-generator'
ORION_INS_GENERATOR = 'chenxran/orion-instance-generator'
RELATIONS = [
"Causes",
"HasProperty",
"MadeUpOf",
"isAfter",
"isBefore",
"xReact",
"xWant",
"xReason",
"xAttr",
"Desires",
]
class BartInductor(object):
def __init__(
self,
group_beam=True,
continue_pretrain_instance_generator=True,
continue_pretrain_hypo_generator=True,
if_then=False
):
self.if_then = if_then
self.orion_instance_generator_path = 'facebook/bart-large' if not continue_pretrain_instance_generator else ORION_INS_GENERATOR
self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
if group_beam:
self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval()
else:
self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval()
self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).to(device).eval()
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large", use_fast=True)
self.word_length = 2
self.stop_sub_list = ['he', 'she', 'this', 'that', 'and', 'it', 'which', 'who', 'whose', 'there', 'they', '.', 'its', 'one',
'i', ',', 'the', 'nobody', 'his', 'her', 'also', 'only', 'currently', 'here', '()', 'what', 'where',
'why', 'a', 'some', '"', ')', '(', 'now', 'everyone', 'everybody', 'their', 'often', 'usually', 'you',
'-', '?', ';', 'in', 'on', 'each', 'both', 'him', 'typically', 'mostly', 'sometimes', 'normally',
'always', 'usually', 'still', 'today', 'was', 'were', 'but', 'although', 'current', 'all', 'have',
'has', 'later', 'with', 'most', 'nowadays', 'then', 'every', 'when', 'someone', 'anyone', 'somebody',
'anybody', 'any', 'being', 'get', 'getting', 'thus', 'under', 'even', 'for', 'can', 'rarely', 'never',
'may', 'generally', 'other', 'another', 'too', 'first', 'second', 'third', 'mainly', 'primarily',
'having', 'have', 'has']
self.stop_size = len(self.stop_sub_list)
for i in range(self.stop_size):
if self.stop_sub_list[i][0].isalpha():
temp = self.stop_sub_list[i][0].upper() + self.stop_sub_list[i][1:]
self.stop_sub_list.append(temp)
self.bad_words_ids = [self.tokenizer.encode(bad_word)[1:-1] for bad_word in ['also', ' also']]
stop_index = self.tokenizer(self.stop_sub_list, max_length=4, padding=True)
stop_index = torch.tensor(stop_index['input_ids'])[:, 1]
stop_weight = torch.zeros(1, self.tokenizer.vocab_size).to(device)
stop_weight[0, stop_index] -= 100
self.stop_weight = stop_weight[0, :]
def clean(self, text):
segments = re.split(r'<ent\d>', text)
last_segment = segments[-1]
if last_segment.startswith('.'):
return text[:text.rfind(last_segment)]+'.'
else:
return text
def generate(self, inputs, k=10, topk=10, return_scores=False):
with torch.no_grad():
tB_probs = self.generate_rule(inputs, k)
new_ret = []
if return_scores:
ret = [(t[0], t[1]) for t in tB_probs]
for temp in ret:
temp = (self.clean(temp[0].strip()), temp[1])
if len(new_ret) < topk and temp not in new_ret:
new_ret.append(temp)
else:
ret = [t[0] for t in tB_probs]
for temp in ret:
temp = self.clean(temp.strip())
if len(new_ret) < topk and temp not in new_ret:
new_ret.append(temp)
return new_ret
def explore_mask(self, tA, k, tokens, prob, required_token, probs):
if required_token == 0:
return [[tokens, prob, probs]]
if required_token <= self.word_length:
k = min(k, 2)
ret = []
generated_ids = self.tokenizer(tA, max_length=128, padding='longest', return_tensors='pt') # ["input_ids"].to(device)
for key in generated_ids.keys():
generated_ids[key] = generated_ids[key].to(device)
mask_index = torch.where(generated_ids["input_ids"][0] == self.tokenizer.mask_token_id)
generated_ret = self.orion_instance_generator(**generated_ids)
#logits = generated_ret.logits
logits = generated_ret[0]
softmax = F.softmax(logits, dim=-1)
mask_word = softmax[0, mask_index[0][0], :] + self.stop_weight
top_k = torch.topk(mask_word, k, dim=0)
for i in range(top_k[1].size(0)):
token_s = top_k[1][i]
prob_s = top_k[0][i].item()
token_this = self.tokenizer.decode([token_s]).strip()
if token_this[0].isalpha() == False or len(token_this) <= 2:
continue
index_s = tA.index(self.tokenizer.mask_token)
tAs = tA[:index_s] + token_this + tA[index_s + len(self.tokenizer.mask_token):]
tokens_this = [t for t in tokens]
tokens_this.append(token_this)
probs_new = deepcopy(probs)
probs_new.append(prob_s)
ret.extend(self.explore_mask(tAs, 1, tokens_this, prob_s * prob, required_token - 1,probs_new))
return ret
def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
spans = [t.lower().strip() for t in re.split(r'<.*?>', tA[:-1])]
generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device).to(torch.int64)
generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
#num_beam_groups=max(120, k),
max_length=generated_ids.size(1) + 15,
num_return_sequences=max(120, k), #min_length=generated_ids.size(1),
#diversity_penalty=2.0,
#length_penalty= 0.8,
#early_stopping=True, bad_words_ids=bad_words_ids, no_repeat_ngram_size=2,
output_scores=True,
return_dict_in_generate=True)
summary_ids = generated_ret['sequences']
probs = F.softmax(generated_ret['sequences_scores'].to(torch.float32))
txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summary_ids]
ret = []
for i, txt in enumerate(txts):
if tA.endswith('.'):
if txt.endswith('.'):
txt = txt[:-1].strip()
txt += '.'
word_imcomplete = False
prob = probs[i].item()
words_i = []
start_index = 0
for j in range(len(spans)-1):
span1 = spans[j]
span2 = spans[j+1]
if (span1 in txt.lower()[start_index:]) and (span2 in txt.lower()[start_index:]):
index1 = txt.lower().index(span1,start_index)+len(span1)
if span2 == '':
if txt[-1] == '.':
index2 = len(txt) -1
else:
index2 = len(txt)
else:
index2 = txt.lower().index(span2, start_index)
words_i.append(txt[index1:index2].strip())
start_index = index2
#if words_i[-1] == '':
# word_imcomplete = True
else:
word_imcomplete = True
if word_imcomplete:
# if print_it:
# print(txt + '\t' + tA + '\t' + '×')
continue
ret.append([words_i, prob])
return sorted(ret, key=lambda x: x[1], reverse=True)[:k]
def extract_words_for_tA(self, tA, k=6):
word_mask_str = ' '.join([self.tokenizer.mask_token] * self.word_length)
tA = tA.replace('<mask>', word_mask_str)
mask_count = tA.count(self.tokenizer.mask_token)
mask_probs = self.explore_mask(tA, k*20, [], 1.0, mask_count, [])
ret = []
visited_mask_txt = {}
for mask, prob, probs in mask_probs:
mask_txt = ' '.join(mask).lower()
if mask_txt in visited_mask_txt:
continue
visited_mask_txt[mask_txt] = 1
words = []
probs_words = []
for i in range(0,mask_count, self.word_length):
words.append(' '.join(mask[i: i + self.word_length]))
prob_word = 1.0
for j in range(i, i + self.word_length):
prob_word *= probs[j]
probs_words.append(prob_word)
ret.append([words, prob, probs_words])
return sorted(ret, key=lambda x: x[1], reverse=True)[:k]
def extract_templateBs_batch(self, words_prob, tA, k, print_it = False):
words_prob_sorted = []
for (words, probA, *_) in words_prob:
tokenized_word = self.tokenizer(words[0])
words_prob_sorted.append([words,probA,len(tokenized_word['input_ids'])])
words_prob_sorted.sort(key=lambda x:x[2])
batch_size = 8
templates = []
index_words = {}
ret = {}
num_beams = k
for enum, (words, probA, *_) in enumerate(words_prob_sorted):
template = construct_template(words, tA, self.if_then)
templates.extend(template)
for t in template:
index_words[len(index_words)] = '\t'.join(words)
# index_words[len(templates)-1] = '\t'.join(words)
if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].to(device).to(torch.int64)
generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
num_beam_groups=num_beams,
max_length=28, #template_length+5,
num_return_sequences=num_beams, min_length=3,
diversity_penalty=1.0,
early_stopping=True,
#length_penalty = 0.1,
bad_words_ids=self.bad_words_ids,
#no_repeat_ngram_size=2,
output_scores=True,
return_dict_in_generate=True, decoder_ori_input_ids = generated_ids,
top_p=0.95,
)
summary_ids = generated_ret['sequences'].reshape((len(templates),num_beams,-1))
probs = F.softmax(generated_ret['sequences_scores'].reshape((len(templates),num_beams)),dim=1).to(torch.float32)
for ii in range(summary_ids.size(0)):
txts = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in
summary_ids[ii]]
ii_template = []
words_ii = index_words[ii].split('\t')
for i, txt in enumerate(txts):
prob = probs[ii][i].item() * probA
txt = txt.lower()
txt = post_process_template(txt)
words_ii_matched = [word.lower() for word in words_ii] #extract_similar_words(txt, words_ii)
if words_ii_matched is None:
prob = 0.0
else:
for j, word in enumerate(words_ii_matched):
if word not in txt:
prob = 0.0
else:
txt = txt.replace(word, '<ent{}>'.format(j), 1)
if txt.count(' ')+1<=3:
continue
ii_template.append([txt, prob])
# if print_it:
# print(index_words[ii]+'\t'+str(convert_for_print(ii_template)))
for template, prob in ii_template:
if template not in ret:
ret[template] = 0.0
ret[template] += prob
templates.clear()
index_words.clear()
return ret
def generate_rule(self, tA, k=10, print_it = False):
tA=formalize_tA(tA)
if 'bart' in str(self.orion_instance_generator.__class__).lower():
words_prob = self.extract_words_for_tA_bart(tA, k,print_it=print_it)
words_prob = filter_words(words_prob)[:k]
# if print_it:
# print(convert_for_print(words_prob))
else:
words_prob = self.extract_words_for_tA(tA, k)
words_prob = filter_words(words_prob)[:k]
tB_prob = self.extract_templateBs_batch(words_prob, tA, k,print_it=print_it)
ret = []
for k1 in tB_prob:
ret.append([k1, tB_prob[k1]])
ret = sorted(ret, key=lambda x: x[1], reverse=True)[:k]
if self.if_then:
for i, temp in enumerate(ret):
sentence = temp[0]
if "then" in sentence:
sentence = sentence.split("then")[-1]
else:
sentence = sentence.replace("if", "")
ret[i][0] = sentence
return ret