andreslu commited on
Commit
4f254e5
1 Parent(s): cfa4d37

Upload 2 files

Browse files
Files changed (2) hide show
  1. expbert.py +11 -9
  2. inductor.py +13 -11
expbert.py CHANGED
@@ -42,15 +42,17 @@ GENERATED_EXP = {
42
  "disease": "data/exp/orion_disease_explanation.txt",
43
  }
44
 
 
45
 
46
  def set_random_seed(seed):
47
  random.seed(seed)
48
  np.random.seed(seed)
49
  torch.manual_seed(seed)
50
- torch.cuda.manual_seed(seed)
51
- torch.cuda.manual_seed_all(seed)
52
- torch.backends.cudnn.deterministic = True
53
- torch.backends.cudnn.benchmark = False
 
54
 
55
 
56
  def print_config(config):
@@ -77,12 +79,12 @@ class ExpBERT(nn.Module):
77
 
78
  def forward(self, inputs):
79
  for k, v in inputs["encoding"].items():
80
- inputs["encoding"][k] = v.cuda()
81
  pooler_output = self.model(**inputs["encoding"]).last_hidden_state[:, 0, :].reshape(1, self.exp_num * self.config.hidden_size)
82
  pooler_output = self.dropout(pooler_output)
83
  logits = self.linear(pooler_output)
84
 
85
- loss = self.criterion(logits, torch.LongTensor([inputs["label"]]).cuda())
86
  prediction = torch.argmax(logits)
87
 
88
  return {
@@ -187,7 +189,7 @@ class Trainer(object):
187
 
188
  self.train_dataset = REDataset(TASK2PATH['{}-train'.format(args.task)], exp, self.tokenizer)
189
  self.test_dataset = REDataset(TASK2PATH['{}-test'.format(args.task)], exp, self.tokenizer)
190
- self.model = AutoModelForSequenceClassification.from_pretrained(args.model).cuda() if self.args.no_exp else ExpBERT(args, len(exp)).cuda()
191
 
192
  self.train_loader = DataLoader(
193
  self.train_dataset,
@@ -220,7 +222,7 @@ class Trainer(object):
220
  self.model.zero_grad()
221
  if self.args.no_exp:
222
  for k, v in examples.items():
223
- examples[k] = v.cuda()
224
  outputs = self.model(**examples)
225
  outputs.loss.backward()
226
 
@@ -245,7 +247,7 @@ class Trainer(object):
245
  for step, examples in enumerate(self.test_loader):
246
  if self.args.no_exp:
247
  for k, v in examples.items():
248
- examples[k] = v.cuda()
249
  outputs = self.model(**examples)
250
  loss.append(outputs.loss.float())
251
  labels.extend(examples["labels"].tolist())
 
42
  "disease": "data/exp/orion_disease_explanation.txt",
43
  }
44
 
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
 
47
  def set_random_seed(seed):
48
  random.seed(seed)
49
  np.random.seed(seed)
50
  torch.manual_seed(seed)
51
+ if torch.cuda.is_available():
52
+ torch.cuda.manual_seed(seed)
53
+ torch.cuda.manual_seed_all(seed)
54
+ torch.backends.cudnn.deterministic = True
55
+ torch.backends.cudnn.benchmark = False
56
 
57
 
58
  def print_config(config):
 
79
 
80
  def forward(self, inputs):
81
  for k, v in inputs["encoding"].items():
82
+ inputs["encoding"][k] = v.to(device)
83
  pooler_output = self.model(**inputs["encoding"]).last_hidden_state[:, 0, :].reshape(1, self.exp_num * self.config.hidden_size)
84
  pooler_output = self.dropout(pooler_output)
85
  logits = self.linear(pooler_output)
86
 
87
+ loss = self.criterion(logits, torch.LongTensor([inputs["label"]]).to(device))
88
  prediction = torch.argmax(logits)
89
 
90
  return {
 
189
 
190
  self.train_dataset = REDataset(TASK2PATH['{}-train'.format(args.task)], exp, self.tokenizer)
191
  self.test_dataset = REDataset(TASK2PATH['{}-test'.format(args.task)], exp, self.tokenizer)
192
+ self.model = AutoModelForSequenceClassification.from_pretrained(args.model).to(device) if self.args.no_exp else ExpBERT(args, len(exp)).to(device)
193
 
194
  self.train_loader = DataLoader(
195
  self.train_dataset,
 
222
  self.model.zero_grad()
223
  if self.args.no_exp:
224
  for k, v in examples.items():
225
+ examples[k] = v.to(device)
226
  outputs = self.model(**examples)
227
  outputs.loss.backward()
228
 
 
247
  for step, examples in enumerate(self.test_loader):
248
  if self.args.no_exp:
249
  for k, v in examples.items():
250
+ examples[k] = v.to(device)
251
  outputs = self.model(**examples)
252
  loss.append(outputs.loss.float())
253
  labels.extend(examples["labels"].tolist())
inductor.py CHANGED
@@ -11,6 +11,8 @@ from src.bart_with_group_beam import BartForConditionalGeneration_GroupBeam
11
  from src.utils import (construct_template, filter_words,
12
  formalize_tA, post_process_template)
13
 
 
 
14
  ORION_HYPO_GENERATOR = 'chenxran/orion-hypothesis-generator'
15
  ORION_INS_GENERATOR = 'chenxran/orion-instance-generator'
16
 
@@ -41,11 +43,11 @@ class BartInductor(object):
41
  self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
42
 
43
  if group_beam:
44
- self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).cuda().eval().half()
45
  else:
46
- self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).cuda().eval().half()
47
 
48
- self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).cuda().eval().half()
49
 
50
  self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
51
  self.word_length = 2
@@ -69,7 +71,7 @@ class BartInductor(object):
69
  self.bad_words_ids = [self.tokenizer.encode(bad_word)[1:-1] for bad_word in ['also', ' also']]
70
  stop_index = self.tokenizer(self.stop_sub_list, max_length=4, padding=True)
71
  stop_index = torch.tensor(stop_index['input_ids'])[:, 1]
72
- stop_weight = torch.zeros(1, self.tokenizer.vocab_size).cuda()
73
  stop_weight[0, stop_index] -= 100
74
  self.stop_weight = stop_weight[0, :]
75
 
@@ -99,9 +101,9 @@ class BartInductor(object):
99
  if required_token <= self.word_length:
100
  k = min(k, 2)
101
  ret = []
102
- generated_ids = self.tokenizer(tA, max_length=128, padding='longest', return_tensors='pt') # ["input_ids"].cuda()
103
  for key in generated_ids.keys():
104
- generated_ids[key] = generated_ids[key].cuda()
105
  mask_index = torch.where(generated_ids["input_ids"][0] == self.tokenizer.mask_token_id)
106
  generated_ret = self.orion_instance_generator(**generated_ids)
107
  #logits = generated_ret.logits
@@ -126,7 +128,7 @@ class BartInductor(object):
126
 
127
  def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
128
  spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
129
- generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].cuda()
130
  generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
131
  #num_beam_groups=max(120, k),
132
  max_length=generated_ids.size(1) + 15,
@@ -222,7 +224,7 @@ class BartInductor(object):
222
  index_words[len(index_words)] = '\t'.join(words)
223
  # index_words[len(templates)-1] = '\t'.join(words)
224
  if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
225
- generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].cuda()
226
  generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
227
  num_beam_groups=num_beams,
228
  max_length=28, #template_length+5,
@@ -304,7 +306,7 @@ class BartInductor(object):
304
 
305
  class CometInductor(object):
306
  def __init__(self):
307
- self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").cuda().eval() # .half()
308
  self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
309
  self.task = "summarization"
310
  self.use_task_specific_params()
@@ -387,8 +389,8 @@ class CometInductor(object):
387
  input_ids, attention_mask = self.trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id)
388
 
389
  summaries = self.model.generate(
390
- input_ids=input_ids.cuda(),
391
- attention_mask=attention_mask.cuda(),
392
  decoder_start_token_id=self.decoder_start_token_id,
393
  num_beams=num_generate,
394
  num_return_sequences=num_generate,
 
11
  from src.utils import (construct_template, filter_words,
12
  formalize_tA, post_process_template)
13
 
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
  ORION_HYPO_GENERATOR = 'chenxran/orion-hypothesis-generator'
17
  ORION_INS_GENERATOR = 'chenxran/orion-instance-generator'
18
 
 
43
  self.orion_hypothesis_generator_path = 'facebook/bart-large' if not continue_pretrain_hypo_generator else ORION_HYPO_GENERATOR
44
 
45
  if group_beam:
46
+ self.orion_hypothesis_generator = BartForConditionalGeneration_GroupBeam.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval().half()
47
  else:
48
+ self.orion_hypothesis_generator = BartForConditionalGeneration.from_pretrained(self.orion_hypothesis_generator_path).to(device).eval().half()
49
 
50
+ self.orion_instance_generator = BartForConditionalGeneration.from_pretrained(self.orion_instance_generator_path).to(device).eval().half()
51
 
52
  self.tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
53
  self.word_length = 2
 
71
  self.bad_words_ids = [self.tokenizer.encode(bad_word)[1:-1] for bad_word in ['also', ' also']]
72
  stop_index = self.tokenizer(self.stop_sub_list, max_length=4, padding=True)
73
  stop_index = torch.tensor(stop_index['input_ids'])[:, 1]
74
+ stop_weight = torch.zeros(1, self.tokenizer.vocab_size).to(device)
75
  stop_weight[0, stop_index] -= 100
76
  self.stop_weight = stop_weight[0, :]
77
 
 
101
  if required_token <= self.word_length:
102
  k = min(k, 2)
103
  ret = []
104
+ generated_ids = self.tokenizer(tA, max_length=128, padding='longest', return_tensors='pt') # ["input_ids"].to(device)
105
  for key in generated_ids.keys():
106
+ generated_ids[key] = generated_ids[key].to(device)
107
  mask_index = torch.where(generated_ids["input_ids"][0] == self.tokenizer.mask_token_id)
108
  generated_ret = self.orion_instance_generator(**generated_ids)
109
  #logits = generated_ret.logits
 
128
 
129
  def extract_words_for_tA_bart(self, tA, k=6, print_it = False):
130
  spans = [t.lower().strip() for t in tA[:-1].split('<mask>')]
131
+ generated_ids = self.tokenizer([tA], padding='longest', return_tensors='pt')['input_ids'].to(device)
132
  generated_ret = self.orion_instance_generator.generate(generated_ids, num_beams=max(120, k),
133
  #num_beam_groups=max(120, k),
134
  max_length=generated_ids.size(1) + 15,
 
224
  index_words[len(index_words)] = '\t'.join(words)
225
  # index_words[len(templates)-1] = '\t'.join(words)
226
  if (len(templates) == batch_size) or enum==len(words_prob_sorted)-1 or (words_prob_sorted[enum+1][2]!=words_prob_sorted[enum][2]):
227
+ generated_ids = self.tokenizer(templates, padding="longest", return_tensors='pt')['input_ids'].to(device)
228
  generated_ret = self.orion_hypothesis_generator.generate(generated_ids, num_beams=num_beams,
229
  num_beam_groups=num_beams,
230
  max_length=28, #template_length+5,
 
306
 
307
  class CometInductor(object):
308
  def __init__(self):
309
+ self.model = AutoModelForSeq2SeqLM.from_pretrained("adamlin/comet-atomic_2020_BART").to(device).eval() # .half()
310
  self.tokenizer = AutoTokenizer.from_pretrained("adamlin/comet-atomic_2020_BART")
311
  self.task = "summarization"
312
  self.use_task_specific_params()
 
389
  input_ids, attention_mask = self.trim_batch(**batch, pad_token_id=self.tokenizer.pad_token_id)
390
 
391
  summaries = self.model.generate(
392
+ input_ids=input_ids.to(device),
393
+ attention_mask=attention_mask.to(device),
394
  decoder_start_token_id=self.decoder_start_token_id,
395
  num_beams=num_generate,
396
  num_return_sequences=num_generate,