Spaces:
Runtime error
Runtime error
File size: 4,487 Bytes
34f251f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import torch.nn as nn
from .xattn import CrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
class WrapperLayer(nn.Module):
"""
WrapperLayer is a wrapper around the CrossAttentionBlock and DecoderLayer.
"""
def __init__(
self, cross_attn_layer, decoder_layer, gradient_checkpointing=False
):
super().__init__()
self.cross_attn_layer = cross_attn_layer
self.decoder_layer = decoder_layer
self.vis_x = None
if self.cross_attn_layer is not None:
self.cross_attn_layer._use_gradient_checkpointing = (
gradient_checkpointing
)
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x):
self.vis_x = vis_x
def forward(
self,
lang_x,
attention_mask=None,
**decoder_layer_kwargs,
):
# Cross attention
if self.cross_attn_layer is not None:
if self.vis_x is None:
raise ValueError("vis_x must be conditioned before forward pass")
lang_x = self.cross_attn_layer(
lang_x,
self.vis_x
)
# Normal decoder layer
lang_x = self.decoder_layer(
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
)
return lang_x
class phEYELMMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_pheye(
self,
lang_hidden_size,
vis_hidden_size,
dtype,
cross_attn_every_n_layers,
gradient_checkpointing,
reduce_factor=1,
from_layer=0
):
"""
Initialize phEYE by adding a new cross attn to the decoder.
"""
self.old_decoder_blocks = self._get_decoder_layers()
self.cross_attn_layers = nn.ModuleList(
[
CrossAttentionBlock(
dim_text=lang_hidden_size, dim_visual=vis_hidden_size, reduce_factor=reduce_factor, layer_idx=layer_idx, n_decoder_layers=len(self.old_decoder_blocks), dtype=dtype
)
if (layer_idx + 1) % cross_attn_every_n_layers == 0 and layer_idx >= from_layer
else None
for layer_idx, _ in enumerate(self._get_decoder_layers())
]
)
self.init_pheye_layers(gradient_checkpointing)
self.initialized_pheye = True
self._use_cached_vision_x = False
def init_pheye_layers(self, gradient_checkpointing):
"""
Re initializes the WrapperLayers.
Propagates any changes made to self.cross_attn_layers or self.old_decoder_blocks
"""
self._set_decoder_layers(
nn.ModuleList(
[
WrapperLayer(
cross_attn_layer, decoder_layer, gradient_checkpointing
)
for cross_attn_layer, decoder_layer in zip(
self.cross_attn_layers, self.old_decoder_blocks
)
]
)
)
def forward(self, input_ids, attention_mask, **kwargs):
if not self.initialized_pheye:
raise ValueError(
"phEYE layers are not initialized. Please call `init_pheye` first."
)
kwargs["input_ids"] = input_ids
kwargs["attention_mask"] = attention_mask
return super().forward(**kwargs) # Call the other parent's forward method
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None) |