from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import copy import json import math import logging import tarfile import tempfile import shutil import sys import torch from torch import nn import torch.nn.functional as F from .file_utils import cached_path from .until_config import PretrainedConfig from .until_module import PreTrainedModel, LayerNorm, ACT2FN from collections import OrderedDict from modules.module_clip import build_model, CLIP, convert_weights from transformers import AutoConfig, AutoModel, RobertaModel, RobertaConfig logger = logging.getLogger(__name__) PRETRAINED_MODEL_ARCHIVE_MAP = {} CONFIG_NAME = 'cross_config.json' WEIGHTS_NAME = 'cross_pytorch_model.bin' def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see https://arxiv.org/abs/1606.08415 """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) def swish(x): return x * torch.sigmoid(x) ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} class CrossConfig(PretrainedConfig): """Configuration class to store the configuration of a `CrossModel`. """ pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP config_name = CONFIG_NAME weights_name = WEIGHTS_NAME def __init__(self, vocab_size_or_config_json_file, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02): """Constructs CrossConfig. Args: vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`. hidden_size: Size of the encoder layers and the pooler layer. num_hidden_layers: Number of hidden layers in the Transformer encoder. num_attention_heads: Number of attention heads for each attention layer in the Transformer encoder. intermediate_size: The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. hidden_act: The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu" and "swish" are supported. hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. attention_probs_dropout_prob: The dropout ratio for the attention probabilities. max_position_embeddings: The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). type_vocab_size: The vocabulary size of the `token_type_ids` passed into `CrossModel`. initializer_range: The sttdev of the truncated_normal_initializer for initializing all weight matrices. """ if isinstance(vocab_size_or_config_json_file, str): with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: json_config = json.loads(reader.read()) for key, value in json_config.items(): self.__dict__[key] = value elif isinstance(vocab_size_or_config_json_file, int): self.vocab_size = vocab_size_or_config_json_file self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.hidden_act = hidden_act self.intermediate_size = intermediate_size self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range else: raise ValueError("First argument must be either a vocabulary size (int)" "or the path to a pretrained model config file (str)") class QuickGELU(nn.Module): def forward(self, x: torch.Tensor): return x * torch.sigmoid(1.702 * x) class ResidualAttentionBlock(nn.Module): def __init__(self, d_model: int, n_head: int): super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) self.ln_1 = LayerNorm(d_model) self.mlp = nn.Sequential(OrderedDict([ ("c_fc", nn.Linear(d_model, d_model * 4)), ("gelu", QuickGELU()), ("c_proj", nn.Linear(d_model * 4, d_model)) ])) self.ln_2 = LayerNorm(d_model) self.n_head = n_head def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): attn_mask_ = attn_mask.repeat(self.n_head, 1, 1) return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] def forward(self, para_tuple: tuple): # x: torch.Tensor, attn_mask: torch.Tensor # print(para_tuple) x, attn_mask = para_tuple x = x + self.attention(self.ln_1(x), attn_mask) x = x + self.mlp(self.ln_2(x)) return (x, attn_mask) class Transformer(nn.Module): def __init__(self, width: int, layers: int, heads: int): super().__init__() self.width = width self.layers = layers self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): # logger.info("x.shpae:{},attn_mask:{}".format(x.shape, attn_mask.shape)) return self.resblocks((x, attn_mask))[0] class VisualEncoder(nn.Module): def __init__(self, task_config, cross_config): super().__init__() pretrained_clip_name = cross_config.pretrained_clip_name if task_config.local_rank == 0: logger.info("pretrained_clip_name:{}".format(pretrained_clip_name)) clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name) clip = build_model(clip_state_dict, local_rank=task_config.local_rank) self.use_temp = task_config.use_temp self.is_vit = copy.deepcopy(clip.vit) self.visual = copy.deepcopy(clip.visual) if self.use_temp: self.temporal_transformer = Transformer(width=cross_config.temporal_hidden_size, layers=cross_config.temporal_hidden_layers, heads=cross_config.temporal_attention_heads) self.frame_position_embeddings = nn.Embedding(cross_config.max_position_embeddings, cross_config.temporal_hidden_size) # use clip.transformer to initial temporal_transformer # for param_1, param_2 in zip(self.temporal_transformer.parameters(), clip.transformer.parameters()): # param_1.data.copy_(param_2.data) # initialize # if task_config.local_rank == 0: # logger.info("clip.positional_embedding:{}".format(clip.positional_embedding)) # self.frame_position_embeddings.weight = copy.deepcopy(clip.positional_embedding) def forward(self, video, video_frames): # encode frames bs, frames, channel, h, w = video.shape # [bs*frame, 3, 224, 224] video = video.view(bs * frames, channel, h, w) # logger.info("video_b.shape:{}, dtype:{}".format(video_b.shape, video_b.dtype)) # logger.info("video_frame[{}]:{}".format(b, video_frame)) visual_hidden = self.encode_image(video, video_frame=frames) # [bs, frame, hidden_size] # logger.info("visual_hidden.shape:{}".format(visual_hidden.shape)) visual_hidden = visual_hidden.view(bs, frames, visual_hidden.size(-1)) # logger.info("visual_hidden1.shape:{}".format(visual_hidden.shape)) # get temporal information visual_hidden_original = visual_hidden frame_output = visual_hidden_original if self.use_temp: seq_length = visual_hidden.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=visual_hidden.device) # logger.info("position_ids.shape:{}".format(position_ids.shape)) frame_position_embeddings = self.frame_position_embeddings(position_ids) # logger.info("frame_position_embeddings.shape:{}".format(frame_position_embeddings.shape)) visual_hidden = visual_hidden + frame_position_embeddings video_mask = torch.ones([bs, frames], device=visual_hidden.device) extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 extended_video_mask = extended_video_mask.expand(-1, video_mask.size(1), -1) visual_hidden = visual_hidden.permute(1, 0, 2) # NLD -> LND visual_hidden = self.temporal_transformer(visual_hidden, extended_video_mask) visual_hidden = visual_hidden.permute(1, 0, 2) # LND -> NLD visual_hidden = visual_hidden + visual_hidden_original # logger.info("visual_hidden.shape:{}".format(visual_hidden.shape)) visual_output = visual_hidden / visual_hidden.norm(dim=-1, keepdim=True) # [bs, frames,512] -> [bs, 512] visual_output = torch.mean(visual_output, dim=1) # logger.info("visual_hidden mean.shape:{}".format(visual_hidden.shape)) # logger.info("visual encoder visual_output.shape:{}".format(visual_output.shape)) return visual_output, frame_output @property def dtype(self): return self.visual.conv1.weight.dtype def encode_image(self, image, return_hidden=False, video_frame=-1): if self.is_vit: # logger.info("image.shape:{}".format(image.shape)) # hidden = self.visual(image, video_frame=video_frame) hidden = self.visual(image.type(self.dtype), video_frame=video_frame) # logger.info("hidden1.shape:{}".format(hidden.shape)) hidden = self.visual.ln_post(hidden) @ self.visual.proj # logger.info("hidden2.shape:{}".format(hidden.shape)) x = hidden[:, 0, :] # x = hidden else: hidden = self.visual(image) x = hidden if return_hidden: return x.float(), hidden.float() return x.float() class TextEncoder(nn.Module): def __init__(self, task_config, cross_config): super().__init__() self.language = task_config.language pretrained_clip_name = cross_config.pretrained_clip_name if task_config.local_rank == 0: logger.info("pretrained_clip_name:{}".format(pretrained_clip_name)) clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name) clip = build_model(clip_state_dict, local_rank=task_config.local_rank) self.logit_scale = copy.deepcopy(clip_state_dict["logit_scale"]) if self.language == "english": self.token_embedding = copy.deepcopy(clip.token_embedding) self.positional_embedding = copy.deepcopy(clip.positional_embedding) self.transformer = copy.deepcopy(clip.transformer) self.ln_final = copy.deepcopy(clip.ln_final) self.text_projection = copy.deepcopy(clip.text_projection) self.dtype = clip.visual.conv1.weight.dtype elif self.language == "chinese": pretrained = task_config.pretrained_text t_config = AutoConfig.from_pretrained(pretrained) if task_config.rank == 0: logger.info("name:{},chinesebert_config:{}".format(pretrained, t_config)) self.chinese_encoder = AutoModel.from_pretrained(pretrained) # logger.info("random Roberta") # self.chinese_encoder = RobertaModel(RobertaConfig()) self.text_proj = nn.Linear(cross_config.chinese_hidden_size, cross_config.temporal_hidden_size) else: raise NotImplementedError("wrong language") def forward(self, input_ids, attention_mask, return_hidden=False): bs_pair = input_ids.size(0) if self.language == "english": text_output, hidden = self.encode_text(input_ids, return_hidden=True) else: temp_output = self.chinese_encoder(input_ids, attention_mask=attention_mask) # logger.info("hidden:{},text_output:{}".format(temp_output[0].shape, temp_output[1].shape)) hidden = self.text_proj(temp_output[0]) text_output = self.text_proj(temp_output[1]) text_output = text_output.view(bs_pair, text_output.size(-1)) hidden = hidden.view(bs_pair, -1, hidden.size(-1)) if return_hidden: return hidden else: return text_output def encode_text(self, text, return_hidden=False): x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) x = x + pos_emd x = x.permute(1, 0, 2) # NLD -> LND x = self.transformer(x) x = x.permute(1, 0, 2) # LND -> NLD hidden = self.ln_final(x).type(self.dtype) @ self.text_projection # x.shape = [batch_size, n_ctx, transformer.width] # take features from the eot embedding (eot_token is the highest number in each sequence) x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] if return_hidden: return x.float(), hidden.float() return x.float() class BertLMPredictionHead(nn.Module): def __init__(self, config): super(BertLMPredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.vocab_size,bias=False,) self.bias = nn.Parameter(torch.zeros(config.vocab_size)) self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super(BertPredictionHeadTransform, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) if isinstance(config.hidden_act, str) or ( sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) ): self.transform_act_fn = ACT2FN[config.hidden_act] else: self.transform_act_fn = config.hidden_act self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(BertLayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias