Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files- expbert.py +11 -9
- inductor.py +13 -11
expbert.py
CHANGED
@@ -42,15 +42,17 @@ GENERATED_EXP = {
|
|
42 |
"disease": "data/exp/orion_disease_explanation.txt",
|
43 |
}
|
44 |
|
|
|
45 |
|
46 |
def set_random_seed(seed):
|
47 |
random.seed(seed)
|
48 |
np.random.seed(seed)
|
49 |
torch.manual_seed(seed)
|
50 |
-
torch.cuda.
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
|
55 |
|
56 |
def print_config(config):
|
@@ -77,12 +79,12 @@ class ExpBERT(nn.Module):
|
|
77 |
|
78 |
def forward(self, inputs):
|
79 |
for k, v in inputs["encoding"].items():
|
80 |
-
inputs["encoding"][k] = v.
|
81 |
pooler_output = self.model(**inputs["encoding"]).last_hidden_state[:, 0, :].reshape(1, self.exp_num * self.config.hidden_size)
|
82 |
pooler_output = self.dropout(pooler_output)
|
83 |
logits = self.linear(pooler_output)
|
84 |
|
85 |
-
loss = self.criterion(logits, torch.LongTensor([inputs["label"]]).
|
86 |
prediction = torch.argmax(logits)
|
87 |
|
88 |
return {
|
@@ -187,7 +189,7 @@ class Trainer(object):
|
|
187 |
|
188 |
self.train_dataset = REDataset(TASK2PATH['{}-train'.format(args.task)], exp, self.tokenizer)
|
189 |
self.test_dataset = REDataset(TASK2PATH['{}-test'.format(args.task)], exp, self.tokenizer)
|
190 |
-
self.model = AutoModelForSequenceClassification.from_pretrained(args.model).
|
191 |
|
192 |
self.train_loader = DataLoader(
|
193 |
self.train_dataset,
|
@@ -220,7 +222,7 @@ class Trainer(object):
|
|
220 |
self.model.zero_grad()
|
221 |
if self.args.no_exp:
|
222 |
for k, v in examples.items():
|
223 |
-
examples[k] = v.
|
224 |
outputs = self.model(**examples)
|
225 |
outputs.loss.backward()
|
226 |
|
@@ -245,7 +247,7 @@ class Trainer(object):
|
|
245 |
for step, examples in enumerate(self.test_loader):
|
246 |
if self.args.no_exp:
|
247 |
for k, v in examples.items():
|
248 |
-
examples[k] = v.
|
249 |
outputs = self.model(**examples)
|
250 |
loss.append(outputs.loss.float())
|
251 |
labels.extend(examples["labels"].tolist())
|
|
|
42 |
"disease": "data/exp/orion_disease_explanation.txt",
|
43 |
}
|
44 |
|
45 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
|
47 |
def set_random_seed(seed):
|
48 |
random.seed(seed)
|
49 |
np.random.seed(seed)
|
50 |
torch.manual_seed(seed)
|
51 |
+
if torch.cuda.is_available():
|
52 |
+
torch.cuda.manual_seed(seed)
|
53 |
+
torch.cuda.manual_seed_all(seed)
|
54 |
+
torch.backends.cudnn.deterministic = True
|
55 |
+
torch.backends.cudnn.benchmark = False
|
56 |
|
57 |
|
58 |
def print_config(config):
|
|
|
79 |
|
80 |
def forward(self, inputs):
|
81 |
for k, v in inputs["encoding"].items():
|
82 |
+
inputs["encoding"][k] = v.to(device)
|
83 |
pooler_output = self.model(**inputs["encoding"]).last_hidden_state[:, 0, :].reshape(1, self.exp_num * self.config.hidden_size)
|
84 |
pooler_output = self.dropout(pooler_output)
|
85 |
logits = self.linear(pooler_output)
|
86 |
|
87 |
+
loss = self.criterion(logits, torch.LongTensor([inputs["label"]]).to(device))
|
88 |
prediction = torch.argmax(logits)
|
89 |
|
90 |
return {
|
|
|
189 |
|
190 |
self.train_dataset = REDataset(TASK2PATH['{}-train'.format(args.task)], exp, self.tokenizer)
|
191 |
self.test_dataset = REDataset(TASK2PATH['{}-test'.format(args.task)], exp, self.tokenizer)
|
192 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(args.model).to(device) if self.args.no_exp else ExpBERT(args, len(exp)).to(device)
|
193 |
|
194 |
self.train_loader = DataLoader(
|
195 |
self.train_dataset,
|
|
|
222 |
self.model.zero_grad()
|
223 |
if self.args.no_exp:
|
224 |
for k, v in examples.items():
|
225 |
+
examples[k] = v.to(device)
|
226 |
outputs = self.model(**examples)
|
227 |
outputs.loss.backward()
|
228 |
|
|
|
247 |
for step, examples in enumerate(self.test_loader):
|
248 |
if self.args.no_exp:
|
249 |
for k, v in examples.items():
|
250 |
+
examples[k] = v.to(device)
|
251 |
outputs = self.model(**examples)
|
252 |
loss.append(outputs.loss.float())
|
253 |
labels.extend(examples["labels"].tolist())
|
inductor.py
CHANGED
@@ -11,6 +11,8 @@ from src.bart_with_group_beam import BartForConditionalGeneration_GroupBeam
|
|
11 |
from src.utils import (construct_template, filter_words,
|
12 |
formalize_tA, post_process_template)
|
13 |
|
|
|
|
|
14 |
ORION_HYPO_GENERATOR = 'chenxran/orion-hypothesis-generator'
|
15 |
ORION_INS_GENERATOR = 'chenxran/orion-instance-generator'
|
16 |
|
@@ -41,11 +43,11 @@ class BartInductor(object):
|
|
41 |
self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
|
42 |
|
43 |
if group_beam:
|
44 |
-
self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).
|
45 |
else:
|
46 |
-
self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).
|
47 |
|
48 |
-
self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).
|
49 |
|
50 |
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
51 |
self.word_length = 2
|
@@ -69,7 +71,7 @@ class BartInductor(object):
|
|
69 |
self.bad_words_ids = [self.tokenizer.encode(bad_word)[1:-1] for bad_word in ['also', ' also']]
|
70 |
stop_index = self.tokenizer(self.stop_sub_list, max_length=4, padding=True)
|
71 |
stop_index = torch.tensor(stop_index['input_ids'])[:, 1]
|
72 |
-
stop_weight = torch.zeros(1, self.tokenizer.vocab_size).
|
73 |
stop_weight[0, stop_index] -= 100
|
74 |
self.stop_weight = stop_weight[0, :]
|
75 |
|
@@ -99,9 +101,9 @@ class BartInductor(object):
|
|
99 |
if required_token <= self.word_length:
|
100 |
k = min(k, 2)
|
101 |
ret = []
|
102 |
-
generated_ids = self.tokenizer(tA, max_length=128, padding='longest', return_tensors='pt') # ["input_ids"].
|
103 |
for key in generated_ids.keys():
|
104 |
-
generated_ids[key] = generated_ids[key].
|
105 |
mask_index = torch.where(generated_ids["input_ids"][0] == self.tokenizer.mask_token_id)
|
106 |
generated_ret = self.orion_instance_generator(**generated_ids)
|
107 |
#logits = generated_ret.logits
|
@@ -126,7 +128,7 @@ class BartInductor(object):
|
|
126 |
|
127 |
def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
|
128 |
spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
|
129 |
-
generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].
|
130 |
generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
|
131 |
#num_beam_groups=max(120, k),
|
132 |
max_length=generated_ids.size(1) + 15,
|
@@ -222,7 +224,7 @@ class BartInductor(object):
|
|
222 |
index_words[len(index_words)] = '\t'.join(words)
|
223 |
# index_words[len(templates)-1] = '\t'.join(words)
|
224 |
if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
|
225 |
-
generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].
|
226 |
generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
|
227 |
num_beam_groups=num_beams,
|
228 |
max_length=28, #template_length+5,
|
@@ -304,7 +306,7 @@ class BartInductor(object):
|
|
304 |
|
305 |
class CometInductor(object):
|
306 |
def __init__(self):
|
307 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").
|
308 |
self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
|
309 |
self.task = "summarization"
|
310 |
self.use_task_specific_params()
|
@@ -387,8 +389,8 @@ class CometInductor(object):
|
|
387 |
input_ids, attention_mask = self.trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id)
|
388 |
|
389 |
summaries = self.model.generate(
|
390 |
-
input_ids=input_ids.
|
391 |
-
attention_mask=attention_mask.
|
392 |
decoder_start_token_id=self.decoder_start_token_id,
|
393 |
num_beams=num_generate,
|
394 |
num_return_sequences=num_generate,
|
|
|
11 |
from src.utils import (construct_template, filter_words,
|
12 |
formalize_tA, post_process_template)
|
13 |
|
14 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
+
|
16 |
ORION_HYPO_GENERATOR = 'chenxran/orion-hypothesis-generator'
|
17 |
ORION_INS_GENERATOR = 'chenxran/orion-instance-generator'
|
18 |
|
|
|
43 |
self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
|
44 |
|
45 |
if group_beam:
|
46 |
+
self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval().half()
|
47 |
else:
|
48 |
+
self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval().half()
|
49 |
|
50 |
+
self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).to(device).eval().half()
|
51 |
|
52 |
self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
|
53 |
self.word_length = 2
|
|
|
71 |
self.bad_words_ids = [self.tokenizer.encode(bad_word)[1:-1] for bad_word in ['also', ' also']]
|
72 |
stop_index = self.tokenizer(self.stop_sub_list, max_length=4, padding=True)
|
73 |
stop_index = torch.tensor(stop_index['input_ids'])[:, 1]
|
74 |
+
stop_weight = torch.zeros(1, self.tokenizer.vocab_size).to(device)
|
75 |
stop_weight[0, stop_index] -= 100
|
76 |
self.stop_weight = stop_weight[0, :]
|
77 |
|
|
|
101 |
if required_token <= self.word_length:
|
102 |
k = min(k, 2)
|
103 |
ret = []
|
104 |
+
generated_ids = self.tokenizer(tA, max_length=128, padding='longest', return_tensors='pt') # ["input_ids"].to(device)
|
105 |
for key in generated_ids.keys():
|
106 |
+
generated_ids[key] = generated_ids[key].to(device)
|
107 |
mask_index = torch.where(generated_ids["input_ids"][0] == self.tokenizer.mask_token_id)
|
108 |
generated_ret = self.orion_instance_generator(**generated_ids)
|
109 |
#logits = generated_ret.logits
|
|
|
128 |
|
129 |
def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
|
130 |
spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
|
131 |
+
generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device)
|
132 |
generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
|
133 |
#num_beam_groups=max(120, k),
|
134 |
max_length=generated_ids.size(1) + 15,
|
|
|
224 |
index_words[len(index_words)] = '\t'.join(words)
|
225 |
# index_words[len(templates)-1] = '\t'.join(words)
|
226 |
if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
|
227 |
+
generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].to(device)
|
228 |
generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
|
229 |
num_beam_groups=num_beams,
|
230 |
max_length=28, #template_length+5,
|
|
|
306 |
|
307 |
class CometInductor(object):
|
308 |
def __init__(self):
|
309 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").to(device).eval() # .half()
|
310 |
self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
|
311 |
self.task = "summarization"
|
312 |
self.use_task_specific_params()
|
|
|
389 |
input_ids, attention_mask = self.trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id)
|
390 |
|
391 |
summaries = self.model.generate(
|
392 |
+
input_ids=input_ids.to(device),
|
393 |
+
attention_mask=attention_mask.to(device),
|
394 |
decoder_start_token_id=self.decoder_start_token_id,
|
395 |
num_beams=num_generate,
|
396 |
num_return_sequences=num_generate,
|