Spaces:
Runtime error
Runtime error
File size: 7,150 Bytes
d8bb2be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from .gemma import GemmaConfig, Gemma, KVCache
from .siglip import SigLIPConfig, SigLIPVisionTower
from typing import Optional
import os
import json
from pathlib import Path
from safetensors import safe_open
@dataclass
class PaliGemmaConfig:
bos_token_id: int = 2
eos_token_id: int = 1
hidden_size: int = 2048
ignore_index: int = -100
image_token_index: int = 257152
pad_token_id: int = 0
projection_dim: int = 2048
text_config: GemmaConfig = None
vision_config: SigLIPConfig = None
vocab_size: int = 257216
@classmethod
def from_dict(cls, data):
return cls(
bos_token_id = data['bos_token_id'],
eos_token_id = data['eos_token_id'],
hidden_size = data['hidden_size'],
ignore_index = data['ignore_index'],
image_token_index = data['image_token_index'],
pad_token_id = data['pad_token_id'],
projection_dim = data['projection_dim'],
text_config = GemmaConfig.from_dict(data['text_config']),
vision_config = SigLIPConfig.from_dict(data['vision_config'])
)
class PaliGemmaMultimodalProjector(nn.Module):
def __init__(self, cfg: PaliGemmaConfig):
super().__init__()
self.linear = nn.Linear(cfg.vision_config.hidden_size, cfg.vision_config.projection_dim)
def forward(self, x: torch.Tensor):
x = self.linear(x)
return x
class PaliGemma(nn.Module):
def __init__(self, cfg: PaliGemmaConfig):
super().__init__()
self.cfg = cfg
self.language_model = Gemma(cfg.text_config)
self.vision_tower = SigLIPVisionTower(cfg.vision_config)
self.multi_modal_projector = PaliGemmaMultimodalProjector(cfg)
def tie_weights(self):
self.language_model.tie_weights()
def _merge_img_embeds_and_input_embeds(self, img_embeds: torch.Tensor,
input_embeds: torch.Tensor,
input_tokens: torch.Tensor):
batch_size, seq_len, embed_dim = input_embeds.shape
scaled_img = img_embeds / (self.cfg.hidden_size ** 0.5)
final_embeddings = torch.zeros((batch_size, seq_len, embed_dim), dtype=img_embeds.dtype, device=img_embeds.device)
# (n, seq_len)
text_mask = (input_tokens != self.cfg.pad_token_id) & (input_tokens != self.cfg.image_token_index)
img_mask = input_tokens == self.cfg.image_token_index
pad_mask = input_tokens == self.cfg.pad_token_id
text_mask = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
img_mask = img_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
pad_mask = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
# (n, seq_len, embed_dim)
final_embeddings = torch.where(text_mask, input_embeds, final_embeddings)
final_embeddings = final_embeddings.masked_scatter(img_mask, scaled_img)
final_embeddings = torch.where(pad_mask, torch.zeros_like(final_embeddings), final_embeddings)
return final_embeddings
def _create_position_ids_and_attention_mask(self,
device: str = '',
dtype: torch.dtype = torch.float32,
batch_size: int = 32,
seq_len: int = 1,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None):
# Create Attention Mask
if kv_cache is None or kv_cache.num_items() == 0:
causal_mask = torch.full((batch_size, seq_len, seq_len), 0, dtype=dtype, device=device)
position_ids = attention_mask.cumsum(dim=-1).masked_fill_((attention_mask == 0), 1).to(device)
else:
assert seq_len == 1
kv_len = kv_cache.num_items() + 1
causal_mask = torch.full((batch_size, 1, kv_len), 0, dtype=dtype, device=device)
position_ids = attention_mask.cumsum(dim=-1)[:, -1].to(device)
# (n, seq_len, kv_len) -> (n, 1, seq_len, kv_len)
causal_mask = causal_mask.unsqueeze(1)
return position_ids, causal_mask
@staticmethod
def from_pretrained(model_dir):
with open(os.path.join(model_dir, 'config.json'), "r") as f:
model_config = json.loads(f.read())
config = PaliGemmaConfig.from_dict(model_config)
safetensor_files = Path(model_dir).glob("*.safetensors")
weights = {}
for file in safetensor_files:
with safe_open(file, framework='pt', device="cpu") as f:
for key in f.keys():
weights[key] = f.get_tensor(key)
model = PaliGemma(config)
model.load_state_dict(weights, strict=False)
model.tie_weights()
return model
def forward(self, *args, **kwargs):
# input_tokens: (n, seq_len)
# -> (n, seq_len, embed_dim)
kv_cache = kwargs['kv_cache'] if 'kv_cache' in kwargs else None
input_tokens = kwargs['input_ids']
pixel_values = kwargs['pixel_values'] if 'pixel_values' in kwargs else None
attention_mask = kwargs['attention_mask']
input_embeds = self.language_model.model.embed_tokens(input_tokens)
if pixel_values is not None:
img_embeds = self.vision_tower(pixel_values.to(input_embeds.dtype))
img_embeds = self.multi_modal_projector(img_embeds)
final_embeddings = self._merge_img_embeds_and_input_embeds(img_embeds=img_embeds,
input_embeds=input_embeds,
input_tokens=input_tokens)
else:
final_embeddings = input_embeds
position_ids, causal_mask = self._create_position_ids_and_attention_mask(device=final_embeddings.device.type,
dtype=final_embeddings.dtype,
batch_size=final_embeddings.shape[0],
seq_len=final_embeddings.shape[1],
attention_mask=attention_mask,
kv_cache=kv_cache)
outputs, kv_cache = self.language_model(
input_embeds=final_embeddings,
position_ids=position_ids,
attention_mask=causal_mask,
kv_cache=kv_cache
)
return outputs, kv_cache
|