Pheye / pheye_builder /phEYE.py
miguelcarv's picture
first commit
34f251f
raw
history blame
No virus
8.58 kB
import torch
from peft import LoraConfig, get_peft_model
from torch import nn
import os
class phEYE(nn.Module):
def __init__(
self,
vision_encoder: nn.Module,
lang_encoder: nn.Module,
vis_dim: int,
dtype: torch.dtype,
cross_attn_every_n_layers: int = 1,
gradient_checkpointing: bool = False,
reduce_factor = 1,
from_layer = 0
):
"""
Args:
vision_encoder (nn.Module): module with OpenCLIP model
lang_encoder (nn.Module): HF causal language model
vis_dim (int): Dimension of the visual features.
Visual features are projected to match this shape along the last dimension.
cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1.
"""
super().__init__()
self.vis_dim = vis_dim
if hasattr(lang_encoder.config, "d_model"):
self.lang_dim = lang_encoder.config.d_model # mpt uses d_model
else:
self.lang_dim = lang_encoder.config.hidden_size
self.vision_encoder = vision_encoder
self.lang_encoder = lang_encoder
self.lang_encoder.init_pheye(
lang_hidden_size=self.lang_dim,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
reduce_factor=reduce_factor,
from_layer=from_layer,
dtype=dtype
)
self._use_gradient_checkpointing = gradient_checkpointing
def forward(
self,
vision_x: list,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
labels: torch.Tensor = None,
clear_conditioned_layers: bool = True,
past_key_values = None,
use_cache: bool = False,
device="cpu",
is_textcaps = False
):
"""
Forward pass of phEYE.
Args:
vision_x (list): Vision input
shape (B, C, H, W)
lang_x (torch.Tensor): Language input ids
shape (B, txt_seq)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
labels (torch.Tensor, optional): Labels. Defaults to None.
clear_conditioned_layers: if True, clear the conditioned layers
once the foward pass is completed. Set this to false if the
same set of images will be reused in another subsequent
forward pass.
past_key_values: pre-computed values to pass to language model.
See past_key_values documentation in Hugging Face
CausalLM models.
use_cache: whether to use cached key values. See use_cache
documentation in Hugging Face CausalLM models.
"""
assert (
self.lang_encoder.initialized_pheye
), "Wrapper layers are not initialized. Please call `initialized_pheye` first."
assert (
self.lang_encoder._use_cached_vision_x or vision_x is not None
), "Must provide either vision_x or have precached media using cache_media()."
if self.lang_encoder._use_cached_vision_x:
# Case: use cached; vision_x should be cached and other
# vision-related inputs should not be provided.
assert (
vision_x is None
), "Expect vision_x to be None when media has been cached using cache_media(). Try uncache_media() first."
assert self.lang_encoder.is_conditioned()
else:
# Case: do not use caching (i.e. this is a standard forward pass);
self._encode_vision_x(vision_x=vision_x, device=device, is_textcaps=is_textcaps)
#print(f"Text features shape: {lang_x.shape}")
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
)
if clear_conditioned_layers:
self.lang_encoder.clear_conditioned_layers()
return output
def generate(
self,
vision_x: list,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
device = "cpu",
**kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (list): Vision input
shape (B, C, H, W)
images in the same chunk are collated along T_img, and frames are collated along F
currently only F=1 is supported (single-frame videos)
lang_x (torch.Tensor): Language input
shape (B, T_txt)
**kwargs: see generate documentation in Hugging Face CausalLM models. Some notable kwargs:
max_length (int, optional): Maximum length of the output. Defaults to None.
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
num_beams (int, optional): Number of beams. Defaults to 1.
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
temperature (float, optional): Temperature. Defaults to 1.0.
top_k (int, optional): Top k. Defaults to 50.
top_p (float, optional): Top p. Defaults to 1.0.
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
length_penalty (float, optional): Length penalty. Defaults to 1.0.
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
do_sample (bool, optional): Do sample. Defaults to False.
early_stopping (bool, optional): Early stopping. Defaults to False.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
num_beams = kwargs.pop("num_beams", 1)
self.lang_encoder._use_cached_vision_x = True
self._encode_vision_x(vision_x=vision_x, device=device, repeat=num_beams)
output = self.lang_encoder.generate(
input_ids=lang_x,
attention_mask=attention_mask,
num_beams=num_beams,
**kwargs,
)
self.lang_encoder.clear_conditioned_layers()
self.lang_encoder._use_cached_vision_x = False
return output
def _encode_vision_x(self, vision_x: list, device="cpu", repeat = 1, is_textcaps = False):
"""
Compute vision features by passing images through vision encoder and conditioning language model.
Args:
vision_x (list): Vision input
shape (B, C, H, W)
"""
if is_textcaps:
vision_x = vision_x[::5]
repeat = 5
vision_x = self.vision_encoder(vision_x, device=device)
if repeat > 1:
vision_x = vision_x.repeat_interleave(repeat, dim=0)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
def cache_media(self, vision_x: list, device="cpu"):
"""
Cache vision_x features from list of images for log-likelihood evaluation
This is not meant to be used to cache things for generate().
Args:
vision_x (torch.Tensor): Vision input
shape (B, F, C, H, W)
"""
self._encode_vision_x(vision_x=vision_x, device=device)
self.lang_encoder._use_cached_vision_x = True
def uncache_media(self):
"""
Clear all conditioning.
"""
self.lang_encoder.clear_conditioned_layers()
self.lang_encoder._use_cached_vision_x = False
def save_model(self, _path):
os.mkdir(_path)
torch.save(self.vision_encoder.state_dict(), _path+"vision_encoder.pt")
torch.save(self.lang_encoder.state_dict(), _path+"lang_encoder.pt")
def add_lora_decoder(self):
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "dense", "fc1", "fc2"],
lora_dropout=0.05,
bias="none"
)
self.lang_encoder.old_decoder_blocks = get_peft_model(self.lang_encoder.old_decoder_blocks, config)
def merge_and_unload(self):
self.lang_encoder.old_decoder_blocks = self.lang_encoder.old_decoder_blocks.merge_and_unload()