Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team. | |
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""PyTorch BERT model.""" | |
import logging | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import math | |
from modules.until_config import PretrainedConfig | |
logger = logging.getLogger(__name__) | |
def gelu(x): | |
"""Implementation of the gelu activation function. | |
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): | |
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
""" | |
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
def swish(x): | |
return x * torch.sigmoid(x) | |
def get_dual_matrix(sim_matrix): | |
if torch.is_tensor(sim_matrix): | |
pass | |
else: | |
sim_matrix = torch.tensor(sim_matrix) | |
temp = 1 | |
# sim_matrix = sim_matrix * F.softmax(sim_matrix / temp, dim=0) * len(sim_matrix) | |
alpha = F.softmax(sim_matrix / temp, dim=0) | |
beta = F.softmax(sim_matrix / temp, dim=1) | |
sim_matrix = sim_matrix * alpha * beta | |
return sim_matrix | |
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |
class LayerNorm(nn.Module): | |
def __init__(self, hidden_size, eps=1e-12): | |
"""Construct a layernorm module in the TF style (epsilon inside the square root). | |
""" | |
super(LayerNorm, self).__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
self.variance_epsilon = eps | |
def forward(self, x): | |
u = x.mean(-1, keepdim=True) | |
s = (x - u).pow(2).mean(-1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
return self.weight * x + self.bias | |
class PreTrainedModel(nn.Module): | |
""" An abstract class to handle weights initialization and | |
a simple interface for dowloading and loading pretrained models. | |
""" | |
def __init__(self, config, *inputs, **kwargs): | |
super(PreTrainedModel, self).__init__() | |
if not isinstance(config, PretrainedConfig): | |
raise ValueError( | |
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " | |
"To create a model from a Google pretrained model use " | |
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( | |
self.__class__.__name__, self.__class__.__name__ | |
)) | |
self.config = config | |
def init_weights(self, module): | |
""" Initialize the weights. | |
""" | |
if isinstance(module, (nn.Linear, nn.Embedding)): | |
# Slightly different from the TF version which uses truncated_normal for initialization | |
# cf https://github.com/pytorch/pytorch/pull/5617 | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
elif isinstance(module, LayerNorm): | |
if 'beta' in dir(module) and 'gamma' in dir(module): | |
module.beta.data.zero_() | |
module.gamma.data.fill_(1.0) | |
else: | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
if isinstance(module, nn.Linear) and module.bias is not None: | |
module.bias.data.zero_() | |
def resize_token_embeddings(self, new_num_tokens=None): | |
raise NotImplementedError | |
def init_preweight(cls, model, state_dict, prefix=None, task_config=None): | |
old_keys = [] | |
new_keys = [] | |
for key in state_dict.keys(): | |
new_key = None | |
if 'gamma' in key: | |
new_key = key.replace('gamma', 'weight') | |
if 'beta' in key: | |
new_key = key.replace('beta', 'bias') | |
if new_key: | |
old_keys.append(key) | |
new_keys.append(new_key) | |
for old_key, new_key in zip(old_keys, new_keys): | |
state_dict[new_key] = state_dict.pop(old_key) | |
if prefix is not None: | |
old_keys = [] | |
new_keys = [] | |
for key in state_dict.keys(): | |
old_keys.append(key) | |
new_keys.append(prefix + key) | |
for old_key, new_key in zip(old_keys, new_keys): | |
state_dict[new_key] = state_dict.pop(old_key) | |
missing_keys = [] | |
unexpected_keys = [] | |
error_msgs = [] | |
# copy state_dict so _load_from_state_dict can modify it | |
metadata = getattr(state_dict, '_metadata', None) | |
state_dict = state_dict.copy() | |
if metadata is not None: | |
state_dict._metadata = metadata | |
def load(module, prefix=''): | |
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |
module._load_from_state_dict( | |
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |
for name, child in module._modules.items(): | |
if child is not None: | |
load(child, prefix + name + '.') | |
load(model, prefix='') | |
if prefix is None and (task_config is None or task_config.local_rank == 0): | |
logger.info("-" * 20) | |
if len(missing_keys) > 0: | |
logger.info("Weights of {} not initialized from pretrained model: {}" | |
.format(model.__class__.__name__, "\n " + "\n ".join(missing_keys))) | |
if len(unexpected_keys) > 0: | |
logger.info("Weights from pretrained model not used in {}: {}" | |
.format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys))) | |
if len(error_msgs) > 0: | |
logger.error("Weights from pretrained model cause errors in {}: {}" | |
.format(model.__class__.__name__, "\n " + "\n ".join(error_msgs))) | |
return model | |
def dtype(self): | |
""" | |
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). | |
""" | |
try: | |
return next(self.parameters()).dtype | |
except StopIteration: | |
# For nn.DataParallel compatibility in PyTorch 1.5 | |
def find_tensor_attributes(module: nn.Module): | |
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] | |
return tuples | |
gen = self._named_members(get_members_fn=find_tensor_attributes) | |
first_tuple = next(gen) | |
return first_tuple[1].dtype | |
def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs): | |
""" | |
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict. | |
Download and cache the pre-trained model file if needed. | |
""" | |
# Instantiate model. | |
model = cls(config, *inputs, **kwargs) | |
if state_dict is None: | |
return model | |
model = cls.init_preweight(model, state_dict) | |
return model | |
################################## | |
###### LOSS FUNCTION ############# | |
################################## | |
class CrossEn(nn.Module): | |
def __init__(self,): | |
super(CrossEn, self).__init__() | |
def forward(self, sim_matrix): | |
logpt = F.log_softmax(sim_matrix, dim=-1) | |
logpt = torch.diag(logpt) | |
nce_loss = -logpt | |
sim_loss = nce_loss.mean() | |
return sim_loss | |
class Dual_CrossEn(nn.Module): | |
def __init__(self,): | |
super(Dual_CrossEn, self).__init__() | |
def forward(self, sim_matrix): | |
sim_matrix = get_dual_matrix(sim_matrix) | |
logpt = F.log_softmax(sim_matrix, dim=-1) | |
logpt = torch.diag(logpt) | |
nce_loss = -logpt | |
sim_loss = nce_loss.mean() | |
return sim_loss | |
class MILNCELoss(nn.Module): | |
def __init__(self, batch_size=1, n_pair=1,): | |
super(MILNCELoss, self).__init__() | |
self.batch_size = batch_size | |
self.n_pair = n_pair | |
torch_v = float(".".join(torch.__version__.split(".")[:2])) | |
self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8 | |
def forward(self, sim_matrix): | |
mm_mask = np.eye(self.batch_size) | |
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair))) | |
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device) | |
from_text_matrix = sim_matrix + mm_mask * -1e12 | |
from_video_matrix = sim_matrix.transpose(1, 0) | |
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1) | |
logpt = F.log_softmax(new_sim_matrix, dim=-1) | |
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1) | |
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12 | |
new_logpt = -torch.logsumexp(masked_logpt, dim=-1) | |
logpt_choice = torch.zeros_like(new_logpt) | |
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2) | |
logpt_choice[mark_ind] = 1 | |
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean() | |
return sim_loss | |
class MaxMarginRankingLoss(nn.Module): | |
def __init__(self, | |
margin=1.0, | |
negative_weighting=False, | |
batch_size=1, | |
n_pair=1, | |
hard_negative_rate=0.5, | |
): | |
super(MaxMarginRankingLoss, self).__init__() | |
self.margin = margin | |
self.n_pair = n_pair | |
self.batch_size = batch_size | |
easy_negative_rate = 1 - hard_negative_rate | |
self.easy_negative_rate = easy_negative_rate | |
self.negative_weighting = negative_weighting | |
if n_pair > 1 and batch_size > 1: | |
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate)) | |
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha | |
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair))) | |
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate)) | |
self.mm_mask = mm_mask.float() | |
def forward(self, x): | |
d = torch.diag(x) | |
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \ | |
F.relu(self.margin + x - d.view(1, -1)) | |
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1: | |
max_margin = max_margin * self.mm_mask.to(max_margin.device) | |
return max_margin.mean() | |
class AllGather(torch.autograd.Function): | |
"""An autograd function that performs allgather on a tensor.""" | |
def forward(ctx, tensor, args): | |
output = [torch.empty_like(tensor) for _ in range(args.world_size)] | |
torch.distributed.all_gather(output, tensor) | |
ctx.rank = args.rank | |
ctx.batch_size = tensor.shape[0] | |
return torch.cat(output, dim=0) | |
def backward(ctx, grad_output): | |
return ( | |
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], | |
None, | |
) | |