InternVL-14B-224px / modeling_internvl.py
czczup's picture
Update modeling_internvl.py
0bffb65 verified
raw
history blame
21.1 kB
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from peft import LoraConfig, get_peft_model
from timm.models.layers import DropPath
from torch import nn
from transformers import GenerationConfig
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_internvl import InternVLConfig
from .modeling_intern_vit import (InternVisionEmbeddings, InternVisionEncoder,
InternVisionModel)
from .modeling_qllama import LlamaForCausalLM, _expand_mask, _make_causal_mask
try:
from .flash_attention import FlashAttention # v1/v2
except:
print('FlashAttention is not installed.')
logger = logging.get_logger(__name__)
class InternVLPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = InternVLConfig
base_model_prefix = 'internvl'
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [
r'position_ids',
]
_no_split_modules = ['InternAttention', 'LlamaDecoderLayer', 'LlamaForCausalLM']
_skip_keys_device_placement = 'past_key_values'
_keep_in_fp32_modules = ['wo']
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_range
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=factor)
if hasattr(module, 'bias') and module.bias is not None:
module.bias.data.zero_()
if isinstance(module, InternVisionEmbeddings):
if hasattr(self.config, 'vision_config'):
factor = self.config.vision_config.initializer_range
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, InternVisionModel):
module.gradient_checkpointing = value
if isinstance(module, InternVisionEncoder):
module.gradient_checkpointing = value
class CrossAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None, out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
super().__init__()
self.norm1_q = norm_layer(dim)
self.norm1_k = norm_layer(dim)
self.norm1_v = norm_layer(dim)
self.cross_attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_attn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv, pos_q, pos_k = x, 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
x = x.squeeze(1)
return x
class InternVLModel(InternVLPreTrainedModel):
config_class = InternVLConfig
main_input_name = 'pixel_values'
def __init__(self, config: InternVLConfig):
super().__init__(config)
text_hidden_size = config.qllama_config.hidden_size
vision_hidden_size = config.vision_config.hidden_size
clip_embed_dim = config.clip_embed_dim
attn_pool_num_heads = config.attn_pool_num_heads
config.qllama_config.num_query_token = config.num_query_token
self.num_query_token = config.num_query_token
self.label_smoothing = config.label_smoothing
self.vision_model = InternVisionModel(config.vision_config) # frozen
self.qllama = LlamaForCausalLM(config.qllama_config) # frozen
self.query_tokens = nn.Parameter( # trainable
torch.zeros(1, config.num_query_token, text_hidden_size)
)
self.text_projection = nn.Parameter(torch.empty(text_hidden_size, clip_embed_dim)) # frozen
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # trainable
self.clip_projector = AttentionPoolingBlock( # frozen
dim=vision_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
self.clip_projector2 = AttentionPoolingBlock( # trainable
dim=text_hidden_size, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
self.itm_head = nn.Linear(text_hidden_size, 2) # trainable
self.gradient_checkpointing = True
# Initialize weights and apply final processing
# self.post_init()
if config.use_backbone_lora:
self.wrap_backbone_lora(r=config.use_backbone_lora)
if config.use_qllama_lora:
self.wrap_qllama_lora(r=config.use_qllama_lora)
if config.force_image_size:
self.vision_model.resize_pos_embeddings(
old_size=config.vision_config.image_size,
new_size=config.force_image_size,
patch_size=config.vision_config.patch_size
)
def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.vision_model = get_peft_model(self.vision_model, lora_config)
self.vision_model.print_trainable_parameters()
def wrap_qllama_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
lora_config = LoraConfig(
r=r,
target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.qllama = get_peft_model(self.qllama, lora_config)
self.qllama.print_trainable_parameters()
def get_input_embeddings(self):
return self.qllama.get_input_embeddings()
def set_input_embeddings(self, value):
self.qllama.set_input_embeddings(value)
def set_output_embeddings(self, new_embeddings):
self.qllama.set_output_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.qllama.get_output_embeddings()
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.FloatTensor,
attention_mask: torch.LongTensor,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
image_embeds = vision_outputs[0]
batch_size = image_embeds.shape[0]
input_embeds = self.get_input_embeddings()(input_ids)
query_tokens = self.query_tokens.repeat(batch_size, 1, 1)
input_embeds = torch.cat([query_tokens, input_embeds], dim=1)
image_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device)
attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
outputs = self.qllama.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
vision_hidden_states=image_embeds,
generation_config=generation_config,
use_zero_attention_mask=True,
**generate_kwargs,
)
return outputs
def get_text_features(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
r"""
Returns:
text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
contains the language model logits, the past key values and the hidden states if
`output_hidden_states=True`.
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_embeds = self.get_input_embeddings()(input_ids)
attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask += _make_causal_mask(
(attention_mask.shape[0], attention_mask.shape[2]),
input_embeds.dtype,
device=input_embeds.device
)
if type(self.qllama.model) == LlamaForCausalLM:
outputs = self.qllama.model.model.forward_train(
inputs_embeds=input_embeds,
vision_hidden_states=None,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
).last_hidden_state
else:
outputs = self.qllama.model.forward_train(
inputs_embeds=input_embeds,
vision_hidden_states=None,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
).last_hidden_state
return outputs
def get_image_features(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
image_embeds = vision_outputs[0]
backbone_embeds = image_embeds
batch_size = image_embeds.shape[0]
input_embeds = self.query_tokens.repeat(batch_size, 1, 1)
attention_mask = torch.ones(input_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
attention_mask = _expand_mask(attention_mask, input_embeds.dtype).to(
input_embeds.device) # [bsz, 1, tgt_seq_len, src_seq_len]
if type(self.qllama.model) == LlamaForCausalLM:
outputs = self.qllama.model.model.forward_train(
inputs_embeds=input_embeds,
vision_hidden_states=image_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
).last_hidden_state
else:
outputs = self.qllama.model.forward_train(
inputs_embeds=input_embeds,
vision_hidden_states=image_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
).last_hidden_state
return backbone_embeds, outputs
def encode_image(self, image, mode):
if mode == 'InternVL-C':
vision_outputs = self.vision_model(
pixel_values=image,
output_hidden_states=False,
return_dict=True)
image_embeds = vision_outputs[0]
image_embeds = self.clip_projector(image_embeds)
elif mode == 'InternVL-G':
backbone_embeds, image_embeds = self.get_image_features(
pixel_values=image,
output_hidden_states=False,
return_dict=True,
)
backbone_embeds = self.clip_projector(backbone_embeds)
image_embeds = self.clip_projector2(image_embeds)
# ensemble
backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
image_embeds = image_embeds + backbone_embeds
else:
raise NotImplementedError
return image_embeds
def encode_text(self, text):
attention_mask = text > 0
text_embeds = self.get_text_features(
input_ids=text,
attention_mask=attention_mask,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
text_embeds = text_embeds @ self.text_projection
return text_embeds
def forward(self, image, text, mode='InternVL-C'):
assert mode in ['InternVL-C', 'InternVL-G'], 'mode must be InternVL-C or InternVL-G'
image_features = self.encode_image(image, mode)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
class InternVL_C(InternVLModel):
def encode_image(self, image):
vision_outputs = self.vision_model(
pixel_values=image,
output_hidden_states=False,
return_dict=True)
image_embeds = vision_outputs[0]
image_embeds = self.clip_projector(image_embeds)
return image_embeds
def encode_text(self, text):
attention_mask = text > 0
text_embeds = self.get_text_features(
input_ids=text,
attention_mask=attention_mask,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
text_embeds = text_embeds @ self.text_projection
return text_embeds
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
class InternVL_G(InternVLModel):
def encode_image(self, image):
backbone_embeds, image_embeds = self.get_image_features(
pixel_values=image,
output_hidden_states=False,
return_dict=True,
)
backbone_embeds = self.clip_projector(backbone_embeds)
image_embeds = self.clip_projector2(image_embeds)
# ensemble
backbone_embeds = backbone_embeds / backbone_embeds.norm(dim=1, keepdim=True)
image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
image_embeds = image_embeds + backbone_embeds
return image_embeds
def encode_text(self, text):
attention_mask = text > 0
text_embeds = self.get_text_features(
input_ids=text,
attention_mask=attention_mask,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
text_embeds = text_embeds[torch.arange(text_embeds.shape[0]), attention_mask.sum(1) - 1]
text_embeds = text_embeds @ self.text_projection
return text_embeds
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text