| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """ |
| | This module implements a Vision Transformer (ViT) with 2D Rotary Position Embeddings, |
| | designed for processing image inputs in vision-language models. |
| | |
| | This module follows Mistral's vision encoder implementation (for their Pistral-12B VLM): |
| | https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py |
| | """ |
| | from functools import partial |
| | from typing import Any, Callable, Mapping, Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| | from .ar_module_normalization import create_norm |
| | from .ar_network_transformer import TransformerBlock |
| | from .log import log |
| |
|
| |
|
| | def get_vit_config(model_name: str) -> Mapping[str, Any]: |
| | """ |
| | Get the ViT configuration for a given model name. |
| | """ |
| | if model_name == "pixtral-12b-vit": |
| | |
| | return dict( |
| | dim=1024, |
| | num_channels=3, |
| | image_size=1024, |
| | patch_size=16, |
| | rope_theta=10000, |
| | ffn_hidden_size=4096, |
| | n_layers=24, |
| | n_heads=16, |
| | n_kv_heads=16, |
| | norm_type="rmsnorm", |
| | norm_eps=1e-5, |
| | image_token_id=10, |
| | ) |
| | else: |
| | raise ValueError(f"Unknown model name: {model_name}") |
| |
|
| |
|
| | def precompute_freqs_cis_2d( |
| | dim: int, |
| | height: int, |
| | width: int, |
| | theta: float, |
| | ) -> torch.Tensor: |
| | """ |
| | Precompute 2D complex tensor for rotary position embedding. |
| | |
| | This function generates a 2D complex tensor used for rotary position embeddings, |
| | which helps the model understand spatial relationships in the input image. |
| | |
| | Args: |
| | dim (int): Dimension of the model (typically the hidden size divided by number of heads). |
| | height (int): Height of the image in patches. |
| | width (int): Width of the image in patches. |
| | theta (float): Base value for the angle calculation, controls the frequency range. |
| | |
| | Returns: |
| | torch.Tensor: 2D complex tensor of shape (height, width, dim // 2). |
| | """ |
| | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| |
|
| | h = torch.arange(height, device=freqs.device) |
| | w = torch.arange(width, device=freqs.device) |
| |
|
| | freqs_h = torch.outer(h, freqs[::2]).float() |
| | freqs_w = torch.outer(w, freqs[1::2]).float() |
| | freqs_2d = torch.cat( |
| | [ |
| | freqs_h[:, None, :].repeat(1, width, 1), |
| | freqs_w[None, :, :].repeat(height, 1, 1), |
| | ], |
| | dim=-1, |
| | ) |
| | return torch.polar(torch.ones_like(freqs_2d), freqs_2d) |
| |
|
| |
|
| | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
| | """ |
| | Reshape frequency tensor for broadcasting with input tensor. |
| | |
| | This function ensures that the frequency tensor can be properly broadcast |
| | with the input tensor during the rotary embedding process. |
| | |
| | Args: |
| | freqs_cis (torch.Tensor): Frequency tensor from precompute_freqs_cis_2d. |
| | x (torch.Tensor): Input tensor to be embedded. |
| | |
| | Returns: |
| | torch.Tensor: Reshaped frequency tensor ready for broadcasting. |
| | """ |
| | ndim = x.ndim |
| | assert 0 <= 1 < ndim, f"ndim is {ndim} but index is {1}" |
| | assert freqs_cis.shape == ( |
| | x.shape[1], |
| | x.shape[-1], |
| | ), f"freqs_cis shape is {freqs_cis.shape} but x shape is {x.shape}" |
| | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] |
| | return freqs_cis.view(*shape) |
| |
|
| |
|
| | def apply_rotary_emb( |
| | xq: torch.Tensor, |
| | xk: torch.Tensor, |
| | *args, |
| | freqs_cis: torch.Tensor, |
| | **kwargs, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Apply rotary positional embeddings to input tensors. |
| | |
| | This function applies the rotary positional embeddings to the query and key tensors, |
| | which helps the model understand spatial relationships in the input. |
| | |
| | Args: |
| | xq (torch.Tensor): Query tensor. |
| | xk (torch.Tensor): Key tensor. |
| | freqs_cis (torch.Tensor): Precomputed frequencies from precompute_freqs_cis_2d. |
| | *args: Variable length argument list (unused). |
| | **kwargs: Arbitrary keyword arguments (unused). |
| | |
| | Returns: |
| | Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors. |
| | """ |
| | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) |
| | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| | return xq_out.type_as(xq), xk_out.type_as(xk) |
| |
|
| |
|
| | class VisionTransformer(nn.Module): |
| | """ |
| | Vision Transformer model for image processing. |
| | |
| | This class implements a Vision Transformer that processes images using a patch-based approach |
| | and applies transformer layers with rotary position embeddings. |
| | |
| | Args: |
| | dim (int): Dimension of the model (hidden size). |
| | num_channels (int): Number of input image channels (e.g., 3 for RGB). |
| | patch_size (int): Size of each image patch (e.g., 16x16 pixels). |
| | n_layers (int): Number of transformer layers. |
| | n_heads (int): Number of attention heads. |
| | ffn_hidden_size (int): Hidden size of the feed-forward network in transformer blocks. |
| | norm_type (str): Type of normalization to use (e.g., "rmsnorm"). |
| | norm_eps (float): Epsilon value for normalization layers. |
| | image_size (int): Size of the input image (assumed square). |
| | rope_theta (float): Base value for rotary position embedding calculation. |
| | attention_dropout (float): Dropout rate for attention layers. |
| | hidden_dropout (float): Dropout rate for hidden layers. |
| | image_token_id (int): Token ID for the image token (if present). |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim: int = 1024, |
| | num_channels: int = 3, |
| | patch_size: int = 16, |
| | n_layers: int = 24, |
| | n_heads: int = 16, |
| | n_kv_heads: int = None, |
| | ffn_hidden_size: int = 4096, |
| | norm_type: str = "rmsnorm", |
| | norm_eps: float = 1e-5, |
| | image_size: int = 1024, |
| | rope_theta: float = 1000000.0, |
| | image_token_id: int = None, |
| | ): |
| | super().__init__() |
| | self.patch_conv = nn.Conv2d( |
| | in_channels=num_channels, |
| | out_channels=dim, |
| | kernel_size=patch_size, |
| | stride=patch_size, |
| | bias=False, |
| | ) |
| | self.ln_pre = create_norm(norm_type=norm_type, dim=dim, eps=norm_eps) |
| | if n_kv_heads is None: |
| | n_kv_heads = n_heads |
| | layer_args = dict( |
| | n_layers=n_layers, |
| | n_heads=n_heads, |
| | n_kv_heads=n_kv_heads, |
| | dim=dim, |
| | use_qk_normalization=False, |
| | max_seq_len=None, |
| | max_batch_size=None, |
| | ffn_hidden_size=ffn_hidden_size, |
| | norm_type=norm_type, |
| | norm_eps=norm_eps, |
| | causal_mask=False, |
| | head_dim=None, |
| | insert_cross_attn=False, |
| | attn_type="full", |
| | ) |
| |
|
| | self.transformer = VisionTransformerBlocks(n_layers=n_layers, args=layer_args) |
| |
|
| | head_dim = dim // n_heads |
| | assert head_dim % 2 == 0, "ROPE requires even head_dim" |
| |
|
| | self.dim = dim |
| | self.n_heads = n_heads |
| | self.max_patches_per_side = image_size // patch_size |
| | self.image_size = image_size |
| | self.patch_size = patch_size |
| | self.rope_theta = rope_theta |
| | self._freqs_cis: Optional[torch.Tensor] = None |
| | self.image_token_id = image_token_id |
| |
|
| | num_params = self.get_num_params() |
| | log.debug(f"Number of model parameters: {round(num_params / 1e6, 3)}M") |
| |
|
| | @classmethod |
| | def build( |
| | cls, |
| | config: Mapping[str, Any], |
| | ) -> "VisionTransformer": |
| | """ |
| | Create a Vision Transformer from a configuration dictionary. |
| | |
| | This class method creates a Vision Transformer from a configuration dictionary, |
| | which is typically loaded from a JSON file or other configuration source. |
| | |
| | Args: |
| | config (Mapping[str, Any]): Configuration dictionary for the Vision Transformer. |
| | |
| | Returns: |
| | VisionTransformer: Vision Transformer model instance. |
| | """ |
| | necessary_keys = ["dim", "num_channels", "patch_size", "n_layers", "n_heads", "ffn_hidden_size", "rope_theta"] |
| | missing_keys = [k for k in necessary_keys if k not in config] |
| | assert len(missing_keys) == 0, f"Missing keys in config: {missing_keys}" |
| | return cls( |
| | **config, |
| | ) |
| |
|
| | def expand_in_channels(self, new_in_channels: int): |
| | """ |
| | Expand the input channels of the patch convolution layer. |
| | This is useful when the input is non-standard, e.g. a 4-channel image with the last channel as the alpha channel. |
| | Note that you should only call this method after the weight is loaded. |
| | """ |
| | assert ( |
| | new_in_channels > self.patch_conv.in_channels |
| | ), "Cannot expand the input channels of the patch convolution layer to be less than the original number of channels." |
| | log.debug( |
| | f"Vision encoder in_channels is {self.patch_conv.in_channels}. But you have specified to be {new_in_channels}. We will change it to {new_in_channels} channels with {new_in_channels - self.patch_conv.in_channels} channels of 0s." |
| | ) |
| | new_conv = nn.Conv2d( |
| | in_channels=new_in_channels, |
| | out_channels=self.patch_conv.out_channels, |
| | kernel_size=self.patch_conv.kernel_size, |
| | stride=self.patch_conv.stride, |
| | bias=False, |
| | ) |
| | new_conv.weight.data[:, : self.patch_conv.in_channels].copy_(self.patch_conv.weight.data) |
| | new_conv.weight.data[ |
| | :, self.patch_conv.in_channels : |
| | ].zero_() |
| | self.patch_conv = new_conv |
| |
|
| | @property |
| | def device(self) -> torch.device: |
| | """Get the device of the model.""" |
| | return next(self.parameters()).device |
| |
|
| | @property |
| | def freqs_cis(self) -> torch.Tensor: |
| | """ |
| | Get or compute the frequency tensor for rotary position embedding. |
| | |
| | This property lazily initializes and caches the frequency tensor used for |
| | rotary position embeddings, ensuring it's on the correct device. |
| | |
| | Returns: |
| | torch.Tensor: The frequency tensor for rotary position embeddings. |
| | """ |
| | if self._freqs_cis is None: |
| | self._freqs_cis = precompute_freqs_cis_2d( |
| | dim=self.dim // self.n_heads, |
| | height=self.max_patches_per_side, |
| | width=self.max_patches_per_side, |
| | theta=self.rope_theta, |
| | ) |
| |
|
| | if self._freqs_cis.device != self.device: |
| | self._freqs_cis = self._freqs_cis.to(device=self.device) |
| |
|
| | return self._freqs_cis |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """ |
| | Forward pass of the Vision Transformer. |
| | |
| | This method processes the input image through the Vision Transformer, |
| | including patch embedding, position embedding, and transformer layers. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is batch size, |
| | C is number of channels, and H, W are height and width. |
| | |
| | Returns: |
| | torch.Tensor: Output features of shape (B, N, D), where N is the number of patches |
| | and D is the embedding dimension. |
| | """ |
| |
|
| | patch_embeds = self.patch_conv(x) |
| | _, _, Hp, Wp = patch_embeds.shape |
| | patch_embeds = patch_embeds.flatten(2) |
| | patch_embeds = patch_embeds.transpose(1, 2) |
| | patch_embeds = self.ln_pre(patch_embeds) |
| | positions = torch.stack( |
| | torch.meshgrid( |
| | torch.arange(Hp), |
| | torch.arange(Wp), |
| | indexing="ij", |
| | ), |
| | dim=-1, |
| | ).reshape(-1, 2) |
| |
|
| | freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] |
| | rope = partial(apply_rotary_emb, freqs_cis=freqs_cis) |
| | out = self.transformer(patch_embeds, rope=rope) |
| |
|
| | return out |
| |
|
| | def get_num_params( |
| | self, |
| | ) -> int: |
| | """ |
| | Return the number of parameters in the model. |
| | """ |
| | n_params = sum(p.numel() for p in self.parameters()) |
| | return n_params |
| |
|
| |
|
| | class VisionTransformerBlocks(nn.Module): |
| | """ |
| | Vision Transformer Blocks. |
| | |
| | This class implements a stack of Transformer blocks used in the Vision Transformer. |
| | |
| | Args: |
| | n_layers (int): Number of transformer layers. |
| | args (Mapping[str, Any]): Arguments for each transformer block, including dimensions, |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | n_layers: int, |
| | args: Mapping[str, Any], |
| | ): |
| | super().__init__() |
| | self.layers = torch.nn.ModuleList() |
| |
|
| | for layer_id in range(n_layers): |
| | self.layers.append( |
| | TransformerBlock( |
| | layer_id=layer_id, |
| | args=args, |
| | ) |
| | ) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | rope: Callable, |
| | ) -> torch.Tensor: |
| | """ |
| | Forward pass through the Vision Transformer Blocks. |
| | |
| | This method applies a series of Transformer blocks to the input tensor, |
| | using the provided rotary position embedding function. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor of shape (B, N, D), where B is batch size, |
| | N is the number of patches, and D is the embedding dimension. |
| | rope (Callable): Rotary position embedding function to be applied in each layer. |
| | |
| | Returns: |
| | torch.Tensor: Output tensor after passing through all transformer layers, |
| | with the same shape as the input. |
| | """ |
| | for layer in self.layers: |
| | x = layer(x, input_pos=None, mask=None, rope=rope) |
| | return x |
| |
|