import os import time from datetime import timedelta import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from mmengine.config import Config from mmengine.utils import ProgressBar from transformers import AutoConfig, AutoModel class RamDataset(torch.utils.data.Dataset): def __init__(self, data_path, is_train=True, num_relation_classes=56): super().__init__() self.num_relation_classes = num_relation_classes data = np.load(data_path, allow_pickle=True) self.samples = data["arr_0"] sample_num = self.samples.size self.sample_idx_list = [] for idx in range(sample_num): if self.samples[idx]["is_train"] == is_train: self.sample_idx_list.append(idx) def __getitem__(self, idx): sample = self.samples[self.sample_idx_list[idx]] object_num = sample["feat"].shape[0] embedding = torch.from_numpy(sample["feat"]) gt_rels = sample["relations"] rel_target = self._get_target(object_num, gt_rels) return embedding, rel_target, gt_rels def __len__(self): return len(self.sample_idx_list) def _get_target(self, object_num, gt_rels): rel_target = torch.zeros([self.num_relation_classes, object_num, object_num]) for ii, jj, cls_relationship in gt_rels: rel_target[cls_relationship, ii, jj] = 1 return rel_target class RamModel(nn.Module): def __init__( self, pretrained_model_name_or_path, load_pretrained_weights=True, num_transformer_layer=2, input_feature_size=256, output_feature_size=768, cls_feature_size=512, num_relation_classes=56, pred_type="attention", loss_type="bce", ): super().__init__() # 0. config self.cls_feature_size = cls_feature_size self.num_relation_classes = num_relation_classes self.pred_type = pred_type self.loss_type = loss_type # 1. fc input and output self.fc_input = nn.Sequential( nn.Linear(input_feature_size, output_feature_size), nn.LayerNorm(output_feature_size), ) self.fc_output = nn.Sequential( nn.Linear(output_feature_size, output_feature_size), nn.LayerNorm(output_feature_size), ) # 2. transformer model if load_pretrained_weights: self.model = AutoModel.from_pretrained(pretrained_model_name_or_path) else: config = AutoConfig.from_pretrained(pretrained_model_name_or_path) self.model = AutoModel.from_config(config) if num_transformer_layer != "all" and isinstance(num_transformer_layer, int): self.model.encoder.layer = self.model.encoder.layer[:num_transformer_layer] # 3. predict head self.cls_sub = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes) self.cls_obj = nn.Linear(output_feature_size, cls_feature_size * num_relation_classes) # 4. loss if self.loss_type == "bce": self.bce_loss = nn.BCEWithLogitsLoss() elif self.loss_type == "multi_label_ce": print("Use Multi Label Cross Entropy Loss.") def forward(self, embeds, attention_mask=None): """ embeds: (batch_size, token_num, feature_size) attention_mask: (batch_size, token_num) """ # 1. fc input embeds = self.fc_input(embeds) # 2. transformer model position_ids = torch.ones([1, embeds.shape[1]]).to(embeds.device).to(torch.long) outputs = self.model.forward(inputs_embeds=embeds, attention_mask=attention_mask, position_ids=position_ids) embeds = outputs["last_hidden_state"] # 3. fc output embeds = self.fc_output(embeds) # 4. predict head batch_size, token_num, feature_size = embeds.shape sub_embeds = self.cls_sub(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3]) obj_embeds = self.cls_obj(embeds).reshape([batch_size, token_num, self.num_relation_classes, self.cls_feature_size]).permute([0, 2, 1, 3]) if self.pred_type == "attention": cls_pred = sub_embeds @ torch.transpose(obj_embeds, 2, 3) / self.cls_feature_size**0.5 # noqa elif self.pred_type == "einsum": cls_pred = torch.einsum("nrsc,nroc->nrso", sub_embeds, obj_embeds) return cls_pred def loss(self, pred, target, attention_mask): loss_dict = dict() batch_size, relation_num, _, _ = pred.shape mask = torch.zeros_like(pred).to(pred.device) for idx in range(batch_size): n = torch.sum(attention_mask[idx]).to(torch.int) mask[idx, :, :n, :n] = 1 pred = pred * mask - 9999 * (1 - mask) if self.loss_type == "bce": loss = self.bce_loss(pred, target) elif self.loss_type == "multi_label_ce": input_tensor = torch.permute(pred, (1, 0, 2, 3)) target_tensor = torch.permute(target, (1, 0, 2, 3)) input_tensor = pred.reshape([relation_num, -1]) target_tensor = target.reshape([relation_num, -1]) loss = self.multilabel_categorical_crossentropy(target_tensor, input_tensor) weight = loss / loss.max() loss = loss * weight loss = loss.mean() loss_dict["loss"] = loss # running metric recall_20 = get_recall_N(pred, target, object_num=20) loss_dict["recall@20"] = recall_20 return loss_dict def multilabel_categorical_crossentropy(self, y_true, y_pred): """ https://kexue.fm/archives/7359 """ y_pred = (1 - 2 * y_true) * y_pred y_pred_neg = y_pred - y_true * 9999 y_pred_pos = y_pred - (1 - y_true) * 9999 zeros = torch.zeros_like(y_pred[..., :1]) y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1) y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1) neg_loss = torch.logsumexp(y_pred_neg, dim=-1) pos_loss = torch.logsumexp(y_pred_pos, dim=-1) return neg_loss + pos_loss def get_recall_N(y_pred, y_true, object_num=20): """ y_pred: [batch_size, 56, object_num, object_num] y_true: [batch_size, 56, object_num, object_num] """ device = y_pred.device recall_list = [] for idx in range(len(y_true)): sample_y_true = [] sample_y_pred = [] # find topk _, topk_indices = torch.topk( y_true[idx : idx + 1].reshape( [ -1, ] ), k=object_num, ) for index in topk_indices: pred_cls = index // (y_true.shape[2] ** 2) index_subject_object = index % (y_true.shape[2] ** 2) pred_subject = index_subject_object // y_true.shape[2] pred_object = index_subject_object % y_true.shape[2] if y_true[idx, pred_cls, pred_subject, pred_object] == 0: continue sample_y_true.append([pred_subject, pred_object, pred_cls]) # find topk _, topk_indices = torch.topk( y_pred[idx : idx + 1].reshape( [ -1, ] ), k=object_num, ) for index in topk_indices: pred_cls = index // (y_pred.shape[2] ** 2) index_subject_object = index % (y_pred.shape[2] ** 2) pred_subject = index_subject_object // y_pred.shape[2] pred_object = index_subject_object % y_pred.shape[2] sample_y_pred.append([pred_subject, pred_object, pred_cls]) recall = len([x for x in sample_y_pred if x in sample_y_true]) / (len(sample_y_true) + 1e-8) recall_list.append(recall) recall = torch.tensor(recall_list).to(device).mean() * 100 return recall class RamTrainer(object): def __init__(self, config): self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._build_dataset() self._build_dataloader() self._build_model() self._build_optimizer() self._build_lr_scheduler() def _build_dataset(self): self.dataset = RamDataset(**self.config.dataset) def _build_dataloader(self): self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=self.config.dataloader.batch_size, shuffle=True if self.config.dataset.is_train else False, ) def _build_model(self): self.model = RamModel(**self.config.model).to(self.device) if self.config.load_from is not None: self.model.load_state_dict(torch.load(self.config.load_from)) self.model.train() def _build_optimizer(self): self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.optim.lr, weight_decay=self.config.optim.weight_decay, eps=self.config.optim.eps, betas=self.config.optim.betas) def _build_lr_scheduler(self): self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config.optim.lr_scheduler.step, gamma=self.config.optim.lr_scheduler.gamma) def train(self): t_start = time.time() running_avg_loss = 0 for epoch_idx in range(self.config.num_epoch): for batch_idx, batch_data in enumerate(self.dataloader): batch_embeds = batch_data[0].to(torch.float32).to(self.device) batch_target = batch_data[1].to(torch.float32).to(self.device) attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1])) batch_pred = self.model.forward(batch_embeds, attention_mask) loss_dict = self.model.loss(batch_pred, batch_target, attention_mask) loss = loss_dict["loss"] recall_20 = loss_dict["recall@20"] self.optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.optim.max_norm, self.config.optim.norm_type) self.optimizer.step() running_avg_loss += loss.item() if batch_idx % 100 == 0: t_current = time.time() num_finished_step = epoch_idx * self.config.num_epoch * len(self.dataloader) + batch_idx + 1 num_to_do_step = (self.config.num_epoch - epoch_idx - 1) * len(self.dataloader) + (len(self.dataloader) - batch_idx - 1) avg_speed = num_finished_step / (t_current - t_start) eta = num_to_do_step / avg_speed print( "ETA={:0>8}, Epoch={}, Batch={}/{}, LR={}, Loss={:.4f}, RunningAvgLoss={:.4f}, Recall@20={:.2f}%".format( str(timedelta(seconds=int(eta))), epoch_idx + 1, batch_idx, len(self.dataloader), self.lr_scheduler.get_last_lr()[0], loss.item(), running_avg_loss / num_finished_step, recall_20.item() ) ) self.lr_scheduler.step() if not os.path.exists(self.config.output_dir): os.makedirs(self.config.output_dir) save_path = os.path.join(self.config.output_dir, "epoch_{}.pth".format(epoch_idx + 1)) print("Save epoch={} checkpoint to {}".format(epoch_idx + 1, save_path)) torch.save(self.model.state_dict(), save_path) return save_path class RamPredictor(object): def __init__(self, config): self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self._build_dataset() self._build_dataloader() self._build_model() def _build_dataset(self): self.dataset = RamDataset(**self.config.dataset) def _build_dataloader(self): self.dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=self.config.dataloader.batch_size, shuffle=False) def _build_model(self): self.model = RamModel(**self.config.model).to(self.device) if self.config.load_from is not None: self.model.load_state_dict(torch.load(self.config.load_from)) self.model.eval() def predict(self, batch_embeds, pred_keep_num=100): """ Parameters ---------- batch_embeds: (batch_size=1, token_num, feature_size) pred_keep_num: int Returns ------- batch_pred: (batch_size, relation_num, object_num, object_num) pred_rels: [[sub_id, obj_id, rel_id], ...] """ if not isinstance(batch_embeds, torch.Tensor): batch_embeds = torch.asarray(batch_embeds) batch_embeds = batch_embeds.to(torch.float32).to(self.device) attention_mask = batch_embeds.new_ones((batch_embeds.shape[0], batch_embeds.shape[1])) batch_pred = self.model.forward(batch_embeds, attention_mask) for idx_i in range(batch_pred.shape[2]): batch_pred[:, :, idx_i, idx_i] = -9999 batch_pred = batch_pred.sigmoid() pred_rels = [] _, topk_indices = torch.topk( batch_pred.reshape( [ -1, ] ), k=pred_keep_num, ) # subject, object, relation for index in topk_indices: pred_relation = index // (batch_pred.shape[2] ** 2) index_subject_object = index % (batch_pred.shape[2] ** 2) pred_subject = index_subject_object // batch_pred.shape[2] pred_object = index_subject_object % batch_pred.shape[2] pred = [pred_subject.item(), pred_object.item(), pred_relation.item()] pred_rels.append(pred) return batch_pred, pred_rels def eval(self): sum_recall_20 = 0.0 sum_recall_50 = 0.0 sum_recall_100 = 0.0 prog_bar = ProgressBar(len(self.dataloader)) for batch_idx, batch_data in enumerate(self.dataloader): batch_embeds = batch_data[0] batch_target = batch_data[1] gt_rels = batch_data[2] batch_pred, pred_rels = self.predict(batch_embeds) this_recall_20 = get_recall_N(batch_pred, batch_target, object_num=20) this_recall_50 = get_recall_N(batch_pred, batch_target, object_num=50) this_recall_100 = get_recall_N(batch_pred, batch_target, object_num=100) sum_recall_20 += this_recall_20.item() sum_recall_50 += this_recall_50.item() sum_recall_100 += this_recall_100.item() prog_bar.update() recall_20 = sum_recall_20 / len(self.dataloader) recall_50 = sum_recall_50 / len(self.dataloader) recall_100 = sum_recall_100 / len(self.dataloader) metric = { "recall_20": recall_20, "recall_50": recall_50, "recall_100": recall_100, } return metric if __name__ == "__main__": # Config config = dict( dataset=dict( data_path="./data/feat_0420.npz", is_train=True, num_relation_classes=56, ), dataloader=dict( batch_size=4, ), model=dict( pretrained_model_name_or_path="bert-base-uncased", load_pretrained_weights=True, num_transformer_layer=2, input_feature_size=256, output_feature_size=768, cls_feature_size=512, num_relation_classes=56, pred_type="attention", loss_type="multi_label_ce", ), optim=dict( lr=1e-4, weight_decay=0.05, eps=1e-8, betas=(0.9, 0.999), max_norm=0.01, norm_type=2, lr_scheduler=dict( step=[6, 10], gamma=0.1, ), ), num_epoch=12, output_dir="./work_dirs", load_from=None, ) # Train config = Config(config) trainer = RamTrainer(config) last_model_path = trainer.train() # Test/Eval config.dataset.is_train = False config.load_from = last_model_path predictor = RamPredictor(config) metric = predictor.eval() print(metric)