HMMC_t2v_search / modules /until_module.py
cheetah003's picture
first commit
29c5a57
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BERT model."""
import logging
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import math
from modules.until_config import PretrainedConfig
logger = logging.getLogger(__name__)
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
def get_dual_matrix(sim_matrix):
if torch.is_tensor(sim_matrix):
pass
else:
sim_matrix = torch.tensor(sim_matrix)
temp = 1
# sim_matrix = sim_matrix * F.softmax(sim_matrix / temp, dim=0) * len(sim_matrix)
alpha = F.softmax(sim_matrix / temp, dim=0)
beta = F.softmax(sim_matrix / temp, dim=1)
sim_matrix = sim_matrix * alpha * beta
return sim_matrix
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias
class PreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, LayerNorm):
if 'beta' in dir(module) and 'gamma' in dir(module):
module.beta.data.zero_()
module.gamma.data.fill_(1.0)
else:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def resize_token_embeddings(self, new_num_tokens=None):
raise NotImplementedError
@classmethod
def init_preweight(cls, model, state_dict, prefix=None, task_config=None):
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
if prefix is not None:
old_keys = []
new_keys = []
for key in state_dict.keys():
old_keys.append(key)
new_keys.append(prefix + key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='')
if prefix is None and (task_config is None or task_config.local_rank == 0):
logger.info("-" * 20)
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}"
.format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}"
.format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
if len(error_msgs) > 0:
logger.error("Weights from pretrained model cause errors in {}: {}"
.format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
return model
@property
def dtype(self):
"""
:obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
try:
return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module):
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
@classmethod
def from_pretrained(cls, config, state_dict=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
"""
# Instantiate model.
model = cls(config, *inputs, **kwargs)
if state_dict is None:
return model
model = cls.init_preweight(model, state_dict)
return model
##################################
###### LOSS FUNCTION #############
##################################
class CrossEn(nn.Module):
def __init__(self,):
super(CrossEn, self).__init__()
def forward(self, sim_matrix):
logpt = F.log_softmax(sim_matrix, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
sim_loss = nce_loss.mean()
return sim_loss
class Dual_CrossEn(nn.Module):
def __init__(self,):
super(Dual_CrossEn, self).__init__()
def forward(self, sim_matrix):
sim_matrix = get_dual_matrix(sim_matrix)
logpt = F.log_softmax(sim_matrix, dim=-1)
logpt = torch.diag(logpt)
nce_loss = -logpt
sim_loss = nce_loss.mean()
return sim_loss
class MILNCELoss(nn.Module):
def __init__(self, batch_size=1, n_pair=1,):
super(MILNCELoss, self).__init__()
self.batch_size = batch_size
self.n_pair = n_pair
torch_v = float(".".join(torch.__version__.split(".")[:2]))
self.bool_dtype = torch.bool if torch_v >= 1.3 else torch.uint8
def forward(self, sim_matrix):
mm_mask = np.eye(self.batch_size)
mm_mask = np.kron(mm_mask, np.ones((self.n_pair, self.n_pair)))
mm_mask = torch.tensor(mm_mask).float().to(sim_matrix.device)
from_text_matrix = sim_matrix + mm_mask * -1e12
from_video_matrix = sim_matrix.transpose(1, 0)
new_sim_matrix = torch.cat([from_video_matrix, from_text_matrix], dim=-1)
logpt = F.log_softmax(new_sim_matrix, dim=-1)
mm_mask_logpt = torch.cat([mm_mask, torch.zeros_like(mm_mask)], dim=-1)
masked_logpt = logpt + (torch.ones_like(mm_mask_logpt) - mm_mask_logpt) * -1e12
new_logpt = -torch.logsumexp(masked_logpt, dim=-1)
logpt_choice = torch.zeros_like(new_logpt)
mark_ind = torch.arange(self.batch_size).to(sim_matrix.device) * self.n_pair + (self.n_pair//2)
logpt_choice[mark_ind] = 1
sim_loss = new_logpt.masked_select(logpt_choice.to(dtype=self.bool_dtype)).mean()
return sim_loss
class MaxMarginRankingLoss(nn.Module):
def __init__(self,
margin=1.0,
negative_weighting=False,
batch_size=1,
n_pair=1,
hard_negative_rate=0.5,
):
super(MaxMarginRankingLoss, self).__init__()
self.margin = margin
self.n_pair = n_pair
self.batch_size = batch_size
easy_negative_rate = 1 - hard_negative_rate
self.easy_negative_rate = easy_negative_rate
self.negative_weighting = negative_weighting
if n_pair > 1 and batch_size > 1:
alpha = easy_negative_rate / ((batch_size - 1) * (1 - easy_negative_rate))
mm_mask = (1 - alpha) * np.eye(self.batch_size) + alpha
mm_mask = np.kron(mm_mask, np.ones((n_pair, n_pair)))
mm_mask = torch.tensor(mm_mask) * (batch_size * (1 - easy_negative_rate))
self.mm_mask = mm_mask.float()
def forward(self, x):
d = torch.diag(x)
max_margin = F.relu(self.margin + x - d.view(-1, 1)) + \
F.relu(self.margin + x - d.view(1, -1))
if self.negative_weighting and self.n_pair > 1 and self.batch_size > 1:
max_margin = max_margin * self.mm_mask.to(max_margin.device)
return max_margin.mean()
class AllGather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
@staticmethod
def forward(ctx, tensor, args):
output = [torch.empty_like(tensor) for _ in range(args.world_size)]
torch.distributed.all_gather(output, tensor)
ctx.rank = args.rank
ctx.batch_size = tensor.shape[0]
return torch.cat(output, dim=0)
@staticmethod
def backward(ctx, grad_output):
return (
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
None,
)