Pheye / pheye_builder /wrapper_lm.py
miguelcarv's picture
first commit
34f251f
raw
history blame
No virus
4.49 kB
import torch.nn as nn
from .xattn import CrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
class WrapperLayer(nn.Module):
"""
WrapperLayer is a wrapper around the CrossAttentionBlock and DecoderLayer.
"""
def __init__(
self, cross_attn_layer, decoder_layer, gradient_checkpointing=False
):
super().__init__()
self.cross_attn_layer = cross_attn_layer
self.decoder_layer = decoder_layer
self.vis_x = None
if self.cross_attn_layer is not None:
self.cross_attn_layer._use_gradient_checkpointing = (
gradient_checkpointing
)
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x):
self.vis_x = vis_x
def forward(
self,
lang_x,
attention_mask=None,
**decoder_layer_kwargs,
):
# Cross attention
if self.cross_attn_layer is not None:
if self.vis_x is None:
raise ValueError("vis_x must be conditioned before forward pass")
lang_x = self.cross_attn_layer(
lang_x,
self.vis_x
)
# Normal decoder layer
lang_x = self.decoder_layer(
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
)
return lang_x
class phEYELMMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_pheye(
self,
lang_hidden_size,
vis_hidden_size,
dtype,
cross_attn_every_n_layers,
gradient_checkpointing,
reduce_factor=1,
from_layer=0
):
"""
Initialize phEYE by adding a new cross attn to the decoder.
"""
self.old_decoder_blocks = self._get_decoder_layers()
self.cross_attn_layers = nn.ModuleList(
[
CrossAttentionBlock(
dim_text=lang_hidden_size, dim_visual=vis_hidden_size, reduce_factor=reduce_factor, layer_idx=layer_idx, n_decoder_layers=len(self.old_decoder_blocks), dtype=dtype
)
if (layer_idx + 1) % cross_attn_every_n_layers == 0 and layer_idx >= from_layer
else None
for layer_idx, _ in enumerate(self._get_decoder_layers())
]
)
self.init_pheye_layers(gradient_checkpointing)
self.initialized_pheye = True
self._use_cached_vision_x = False
def init_pheye_layers(self, gradient_checkpointing):
"""
Re initializes the WrapperLayers.
Propagates any changes made to self.cross_attn_layers or self.old_decoder_blocks
"""
self._set_decoder_layers(
nn.ModuleList(
[
WrapperLayer(
cross_attn_layer, decoder_layer, gradient_checkpointing
)
for cross_attn_layer, decoder_layer in zip(
self.cross_attn_layers, self.old_decoder_blocks
)
]
)
)
def forward(self, input_ids, attention_mask, **kwargs):
if not self.initialized_pheye:
raise ValueError(
"phEYE layers are not initialized. Please call `init_pheye` first."
)
kwargs["input_ids"] = input_ids
kwargs["attention_mask"] = attention_mask
return super().forward(**kwargs) # Call the other parent's forward method
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)