# coding=utf-8 # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BERT model.""" import logging import numpy as np import torch from torch import nn import torch.nn.functional as F import math from modules.until_config import PretrainedConfig logger = logging.getLogger(__name__) def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) def swish(x): return x * torch.sigmoid(x) def get_dual_matrix(sim_matrix): if torch.is_tensor(sim_matrix): pass else: sim_matrix = torch.tensor(sim_matrix) temp = 1 # sim_matrix = sim_matrix * F.softmax(sim_matrix / temp, dim=0) * len(sim_matrix) alpha = F.softmax(sim_matrix / temp, dim=0) beta = F.softmax(sim_matrix / temp, dim=1) sim_matrix = sim_matrix * alpha * beta return sim_matrix ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class PreTrainedModel(nn.Module): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ def __init__(self, config, *inputs, **kwargs): super(PreTrainedModel, self).__init__() if not isinstance(config, PretrainedConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " "To create a model from a Google pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) self.config = config def init_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) elif isinstance(module, LayerNorm): if 'beta' in dir(module) and 'gamma' in dir(module): module.beta.data.zero_() module.gamma.data.fill_(1.0) else: module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def resize_token_embeddings(self, new_num_tokens=None): raise NotImplementedError @classmethod def init_preweight(cls, model, state_dict, prefix=None, task_config=None): old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) if prefix is not None: old_keys = [] new_keys = [] for key in state_dict.keys(): old_keys.append(key) new_keys.append(prefix + key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(model, prefix='') if prefix is None and (task_config is None or task_config.local_rank == 0): logger.info("-" * 20) if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}" .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}" .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) if len(error_msgs) > 0: logger.error("Weights from pretrained model cause errors in {}: {}" .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) return model @property def dtype(self): """ :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ try: return next(self.parameters()).dtype except StopIteration: # For nn.DataParallel compatibility in PyTorch 1.5 def find_tensor_attributes(module: nn.Module): tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] return tuples gen = self._named_members(get_members_fn=find_tensor_attributes) first_tuple = next(gen) return first_tuple[1].dtype @classmethod def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): """ Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. """ # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None: return model model = cls.init_preweight(model, state_dict) return model ################################## ###### LOSS FUNCTION ############# ################################## class CrossEn(nn.Module): def __init__(self,): super(CrossEn, self).__init__() def forward(self, sim_matrix): logpt = F.log_softmax(sim_matrix, dim=-1) logpt = torch.diag(logpt) nce_loss = -logpt sim_loss = nce_loss.mean() return sim_loss class Dual_CrossEn(nn.Module): def __init__(self,): super(Dual_CrossEn, self).__init__() def forward(self, sim_matrix): sim_matrix = get_dual_matrix(sim_matrix) logpt = F.log_softmax(sim_matrix, dim=-1) logpt = torch.diag(logpt) nce_loss = -logpt sim_loss = nce_loss.mean() return sim_loss class MILNCELoss(nn.Module): def __init__(self, batch_size=1, n_pair=1,): super(MILNCELoss, self).__init__() self.batch_size = batch_size self.n_pair = n_pair torch_v = float(".".join(torch.__version__.split(".")[:2])) self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8 def forward(self, sim_matrix): mm_mask = np.eye(self.batch_size) mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair))) mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device) from_text_matrix = sim_matrix + mm_mask * -1e12 from_video_matrix = sim_matrix.transpose(1, 0) new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1) logpt = F.log_softmax(new_sim_matrix, dim=-1) mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1) masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12 new_logpt = -torch.logsumexp(masked_logpt, dim=-1) logpt_choice = torch.zeros_like(new_logpt) mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2) logpt_choice[mark_ind] = 1 sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean() return sim_loss class MaxMarginRankingLoss(nn.Module): def __init__(self, margin=1.0, negative_weighting=False, batch_size=1, n_pair=1, hard_negative_rate=0.5, ): super(MaxMarginRankingLoss, self).__init__() self.margin = margin self.n_pair = n_pair self.batch_size = batch_size easy_negative_rate = 1 - hard_negative_rate self.easy_negative_rate = easy_negative_rate self.negative_weighting = negative_weighting if n_pair > 1 and batch_size > 1: alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate)) mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair))) mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate)) self.mm_mask = mm_mask.float() def forward(self, x): d = torch.diag(x) max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \ F.relu(self.margin + x - d.view(1, -1)) if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1: max_margin = max_margin * self.mm_mask.to(max_margin.device) return max_margin.mean() class AllGather(torch.autograd.Function): """An autograd function that performs allgather on a tensor.""" @staticmethod def forward(ctx, tensor, args): output = [torch.empty_like(tensor) for _ in range(args.world_size)] torch.distributed.all_gather(output, tensor) ctx.rank = args.rank ctx.batch_size = tensor.shape[0] return torch.cat(output, dim=0) @staticmethod def backward(ctx, grad_output): return ( grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], None, )