Spaces:
Runtime error
Runtime error
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import os | |
import copy | |
import json | |
import math | |
import logging | |
import tarfile | |
import tempfile | |
import shutil | |
import sys | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from .file_utils import cached_path | |
from .until_config import PretrainedConfig | |
from .until_module import PreTrainedModel, LayerNorm, ACT2FN | |
from collections import OrderedDict | |
from modules.module_clip import build_model, CLIP, convert_weights | |
from transformers import AutoConfig, AutoModel, RobertaModel, RobertaConfig | |
logger = logging.getLogger(__name__) | |
PRETRAINED_MODEL_ARCHIVE_MAP = {} | |
CONFIG_NAME = 'cross_config.json' | |
WEIGHTS_NAME = 'cross_pytorch_model.bin' | |
def gelu(x): | |
"""Implementation of the gelu activation function. | |
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): | |
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
Also see https://arxiv.org/abs/1606.08415 | |
""" | |
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |
def swish(x): | |
return x * torch.sigmoid(x) | |
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |
class CrossConfig(PretrainedConfig): | |
"""Configuration class to store the configuration of a `CrossModel`. | |
""" | |
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP | |
config_name = CONFIG_NAME | |
weights_name = WEIGHTS_NAME | |
def __init__(self, | |
vocab_size_or_config_json_file, | |
hidden_size=768, | |
num_hidden_layers=12, | |
num_attention_heads=12, | |
intermediate_size=3072, | |
hidden_act="gelu", | |
hidden_dropout_prob=0.1, | |
attention_probs_dropout_prob=0.1, | |
max_position_embeddings=512, | |
type_vocab_size=2, | |
initializer_range=0.02): | |
"""Constructs CrossConfig. | |
Args: | |
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CrossModel`. | |
hidden_size: Size of the encoder layers and the pooler layer. | |
num_hidden_layers: Number of hidden layers in the Transformer encoder. | |
num_attention_heads: Number of attention heads for each attention layer in | |
the Transformer encoder. | |
intermediate_size: The size of the "intermediate" (i.e., feed-forward) | |
layer in the Transformer encoder. | |
hidden_act: The non-linear activation function (function or string) in the | |
encoder and pooler. If string, "gelu", "relu" and "swish" are supported. | |
hidden_dropout_prob: The dropout probabilitiy for all fully connected | |
layers in the embeddings, encoder, and pooler. | |
attention_probs_dropout_prob: The dropout ratio for the attention | |
probabilities. | |
max_position_embeddings: The maximum sequence length that this model might | |
ever be used with. Typically set this to something large just in case | |
(e.g., 512 or 1024 or 2048). | |
type_vocab_size: The vocabulary size of the `token_type_ids` passed into | |
`CrossModel`. | |
initializer_range: The sttdev of the truncated_normal_initializer for | |
initializing all weight matrices. | |
""" | |
if isinstance(vocab_size_or_config_json_file, str): | |
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: | |
json_config = json.loads(reader.read()) | |
for key, value in json_config.items(): | |
self.__dict__[key] = value | |
elif isinstance(vocab_size_or_config_json_file, int): | |
self.vocab_size = vocab_size_or_config_json_file | |
self.hidden_size = hidden_size | |
self.num_hidden_layers = num_hidden_layers | |
self.num_attention_heads = num_attention_heads | |
self.hidden_act = hidden_act | |
self.intermediate_size = intermediate_size | |
self.hidden_dropout_prob = hidden_dropout_prob | |
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |
self.max_position_embeddings = max_position_embeddings | |
self.type_vocab_size = type_vocab_size | |
self.initializer_range = initializer_range | |
else: | |
raise ValueError("First argument must be either a vocabulary size (int)" | |
"or the path to a pretrained model config file (str)") | |
class QuickGELU(nn.Module): | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |
class ResidualAttentionBlock(nn.Module): | |
def __init__(self, d_model: int, n_head: int): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(d_model, n_head) | |
self.ln_1 = LayerNorm(d_model) | |
self.mlp = nn.Sequential(OrderedDict([ | |
("c_fc", nn.Linear(d_model, d_model * 4)), | |
("gelu", QuickGELU()), | |
("c_proj", nn.Linear(d_model * 4, d_model)) | |
])) | |
self.ln_2 = LayerNorm(d_model) | |
self.n_head = n_head | |
def attention(self, x: torch.Tensor, attn_mask: torch.Tensor): | |
attn_mask_ = attn_mask.repeat(self.n_head, 1, 1) | |
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask_)[0] | |
def forward(self, para_tuple: tuple): | |
# x: torch.Tensor, attn_mask: torch.Tensor | |
# print(para_tuple) | |
x, attn_mask = para_tuple | |
x = x + self.attention(self.ln_1(x), attn_mask) | |
x = x + self.mlp(self.ln_2(x)) | |
return (x, attn_mask) | |
class Transformer(nn.Module): | |
def __init__(self, width: int, layers: int, heads: int): | |
super().__init__() | |
self.width = width | |
self.layers = layers | |
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) | |
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor): | |
# logger.info("x.shpae:{},attn_mask:{}".format(x.shape, attn_mask.shape)) | |
return self.resblocks((x, attn_mask))[0] | |
class VisualEncoder(nn.Module): | |
def __init__(self, task_config, cross_config): | |
super().__init__() | |
pretrained_clip_name = cross_config.pretrained_clip_name | |
if task_config.local_rank == 0: | |
logger.info("pretrained_clip_name:{}".format(pretrained_clip_name)) | |
clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name) | |
clip = build_model(clip_state_dict, local_rank=task_config.local_rank) | |
self.use_temp = task_config.use_temp | |
self.is_vit = copy.deepcopy(clip.vit) | |
self.visual = copy.deepcopy(clip.visual) | |
if self.use_temp: | |
self.temporal_transformer = Transformer(width=cross_config.temporal_hidden_size, | |
layers=cross_config.temporal_hidden_layers, | |
heads=cross_config.temporal_attention_heads) | |
self.frame_position_embeddings = nn.Embedding(cross_config.max_position_embeddings, | |
cross_config.temporal_hidden_size) | |
# use clip.transformer to initial temporal_transformer | |
# for param_1, param_2 in zip(self.temporal_transformer.parameters(), clip.transformer.parameters()): | |
# param_1.data.copy_(param_2.data) # initialize | |
# if task_config.local_rank == 0: | |
# logger.info("clip.positional_embedding:{}".format(clip.positional_embedding)) | |
# self.frame_position_embeddings.weight = copy.deepcopy(clip.positional_embedding) | |
def forward(self, video, video_frames): | |
# encode frames | |
bs, frames, channel, h, w = video.shape | |
# [bs*frame, 3, 224, 224] | |
video = video.view(bs * frames, channel, h, w) | |
# logger.info("video_b.shape:{}, dtype:{}".format(video_b.shape, video_b.dtype)) | |
# logger.info("video_frame[{}]:{}".format(b, video_frame)) | |
visual_hidden = self.encode_image(video, video_frame=frames) | |
# [bs, frame, hidden_size] | |
# logger.info("visual_hidden.shape:{}".format(visual_hidden.shape)) | |
visual_hidden = visual_hidden.view(bs, frames, visual_hidden.size(-1)) | |
# logger.info("visual_hidden1.shape:{}".format(visual_hidden.shape)) | |
# get temporal information | |
visual_hidden_original = visual_hidden | |
frame_output = visual_hidden_original | |
if self.use_temp: | |
seq_length = visual_hidden.size(1) | |
position_ids = torch.arange(seq_length, dtype=torch.long, device=visual_hidden.device) | |
# logger.info("position_ids.shape:{}".format(position_ids.shape)) | |
frame_position_embeddings = self.frame_position_embeddings(position_ids) | |
# logger.info("frame_position_embeddings.shape:{}".format(frame_position_embeddings.shape)) | |
visual_hidden = visual_hidden + frame_position_embeddings | |
video_mask = torch.ones([bs, frames], device=visual_hidden.device) | |
extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 | |
extended_video_mask = extended_video_mask.expand(-1, video_mask.size(1), -1) | |
visual_hidden = visual_hidden.permute(1, 0, 2) # NLD -> LND | |
visual_hidden = self.temporal_transformer(visual_hidden, extended_video_mask) | |
visual_hidden = visual_hidden.permute(1, 0, 2) # LND -> NLD | |
visual_hidden = visual_hidden + visual_hidden_original | |
# logger.info("visual_hidden.shape:{}".format(visual_hidden.shape)) | |
visual_output = visual_hidden / visual_hidden.norm(dim=-1, keepdim=True) | |
# [bs, frames,512] -> [bs, 512] | |
visual_output = torch.mean(visual_output, dim=1) | |
# logger.info("visual_hidden mean.shape:{}".format(visual_hidden.shape)) | |
# logger.info("visual encoder visual_output.shape:{}".format(visual_output.shape)) | |
return visual_output, frame_output | |
def dtype(self): | |
return self.visual.conv1.weight.dtype | |
def encode_image(self, image, return_hidden=False, video_frame=-1): | |
if self.is_vit: | |
# logger.info("image.shape:{}".format(image.shape)) | |
# hidden = self.visual(image, video_frame=video_frame) | |
hidden = self.visual(image.type(self.dtype), video_frame=video_frame) | |
# logger.info("hidden1.shape:{}".format(hidden.shape)) | |
hidden = self.visual.ln_post(hidden) @ self.visual.proj | |
# logger.info("hidden2.shape:{}".format(hidden.shape)) | |
x = hidden[:, 0, :] | |
# x = hidden | |
else: | |
hidden = self.visual(image) | |
x = hidden | |
if return_hidden: | |
return x.float(), hidden.float() | |
return x.float() | |
class TextEncoder(nn.Module): | |
def __init__(self, task_config, cross_config): | |
super().__init__() | |
self.language = task_config.language | |
pretrained_clip_name = cross_config.pretrained_clip_name | |
if task_config.local_rank == 0: | |
logger.info("pretrained_clip_name:{}".format(pretrained_clip_name)) | |
clip_state_dict = CLIP.get_config(pretrained_clip_name=pretrained_clip_name) | |
clip = build_model(clip_state_dict, local_rank=task_config.local_rank) | |
self.logit_scale = copy.deepcopy(clip_state_dict["logit_scale"]) | |
if self.language == "english": | |
self.token_embedding = copy.deepcopy(clip.token_embedding) | |
self.positional_embedding = copy.deepcopy(clip.positional_embedding) | |
self.transformer = copy.deepcopy(clip.transformer) | |
self.ln_final = copy.deepcopy(clip.ln_final) | |
self.text_projection = copy.deepcopy(clip.text_projection) | |
self.dtype = clip.visual.conv1.weight.dtype | |
elif self.language == "chinese": | |
pretrained = task_config.pretrained_text | |
t_config = AutoConfig.from_pretrained(pretrained) | |
if task_config.rank == 0: | |
logger.info("name:{},chinesebert_config:{}".format(pretrained, t_config)) | |
self.chinese_encoder = AutoModel.from_pretrained(pretrained) | |
# logger.info("random Roberta") | |
# self.chinese_encoder = RobertaModel(RobertaConfig()) | |
self.text_proj = nn.Linear(cross_config.chinese_hidden_size, cross_config.temporal_hidden_size) | |
else: | |
raise NotImplementedError("wrong language") | |
def forward(self, input_ids, attention_mask, return_hidden=False): | |
bs_pair = input_ids.size(0) | |
if self.language == "english": | |
text_output, hidden = self.encode_text(input_ids, return_hidden=True) | |
else: | |
temp_output = self.chinese_encoder(input_ids, attention_mask=attention_mask) | |
# logger.info("hidden:{},text_output:{}".format(temp_output[0].shape, temp_output[1].shape)) | |
hidden = self.text_proj(temp_output[0]) | |
text_output = self.text_proj(temp_output[1]) | |
text_output = text_output.view(bs_pair, text_output.size(-1)) | |
hidden = hidden.view(bs_pair, -1, hidden.size(-1)) | |
if return_hidden: | |
return hidden | |
else: | |
return text_output | |
def encode_text(self, text, return_hidden=False): | |
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] | |
pos_emd = self.positional_embedding[:x.size(1), :].type(self.dtype) | |
x = x + pos_emd | |
x = x.permute(1, 0, 2) # NLD -> LND | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
hidden = self.ln_final(x).type(self.dtype) @ self.text_projection | |
# x.shape = [batch_size, n_ctx, transformer.width] | |
# take features from the eot embedding (eot_token is the highest number in each sequence) | |
x = hidden[torch.arange(hidden.shape[0]), text.argmax(dim=-1)] | |
if return_hidden: | |
return x.float(), hidden.float() | |
return x.float() | |
class BertLMPredictionHead(nn.Module): | |
def __init__(self, config): | |
super(BertLMPredictionHead, self).__init__() | |
self.transform = BertPredictionHeadTransform(config) | |
# The output weights are the same as the input embeddings, but there is | |
# an output-only bias for each token. | |
self.decoder = nn.Linear(config.hidden_size, config.vocab_size,bias=False,) | |
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) | |
self.decoder.bias = self.bias | |
def forward(self, hidden_states): | |
hidden_states = self.transform(hidden_states) | |
hidden_states = self.decoder(hidden_states) | |
return hidden_states | |
class BertPredictionHeadTransform(nn.Module): | |
def __init__(self, config): | |
super(BertPredictionHeadTransform, self).__init__() | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
if isinstance(config.hidden_act, str) or ( | |
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) | |
): | |
self.transform_act_fn = ACT2FN[config.hidden_act] | |
else: | |
self.transform_act_fn = config.hidden_act | |
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) | |
def forward(self, hidden_states): | |
hidden_states = self.dense(hidden_states) | |
hidden_states = self.transform_act_fn(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states) | |
return hidden_states | |
class BertLayerNorm(nn.Module): | |
def __init__(self, hidden_size, eps=1e-12): | |
"""Construct a layernorm module in the TF style (epsilon inside the square root). | |
""" | |
super(BertLayerNorm, self).__init__() | |
self.weight = nn.Parameter(torch.ones(hidden_size)) | |
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |
self.variance_epsilon = eps | |
def forward(self, x): | |
u = x.mean(-1, keepdim=True) | |
s = (x - u).pow(2).mean(-1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |
return self.weight * x + self.bias |