import argparse import logging import os import random from datetime import datetime import numpy as np import torch from sklearn.metrics import accuracy_score, f1_score from torch import nn from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from transformers import (AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoTokenizer, BertForSequenceClassification, BertModel) if not os.path.exists('logs/'): os.mkdir('logs/') logging.basicConfig( filename='logs/expbert-{}.log'.format(str(datetime.now())), format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger = logging.getLogger(__name__) TASK2PATH = { "disease-train": "data/disease/train.txt", "disease-test": "data/disease/test.txt", "spouse-train": "data/spouse/train.txt", "spouse-test": "data/spouse/test.txt", } ANNOTATED_EXP = { "spouse": "data/exp/expbert_spouse_explanation.txt", "disease": "data/exp/expbert_disease_explanation.txt", } GENERATED_EXP = { "spouse": "data/exp/orion_spouse_explanation.txt", "disease": "data/exp/orion_disease_explanation.txt", } device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False def print_config(config): config = vars(config) logger.info("**************** MODEL CONFIGURATION ****************") for key in sorted(config.keys()): val = config[key] keystr = "{}".format(key) + (" " * (25 - len(key))) logger.info("{} --> {}".format(keystr, val)) logger.info("**************** MODEL CONFIGURATION ****************") class ExpBERT(nn.Module): def __init__(self, args, exp_num): super(ExpBERT, self).__init__() self.args = args self.exp_num = exp_num self.config = AutoConfig.from_pretrained(args.model) self.model = AutoModel.from_pretrained(args.model, config=self.config) self.dropout = nn.Dropout(p=0.1) self.linear = nn.Linear(self.config.hidden_size * exp_num, 2) self.criterion = nn.CrossEntropyLoss() def forward(self, inputs): for k, v in inputs["encoding"].items(): inputs["encoding"][k] = v.to(device) pooler_output = self.model(**inputs["encoding"]).last_hidden_state[:, 0, :].reshape(1, self.exp_num * self.config.hidden_size) pooler_output = self.dropout(pooler_output) logits = self.linear(pooler_output) loss = self.criterion(logits, torch.LongTensor([inputs["label"]]).to(device)) prediction = torch.argmax(logits) return { "loss": loss, "prediction": prediction, } class REDataset(Dataset): def __init__(self, path, exp, tokenizer): super(REDataset, self).__init__() self.tokenizer = tokenizer self.exp = exp self.sentences = [] self.labels = [] self.entities = [] with open(path, "r", encoding="utf-8") as file: data = file.readlines() for example in data: sentence, entity1, entity2, id, label = example.strip().split("\t") self.sentences.append(sentence) if eval(label) == 1: self.labels.append(1) elif eval(label) == -1: self.labels.append(0) self.entities.append([entity1, entity2]) logger.info("Number of Example in {}: {}".format(path, str(len(self.labels)))) def __len__(self): return len(self.labels) def __getitem__(self, index): return { "sentence": self.sentences[index], "entity": self.entities[index], "label": self.labels[index], } def collate_fn(self, batch): outputs = [] for ex in batch: temp = [] for exp in self.exp: if "{e1}" in exp or "{e2}" in exp: exp = exp.replace("{e1}", ex["entity"][0]).replace("{e2}", ex["entity"][1]) else: for entity in ex["entity"]: index = exp.index('') exp = exp[:index] + entity + exp[index + len(''):] temp.append(exp) outputs.append( { "encoding": self.tokenizer( [ex["sentence"]] * len(temp), temp, add_special_tokens=True, padding="longest", truncation=True, max_length=156, return_tensors="pt", return_attention_mask=True, return_token_type_ids=True, ), "label": ex["label"], } ) return outputs def collate_fn_(self, batch): texts = [] labels = [] for ex in batch: texts.append(ex["sentence"]) labels.append(ex["label"]) outputs = self.tokenizer( texts, add_special_tokens=True, padding="longest", truncation=True, max_length=156, return_tensors="pt", return_attention_mask=True, return_token_type_ids=True, ) outputs["labels"] = torch.LongTensor(labels) return outputs class Trainer(object): def __init__(self, args): self.args = args print_config(args) self.tokenizer = AutoTokenizer.from_pretrained(self.args.model) TASK2EXP = GENERATED_EXP if args.generated_rules else ANNOTATED_EXP with open(TASK2EXP[args.task], "r", encoding="utf-8") as file: exp = file.readlines() self.train_dataset = REDataset(TASK2PATH['{}-train'.format(args.task)], exp, self.tokenizer) self.test_dataset = REDataset(TASK2PATH['{}-test'.format(args.task)], exp, self.tokenizer) self.model = AutoModelForSequenceClassification.from_pretrained(args.model).to(device) if self.args.no_exp else ExpBERT(args, len(exp)).to(device) self.train_loader = DataLoader( self.train_dataset, batch_size=args.batch_size, shuffle=args.shuffle, collate_fn=self.train_dataset.collate_fn_ if self.args.no_exp else self.train_dataset.collate_fn, ) self.test_loader = DataLoader( self.test_dataset, batch_size=args.batch_size, shuffle=args.shuffle, collate_fn=self.test_dataset.collate_fn_ if self.args.no_exp else self.test_dataset.collate_fn, ) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.args.learning_rate) def compute_metrics(self, labels, predictions): accuracy = accuracy_score(y_pred=predictions, y_true=labels) f1 = f1_score(y_pred=predictions, y_true=labels) return accuracy, f1 def train(self): self.model.train() self.test(-1) for e in range(self.args.epochs): with tqdm(total=len(self.train_loader)) as pbar: for step, examples in enumerate(self.train_loader): self.model.zero_grad() if self.args.no_exp: for k, v in examples.items(): examples[k] = v.to(device) outputs = self.model(**examples) outputs.loss.backward() else: for ex in examples: outputs = self.model(ex) (outputs["loss"] / len(examples)).backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self.optimizer.step() pbar.update(1) self.test(e) def test(self, epoch): self.model.eval() with torch.no_grad(): with tqdm(total=len(self.test_loader)) as pbar: loss = [] labels = [] predictions = [] for step, examples in enumerate(self.test_loader): if self.args.no_exp: for k, v in examples.items(): examples[k] = v.to(device) outputs = self.model(**examples) loss.append(outputs.loss.float()) labels.extend(examples["labels"].tolist()) predictions.extend(torch.argmax(outputs.logits, dim=1).tolist()) else: for ex in examples: labels.append(ex['label']) outputs = self.model(ex) loss.append(outputs["loss"].item()) predictions.append(outputs['prediction'].tolist()) pbar.update(1) accuracy, f1 = self.compute_metrics(predictions, labels) logger.info("[EPOCH {}] Accuracy: {} | F1-Score: {}. (Number of Data {})".format(epoch, accuracy, f1, len(predictions))) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--task", type=str, default="spouse") parser.add_argument("--model", type=str, default="bert-base-uncased") parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--learning_rate", type=float, default=2e-5) parser.add_argument("--shuffle", type=bool, default=False) parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--no_exp", type=bool, default=False) parser.add_argument("--generated_rules", type=bool, default=False) args = parser.parse_args() for seed in range(42, 47): set_random_seed(seed) trainer = Trainer(args) trainer.train()