Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dataclasses import dataclass | |
from .gemma import GemmaConfig, Gemma, KVCache | |
from .siglip import SigLIPConfig, SigLIPVisionTower | |
from typing import Optional | |
import os | |
import json | |
from pathlib import Path | |
from safetensors import safe_open | |
class PaliGemmaConfig: | |
bos_token_id: int = 2 | |
eos_token_id: int = 1 | |
hidden_size: int = 2048 | |
ignore_index: int = -100 | |
image_token_index: int = 257152 | |
pad_token_id: int = 0 | |
projection_dim: int = 2048 | |
text_config: GemmaConfig = None | |
vision_config: SigLIPConfig = None | |
vocab_size: int = 257216 | |
def from_dict(cls, data): | |
return cls( | |
bos_token_id = data['bos_token_id'], | |
eos_token_id = data['eos_token_id'], | |
hidden_size = data['hidden_size'], | |
ignore_index = data['ignore_index'], | |
image_token_index = data['image_token_index'], | |
pad_token_id = data['pad_token_id'], | |
projection_dim = data['projection_dim'], | |
text_config = GemmaConfig.from_dict(data['text_config']), | |
vision_config = SigLIPConfig.from_dict(data['vision_config']) | |
) | |
class PaliGemmaMultimodalProjector(nn.Module): | |
def __init__(self, cfg: PaliGemmaConfig): | |
super().__init__() | |
self.linear = nn.Linear(cfg.vision_config.hidden_size, cfg.vision_config.projection_dim) | |
def forward(self, x: torch.Tensor): | |
x = self.linear(x) | |
return x | |
class PaliGemma(nn.Module): | |
def __init__(self, cfg: PaliGemmaConfig): | |
super().__init__() | |
self.cfg = cfg | |
self.language_model = Gemma(cfg.text_config) | |
self.vision_tower = SigLIPVisionTower(cfg.vision_config) | |
self.multi_modal_projector = PaliGemmaMultimodalProjector(cfg) | |
def tie_weights(self): | |
self.language_model.tie_weights() | |
def _merge_img_embeds_and_input_embeds(self, img_embeds: torch.Tensor, | |
input_embeds: torch.Tensor, | |
input_tokens: torch.Tensor): | |
batch_size, seq_len, embed_dim = input_embeds.shape | |
scaled_img = img_embeds / (self.cfg.hidden_size ** 0.5) | |
final_embeddings = torch.zeros((batch_size, seq_len, embed_dim), dtype=img_embeds.dtype, device=img_embeds.device) | |
# (n, seq_len) | |
text_mask = (input_tokens != self.cfg.pad_token_id) & (input_tokens != self.cfg.image_token_index) | |
img_mask = input_tokens == self.cfg.image_token_index | |
pad_mask = input_tokens == self.cfg.pad_token_id | |
text_mask = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim) | |
img_mask = img_mask.unsqueeze(-1).expand(-1, -1, embed_dim) | |
pad_mask = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim) | |
# (n, seq_len, embed_dim) | |
final_embeddings = torch.where(text_mask, input_embeds, final_embeddings) | |
final_embeddings = final_embeddings.masked_scatter(img_mask, scaled_img) | |
final_embeddings = torch.where(pad_mask, torch.zeros_like(final_embeddings), final_embeddings) | |
return final_embeddings | |
def _create_position_ids_and_attention_mask(self, | |
device: str = '', | |
dtype: torch.dtype = torch.float32, | |
batch_size: int = 32, | |
seq_len: int = 1, | |
attention_mask: Optional[torch.Tensor] = None, | |
kv_cache: Optional[KVCache] = None): | |
# Create Attention Mask | |
if kv_cache is None or kv_cache.num_items() == 0: | |
causal_mask = torch.full((batch_size, seq_len, seq_len), 0, dtype=dtype, device=device) | |
position_ids = attention_mask.cumsum(dim=-1).masked_fill_((attention_mask == 0), 1).to(device) | |
else: | |
assert seq_len == 1 | |
kv_len = kv_cache.num_items() + 1 | |
causal_mask = torch.full((batch_size, 1, kv_len), 0, dtype=dtype, device=device) | |
position_ids = attention_mask.cumsum(dim=-1)[:, -1].to(device) | |
# (n, seq_len, kv_len) -> (n, 1, seq_len, kv_len) | |
causal_mask = causal_mask.unsqueeze(1) | |
return position_ids, causal_mask | |
def from_pretrained(model_dir): | |
with open(os.path.join(model_dir, 'config.json'), "r") as f: | |
model_config = json.loads(f.read()) | |
config = PaliGemmaConfig.from_dict(model_config) | |
safetensor_files = Path(model_dir).glob("*.safetensors") | |
weights = {} | |
for file in safetensor_files: | |
with safe_open(file, framework='pt', device="cpu") as f: | |
for key in f.keys(): | |
weights[key] = f.get_tensor(key) | |
model = PaliGemma(config) | |
model.load_state_dict(weights, strict=False) | |
model.tie_weights() | |
return model | |
def forward(self, *args, **kwargs): | |
# input_tokens: (n, seq_len) | |
# -> (n, seq_len, embed_dim) | |
kv_cache = kwargs['kv_cache'] if 'kv_cache' in kwargs else None | |
input_tokens = kwargs['input_ids'] | |
pixel_values = kwargs['pixel_values'] if 'pixel_values' in kwargs else None | |
attention_mask = kwargs['attention_mask'] | |
input_embeds = self.language_model.model.embed_tokens(input_tokens) | |
if pixel_values is not None: | |
img_embeds = self.vision_tower(pixel_values.to(input_embeds.dtype)) | |
img_embeds = self.multi_modal_projector(img_embeds) | |
final_embeddings = self._merge_img_embeds_and_input_embeds(img_embeds=img_embeds, | |
input_embeds=input_embeds, | |
input_tokens=input_tokens) | |
else: | |
final_embeddings = input_embeds | |
position_ids, causal_mask = self._create_position_ids_and_attention_mask(device=final_embeddings.device.type, | |
dtype=final_embeddings.dtype, | |
batch_size=final_embeddings.shape[0], | |
seq_len=final_embeddings.shape[1], | |
attention_mask=attention_mask, | |
kv_cache=kv_cache) | |
outputs, kv_cache = self.language_model( | |
input_embeds=final_embeddings, | |
position_ids=position_ids, | |
attention_mask=causal_mask, | |
kv_cache=kv_cache | |
) | |
return outputs, kv_cache | |