infimm-hd / flamingo.py
liuhaogeng
first commit
b0b3b00
import inspect
import torch
import numpy as np
from einops import rearrange
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
try:
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
except:
from torch.utils.checkpoint import checkpoint
def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None):
""" Sinusoid position encoding table """
def cal_angle(position, hid_idx):
return position / np.power(10000, 2 * (hid_idx // 2) / d_hid)
def get_posi_angle_vec(position):
return [cal_angle(position, hid_j) for hid_j in range(d_hid)]
sinusoid_table = np.array(
[get_posi_angle_vec(pos_i) for pos_i in range(n_position)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
if padding_idx is not None:
# zero vector for padding dimension
sinusoid_table[padding_idx] = 0.0
return torch.FloatTensor(sinusoid_table)
def construct_position_encoding(vis_dim, max_pos, rows, cols):
seq = get_sinusoid_encoding_table(max_pos, int(vis_dim/2))
y_coords, x_coords = torch.meshgrid(torch.arange(rows), torch.arange(cols), indexing='ij')
row_positions = seq[y_coords.flatten(), :]
col_positions = seq[x_coords.flatten(), :]
position_encoding = torch.cat((col_positions, row_positions), dim=-1)
return position_encoding
def unwrap_fsdp(m):
if isinstance(m, FSDP):
return unwrap_fsdp(m.module)
return m
def accepts_parameter(func, parameter_name):
signature = inspect.signature(func)
return parameter_name in signature.parameters
class Flamingo(nn.Module):
def __init__(
self,
vision_encoder: nn.Module,
lang_encoder: nn.Module,
eoc_token_id: int,
media_token_id: int,
vis_dim: int,
cross_attn_every_n_layers: int = 1,
gradient_checkpointing: bool = False,
use_ft_layernorm: bool = False,
use_ft_flash_attention: bool = False,
enable_init_network_params: bool = False,
initializer_range: float = 0.02,
):
"""
Args:
vision_encoder (nn.Module): HF CLIPModel
lang_encoder (nn.Module): HF causal language model
eoc_token_id (int): Token id for <|endofchunk|>
media_token_id (int): Token id for <image>
vis_dim (int): Dimension of the visual features.
Visual features are projected to match this shape along the last dimension.
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
"""
super().__init__()
self.vit_use_grad = False
self.eoc_token_id = eoc_token_id
self.media_token_id = media_token_id
self.vis_dim = vis_dim
if hasattr(lang_encoder.config, "d_model"):
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
else:
self.lang_dim = lang_encoder.config.hidden_size
self.vision_encoder = (
vision_encoder.visual
if hasattr(vision_encoder, "visual")
else vision_encoder
)
self.lang_encoder = lang_encoder
self.lang_encoder.init_flamingo(
media_token_id=media_token_id,
lang_hidden_size=self.lang_dim,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
use_ft_layernorm=use_ft_layernorm,
use_ft_flash_attention=use_ft_flash_attention,
enable_init_network_params=enable_init_network_params,
initializer_range=initializer_range,
)
self._use_gradient_checkpointing = gradient_checkpointing
def forward(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
image_mask: torch.Tensor = None,
subimage_shape: torch.Tensor = None,
clear_conditioned_layers: bool = True,
past_key_values=None,
use_cache: bool = False,
):
"""
Forward pass of Flamingo.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W) with F=1
lang_x (torch.Tensor): Language input ids
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
labels (torch.Tensor, optional): Labels. Defaults to None.
clear_conditioned_layers: if True, clear the conditioned layers
once the foward pass is completed. Set this to false if the
same set of images will be reused in another subsequent
forward pass.
past_key_values: pre-computed values to pass to language model.
See past_key_values documentation in Hugging Face
CausalLM models.
use_cache: whether to use cached key values. See use_cache
documentation in Hugging Face CausalLM models.
"""
assert (
self.lang_encoder.initialized_flamingo
), "Flamingo layers are not initialized. Please call `init_flamingo` first."
assert (
self.lang_encoder._use_cached_vision_x or vision_x is not None
), "Must provide either vision_x or have precached media using cache_media()."
if self.lang_encoder._use_cached_vision_x:
# Case: use cached; vision_x should be cached and other
# vision-related inputs should not be provided.
assert vision_x is None, (
"Expect vision_x to be None when media has been cached using"
" cache_media(). Try uncache_media() first."
)
assert self.lang_encoder.is_conditioned()
else:
# Case: do not use caching (i.e. this is a standard forward pass);
self._encode_vision_x(vision_x=vision_x, image_mask=image_mask, subimage_shape=subimage_shape)
self._condition_media_locations(input_ids=lang_x)
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
)
if clear_conditioned_layers:
self.lang_encoder.clear_conditioned_layers()
return output
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
**kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
images in the same chunk are collated along T_img, and frames are collated along F
currently only F=1 is supported (single-frame videos)
lang_x (torch.Tensor): Language input
shape (B, T_txt)
**kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
max_length (int, optional): Maximum length of the output. Defaults to None.
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
num_beams (int, optional): Number of beams. Defaults to 1.
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
temperature (float, optional): Temperature. Defaults to 1.0.
top_k (int, optional): Top k. Defaults to 50.
top_p (float, optional): Top p. Defaults to 1.0.
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
length_penalty (float, optional): Length penalty. Defaults to 1.0.
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
do_sample (bool, optional): Do sample. Defaults to False.
early_stopping (bool, optional): Early stopping. Defaults to False.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
subimage_shape = kwargs.pop("subimage_shape", None)
image_mask = kwargs.pop("image_mask", None)
num_beams = kwargs.pop("num_beams", 1)
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
if image_mask is not None:
image_mask = image_mask.repeat_interleave(num_beams, dim=0)
if subimage_shape is not None:
subimage_shape = subimage_shape.repeat_interleave(num_beams, dim=0)
self.lang_encoder._use_cached_vision_x = True
self._encode_vision_x(vision_x=vision_x, image_mask=image_mask, subimage_shape=subimage_shape)
eos_token_id = kwargs.pop("eos_token_id", self.eoc_token_id)
output = self.lang_encoder.generate(
input_ids=lang_x,
attention_mask=attention_mask,
eos_token_id=eos_token_id,
num_beams=num_beams,
**kwargs,
)
self.lang_encoder.clear_conditioned_layers()
self.lang_encoder._use_cached_vision_x = False
return output
def _encode_vision_x(self, vision_x: torch.Tensor, image_mask: torch.Tensor=None, subimage_shape: torch.Tensor=None):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
assert F == 1, "Only single frame supported"
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
if not self.vit_use_grad:
with torch.no_grad():
module_to_inspect = unwrap_fsdp(self.vision_encoder)
if accepts_parameter(module_to_inspect.forward, "return_all_features"):
vision_x = self.vision_encoder(vision_x, return_all_features=True)
else:
vision_x = self.vision_encoder(vision_x)[1]
else:
module_to_inspect = unwrap_fsdp(self.vision_encoder)
if accepts_parameter(module_to_inspect.forward, "return_all_features"):
if self.training:
vision_x = checkpoint(self.vision_encoder, vision_x, True)
else:
vision_x = self.vision_encoder(vision_x, return_all_features=True)
else:
vision_x = self.vision_encoder(vision_x)[1]
vision_x = rearrange(vision_x, "(b T F) v d -> b (T F) v d", b=b, T=T, F=F)
pos_emb = torch.zeros((T,self.vis_dim)).to(vision_x.dtype).to(vision_x.device)
for i in range(subimage_shape.shape[0]):
cols, rows = int(subimage_shape[i,0]), int(subimage_shape[i,1])
tmp_pos_emb = construct_position_encoding(vision_x.shape[-1], 20, rows, cols).to(vision_x.dtype).to(vision_x.device)
pos_emb[1:int(cols*rows)+1,:] = tmp_pos_emb
vision_x = vision_x + pos_emb.unsqueeze(1).unsqueeze(0).detach()
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x((vision_x, image_mask))
def _condition_media_locations(self, input_ids: torch.Tensor):
"""
Compute the media token locations from lang_x and condition the language model on these.
Args:
input_ids (torch.Tensor): Language input
shape (B, T_txt)
"""
print(111)
media_locations = input_ids == self.media_token_id
# make all of the seq focus on the first fake image to avoid nan
# media_locations = torch.where(tmp_mask==False, tmp_mask, media_locations)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_media_locations(media_locations)
def cache_media(self, input_ids: torch.Tensor, vision_x: torch.Tensor):
"""
Pre-cache a prompt/sequence of images / text for log-likelihood evaluations.
All subsequent calls to forward() will generate attending to the LAST
image in vision_x.
This is not meant to be used to cache things for generate().
Args:
input_ids (torch.Tensor): Language input
shape (B, T_txt)
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
"""
self._encode_vision_x(vision_x=vision_x)
self._condition_media_locations(input_ids=input_ids)
self.lang_encoder._use_cached_vision_x = True
def uncache_media(self):
"""
Clear all conditioning.
"""
self.lang_encoder.clear_conditioned_layers()
self.lang_encoder._use_cached_vision_x = False