HMMC_t2v_search / modules /module_cross.py
cheetah003's picture
first commit
29c5a57
raw
history blame
16.2 kB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import logging
import tarfile
import tempfile
import shutil
import sys
import torch
from torch import nn
import torch.nn.functional as F
from .file_utils import cached_path
from .until_config import PretrainedConfig
from .until_module import PreTrainedModel, LayerNorm, ACT2FN
from collections import OrderedDict
from modules.module_clip import build_model, CLIP, convert_weights
from transformers import AutoConfig, AutoModel, RobertaModel, RobertaConfig
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {}
CONFIG_NAME = 'cross_config.json'
WEIGHTS_NAME = 'cross_pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class CrossConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `CrossModel`.
"""
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
config_name = CONFIG_NAME
weights_name = WEIGHTS_NAME
def __init__(self,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
"""Constructs CrossConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`CrossModel`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.n_head = n_head
def attention(self, x: torch.Tensor, attn_mask: torch.Tensor):
attn_mask_ = attn_mask.repeat(self.n_head, 1, 1)
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0]
def forward(self, para_tuple: tuple):
# x: torch.Tensor, attn_mask: torch.Tensor
# print(para_tuple)
x, attn_mask = para_tuple
x = x + self.attention(self.ln_1(x), attn_mask)
x = x + self.mlp(self.ln_2(x))
return (x, attn_mask)
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)])
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
# logger.info("x.shpae:{},attn_mask:{}".format(x.shape, attn_mask.shape))
return self.resblocks((x, attn_mask))[0]
class VisualEncoder(nn.Module):
def __init__(self, task_config, cross_config):
super().__init__()
pretrained_clip_name = cross_config.pretrained_clip_name
if task_config.local_rank == 0:
logger.info("pretrained_clip_name:{}".format(pretrained_clip_name))
clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name)
clip = build_model(clip_state_dict, local_rank=task_config.local_rank)
self.use_temp = task_config.use_temp
self.is_vit = copy.deepcopy(clip.vit)
self.visual = copy.deepcopy(clip.visual)
if self.use_temp:
self.temporal_transformer = Transformer(width=cross_config.temporal_hidden_size,
layers=cross_config.temporal_hidden_layers,
heads=cross_config.temporal_attention_heads)
self.frame_position_embeddings = nn.Embedding(cross_config.max_position_embeddings,
cross_config.temporal_hidden_size)
# use clip.transformer to initial temporal_transformer
# for param_1, param_2 in zip(self.temporal_transformer.parameters(), clip.transformer.parameters()):
# param_1.data.copy_(param_2.data) # initialize
# if task_config.local_rank == 0:
# logger.info("clip.positional_embedding:{}".format(clip.positional_embedding))
# self.frame_position_embeddings.weight = copy.deepcopy(clip.positional_embedding)
def forward(self, video, video_frames):
# encode frames
bs, frames, channel, h, w = video.shape
# [bs*frame, 3, 224, 224]
video = video.view(bs * frames, channel, h, w)
# logger.info("video_b.shape:{}, dtype:{}".format(video_b.shape, video_b.dtype))
# logger.info("video_frame[{}]:{}".format(b, video_frame))
visual_hidden = self.encode_image(video, video_frame=frames)
# [bs, frame, hidden_size]
# logger.info("visual_hidden.shape:{}".format(visual_hidden.shape))
visual_hidden = visual_hidden.view(bs, frames, visual_hidden.size(-1))
# logger.info("visual_hidden1.shape:{}".format(visual_hidden.shape))
# get temporal information
visual_hidden_original = visual_hidden
frame_output = visual_hidden_original
if self.use_temp:
seq_length = visual_hidden.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=visual_hidden.device)
# logger.info("position_ids.shape:{}".format(position_ids.shape))
frame_position_embeddings = self.frame_position_embeddings(position_ids)
# logger.info("frame_position_embeddings.shape:{}".format(frame_position_embeddings.shape))
visual_hidden = visual_hidden + frame_position_embeddings
video_mask = torch.ones([bs, frames], device=visual_hidden.device)
extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0
extended_video_mask = extended_video_mask.expand(-1, video_mask.size(1), -1)
visual_hidden = visual_hidden.permute(1, 0, 2) # NLD -> LND
visual_hidden = self.temporal_transformer(visual_hidden, extended_video_mask)
visual_hidden = visual_hidden.permute(1, 0, 2) # LND -> NLD
visual_hidden = visual_hidden + visual_hidden_original
# logger.info("visual_hidden.shape:{}".format(visual_hidden.shape))
visual_output = visual_hidden / visual_hidden.norm(dim=-1, keepdim=True)
# [bs, frames,512] -> [bs, 512]
visual_output = torch.mean(visual_output, dim=1)
# logger.info("visual_hidden mean.shape:{}".format(visual_hidden.shape))
# logger.info("visual encoder visual_output.shape:{}".format(visual_output.shape))
return visual_output, frame_output
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image, return_hidden=False, video_frame=-1):
if self.is_vit:
# logger.info("image.shape:{}".format(image.shape))
# hidden = self.visual(image, video_frame=video_frame)
hidden = self.visual(image.type(self.dtype), video_frame=video_frame)
# logger.info("hidden1.shape:{}".format(hidden.shape))
hidden = self.visual.ln_post(hidden) @ self.visual.proj
# logger.info("hidden2.shape:{}".format(hidden.shape))
x = hidden[:, 0, :]
# x = hidden
else:
hidden = self.visual(image)
x = hidden
if return_hidden:
return x.float(), hidden.float()
return x.float()
class TextEncoder(nn.Module):
def __init__(self, task_config, cross_config):
super().__init__()
self.language = task_config.language
pretrained_clip_name = cross_config.pretrained_clip_name
if task_config.local_rank == 0:
logger.info("pretrained_clip_name:{}".format(pretrained_clip_name))
clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name)
clip = build_model(clip_state_dict, local_rank=task_config.local_rank)
self.logit_scale = copy.deepcopy(clip_state_dict["logit_scale"])
if self.language == "english":
self.token_embedding = copy.deepcopy(clip.token_embedding)
self.positional_embedding = copy.deepcopy(clip.positional_embedding)
self.transformer = copy.deepcopy(clip.transformer)
self.ln_final = copy.deepcopy(clip.ln_final)
self.text_projection = copy.deepcopy(clip.text_projection)
self.dtype = clip.visual.conv1.weight.dtype
elif self.language == "chinese":
pretrained = task_config.pretrained_text
t_config = AutoConfig.from_pretrained(pretrained)
if task_config.rank == 0:
logger.info("name:{},chinesebert_config:{}".format(pretrained, t_config))
self.chinese_encoder = AutoModel.from_pretrained(pretrained)
# logger.info("random Roberta")
# self.chinese_encoder = RobertaModel(RobertaConfig())
self.text_proj = nn.Linear(cross_config.chinese_hidden_size, cross_config.temporal_hidden_size)
else:
raise NotImplementedError("wrong language")
def forward(self, input_ids, attention_mask, return_hidden=False):
bs_pair = input_ids.size(0)
if self.language == "english":
text_output, hidden = self.encode_text(input_ids, return_hidden=True)
else:
temp_output = self.chinese_encoder(input_ids, attention_mask=attention_mask)
# logger.info("hidden:{},text_output:{}".format(temp_output[0].shape, temp_output[1].shape))
hidden = self.text_proj(temp_output[0])
text_output = self.text_proj(temp_output[1])
text_output = text_output.view(bs_pair, text_output.size(-1))
hidden = hidden.view(bs_pair, -1, hidden.size(-1))
if return_hidden:
return hidden
else:
return text_output
def encode_text(self, text, return_hidden=False):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype)
x = x + pos_emd
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
hidden = self.ln_final(x).type(self.dtype) @ self.text_projection
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)]
if return_hidden:
return x.float(), hidden.float()
return x.float()
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size, config.vocab_size,bias=False,)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str) or (
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)
):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias