import os import argparse from sys import prefix import torch import logging import json import pytorch_lightning as pl from transformers import BartTokenizer, BartConfig from transformers import AdamW, get_linear_schedule_with_warmup from .network import BartGen from .constrained_gen import BartConstrainedGen logger = logging.getLogger(__name__) print("model.py") class GenIEModel(pl.LightningModule): def __init__(self, args): super().__init__() self.hparams = args self.config=BartConfig.from_pretrained('facebook/bart-large') self.tokenizer = BartTokenizer.from_pretrained('facebook/bart-large') self.tokenizer.add_tokens([' ',' ']) if self.hparams.model=='gen': self.model = BartGen(self.config, self.tokenizer) self.model.resize_token_embeddings() elif self.hparams.model == 'constrained-gen': self.model = BartConstrainedGen(self.config, self.tokenizer) self.model.resize_token_embeddings() else: raise NotImplementedError def forward(self, inputs): return self.model(**inputs) def training_step(self, batch, batch_idx): ''' processed_ex = { 'doc_key': ex['doc_key'], 'input_tokens_ids':input_tokens['input_ids'], 'input_attn_mask': input_tokens['attention_mask'], 'tgt_token_ids': tgt_tokens['input_ids'], 'tgt_attn_mask': tgt_tokens['attention_mask'], } ''' inputs = { "input_ids": batch["input_token_ids"], "attention_mask": batch["input_attn_mask"], "decoder_input_ids": batch['tgt_token_ids'], "decoder_attention_mask": batch["tgt_attn_mask"], "task": 0 } outputs = self.model(**inputs) loss = outputs[0] loss = torch.mean(loss) log = { 'train/loss': loss, } return { 'loss': loss, 'log': log } def validation_step(self,batch, batch_idx): inputs = { "input_ids": batch["input_token_ids"], "attention_mask": batch["input_attn_mask"], "decoder_input_ids": batch['tgt_token_ids'], "decoder_attention_mask": batch["tgt_attn_mask"], "task" :0, } outputs = self.model(**inputs) loss = outputs[0] loss = torch.mean(loss) return loss def validation_epoch_end(self, outputs): avg_loss = torch.mean(torch.stack(outputs)) log = { 'val/loss': avg_loss, } return { 'loss': avg_loss, 'log': log } def test_step(self, batch, batch_idx): if self.hparams.sample_gen: sample_output = self.model.generate(batch['input_token_ids'], do_sample=True, top_k=20, top_p=0.95, max_length=30, num_return_sequences=1,num_beams=1, ) else: sample_output = self.model.generate(batch['input_token_ids'], do_sample=False, max_length=30, num_return_sequences=1,num_beams=1, ) sample_output = sample_output.reshape(batch['input_token_ids'].size(0), 1, -1) doc_key = batch['doc_key'] # list tgt_token_ids = batch['tgt_token_ids'] return (doc_key, sample_output, tgt_token_ids) def test_epoch_end(self, outputs): # evaluate F1 with open('checkpoints/{}/predictions.jsonl'.format(self.hparams.ckpt_name),'w') as writer: for tup in outputs: for idx in range(len(tup[0])): pred = { 'doc_key': tup[0][idx], 'predicted': self.tokenizer.decode(tup[1][idx].squeeze(0), skip_special_tokens=True), 'gold': self.tokenizer.decode(tup[2][idx].squeeze(0), skip_special_tokens=True) } writer.write(json.dumps(pred)+'\n') return {} def pred(self, batch): if self.hparams.sample_gen: sample_output = self.model.generate(batch, do_sample=True, top_k=20, top_p=0.95, max_length=30, num_return_sequences=1, num_beams=1, ) else: sample_output = self.model.generate(batch, do_sample=False, max_length=30, num_return_sequences=1, num_beams=1, ) sample_output = sample_output.reshape(batch.size(0), 1, -1) return [self.tokenizer.decode(sample.squeeze(0), skip_special_tokens=True) for sample in sample_output] def configure_optimizers(self): self.train_len = len(self.train_dataloader()) if self.hparams.max_steps > 0: t_total = self.hparams.max_steps self.hparams.num_train_epochs = self.hparams.max_steps // self.train_len // self.hparams.accumulate_grad_batches + 1 else: t_total = self.train_len // self.hparams.accumulate_grad_batches * self.hparams.num_train_epochs logger.info('{} training steps in total.. '.format(t_total)) # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": self.hparams.weight_decay, }, {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}, ] optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon) # scheduler is called only once per epoch by default scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total) scheduler_dict = { 'scheduler': scheduler, 'interval': 'step', 'name': 'linear-schedule', } return [optimizer, ], [scheduler_dict,]