File size: 9,408 Bytes
f2c2a4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
import json
import os
import tempfile
from dataclasses import asdict
from typing import Optional
from models.vision_transformer import ViT
from models.language_model import LanguageModel
from models.modality_projector import ModalityProjector
from models.config import VLMConfig
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors.torch import load_model, save_model
class VisionLanguageModel(nn.Module):
def __init__(self, cfg: VLMConfig, load_backbone=True):
super().__init__()
self.cfg = cfg
if load_backbone:
print("Loading from backbone weights")
self.vision_encoder = ViT.from_pretrained(cfg)
self.decoder = LanguageModel.from_pretrained(cfg)
else:
self.vision_encoder = ViT(cfg)
self.decoder = LanguageModel(cfg)
self.MP = ModalityProjector(cfg)
self.load_backbone = load_backbone
def forward(self, input_ids, image, attention_mask=None, targets=None):
image_embd = self.vision_encoder(image)
image_embd = self.MP(image_embd)
token_embd = self.decoder.token_embedding(input_ids)
combined_embd = torch.cat((image_embd, token_embd), dim=1) # Concatenate image embeddings to token embeddings
# Adjust attention mask to account for image tokens
if attention_mask is not None:
# Create mask of 1s for image tokens (all image tokens should be attended to)
batch_size = image_embd.size(0)
img_seq_len = image_embd.size(1)
image_attention_mask = torch.ones((batch_size, img_seq_len), device=attention_mask.device, dtype=attention_mask.dtype)
# Combine image and token attention masks
attention_mask = torch.cat((image_attention_mask, attention_mask), dim=1)
logits = self.decoder(combined_embd, attention_mask) # Not logits yet, but easier to return like this
loss = None
if targets is not None:
# Only use the token part of the logits for loss computation
logits = self.decoder.head(logits)
logits = logits[:, image_embd.size(1):, :]
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1), ignore_index=-100)
return logits, loss
@torch.no_grad()
def generate(self, input_ids, image, attention_mask=None, max_new_tokens=5):
# Process image through vision encoder and projection
image_embd = self.vision_encoder(image)
image_embd = self.MP(image_embd)
# Embed initial tokens
token_embd = self.decoder.token_embedding(input_ids)
# Concatenate image embeddings with token embeddings
combined_embd = torch.cat((image_embd, token_embd), dim=1)
batch_size = image_embd.size(0)
img_seq_len = image_embd.size(1)
# Adjust attention mask to account for image tokens
if attention_mask is not None:
# Create mask of 1s for image tokens (all image tokens should be attended to)
image_attention_mask = torch.ones((batch_size, img_seq_len), device=attention_mask.device, dtype=attention_mask.dtype)
attention_mask = torch.cat((image_attention_mask, attention_mask), dim=1)
# Generate from combined embeddings using the decoder
# We need to use the decoder's forward function and not its generate method
# because we want to keep track of the image prefix
outputs = combined_embd
generated_tokens = torch.zeros((batch_size, max_new_tokens), device=input_ids.device, dtype=input_ids.dtype)
#Note: Here you could implement improvements like e.g. KV caching
for i in range(max_new_tokens):
model_out = self.decoder(outputs, attention_mask)
# Get predictions for the last token only (normally this is the embedding, not the logits)
last_token_logits = model_out[:, -1, :]
# Apply head to get logits (if model is in embedding mode)
if not self.decoder.lm_use_tokens:
last_token_logits = self.decoder.head(last_token_logits)
probs = torch.softmax(last_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(-1)
# Convert to embedding and append
next_embd = self.decoder.token_embedding(next_token)
outputs = torch.cat((outputs, next_embd), dim=1)
if attention_mask is not None:
attention_mask = torch.cat((attention_mask, torch.ones((batch_size, 1), device=attention_mask.device)), dim=1)
return generated_tokens
@classmethod
def from_pretrained(
cls, repo_id_or_path: str, *, revision: Optional[str] = None
) -> "VisionLanguageModel":
"""
Load a VisionLanguageModel from a local directory or a repo on the Hugging Face Hub.
Args:
repo_id_or_path (str): The path to the local directory or the Hugging Face Hub repo ID.
Returns:
VisionLanguageModel: The loaded model.
"""
# If local folder exists => load from there
if os.path.exists(repo_id_or_path):
config_path = os.path.join(repo_id_or_path, "config.json")
weights_path = os.path.join(repo_id_or_path, "model.safetensors")
if not os.path.exists(config_path):
raise ValueError(
f"Config file not found at {config_path}. Please provide a valid path."
)
if not os.path.exists(weights_path):
raise ValueError(
f"Weights file not found at {weights_path}. Please provide a valid path."
)
# Otherwise, assume it's a Hugging Face Hub repo
else:
from huggingface_hub import hf_hub_download
config_path = hf_hub_download(
repo_id=repo_id_or_path, filename="config.json", revision=revision
)
weights_path = hf_hub_download(
repo_id=repo_id_or_path, filename="model.safetensors", revision=revision
)
# Load config
with open(config_path, "r") as f:
cfg = VLMConfig(**json.load(f))
# Initialize model without loading the backbone
model = cls(cfg, load_backbone=False)
# Load safetensors weights
load_model(model, weights_path)
# Done!
return model
def save_pretrained(self, save_directory: str) -> None:
"""
Save the model and configuration to a directory.
Args:
save_directory (str): The directory to save the model and config.
"""
# Create directory if it doesn't exist
os.makedirs(save_directory, exist_ok=True)
# Save config
with open(os.path.join(save_directory, "config.json"), "w") as f:
f.write(json.dumps(asdict(self.cfg), indent=4))
# Save weights as safetensors
save_model(self, os.path.join(save_directory, "model.safetensors"))
def push_to_hub(self, repo_id: str, private: bool = False) -> None:
"""
Push the model and configuration to the Hugging Face Hub.
Args:
repo_id (str): The repo ID on the Hugging Face Hub.
"""
from huggingface_hub import create_repo, upload_folder
# Create repo
repo_url = create_repo(repo_id=repo_id, private=private, exist_ok=True)
repo_id = repo_url.repo_id
print("Created repo: ", repo_url)
with tempfile.TemporaryDirectory() as save_path:
# Save to tmp directory
self.save_pretrained(save_path)
# Save model card
with open(os.path.join(save_path, "README.md"), "w") as f:
f.write(MODEL_CARD_TEMPLATE.format(repo_id=repo_id))
# Upload
return upload_folder(
repo_id=repo_id,
repo_type="model",
folder_path=save_path,
commit_message="Upload nanoVLM using push_to_hub",
)
MODEL_CARD_TEMPLATE = """
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
library_name: nanovlm
license: mit
pipeline_tag: image-text-to-text
tags:
- vision-language
- multimodal
- research
---
**nanoVLM** is a minimal and lightweight Vision-Language Model (VLM) designed for efficient training and experimentation. Built using pure PyTorch, the entire model architecture and training logic fits within ~750 lines of code. It combines a ViT-based image encoder (SigLIP-B/16-224-85M) with a lightweight causal language model (SmolLM2-135M), resulting in a compact 222M parameter model.
For more information, check out the base model on https://huggingface.co/lusxvr/nanoVLM-222M.
**Usage:**
Clone the nanoVLM repository: https://github.com/huggingface/nanoVLM.
Follow the install instructions and run the following code:
```python
from models.vision_language_model import VisionLanguageModel
model = VisionLanguageModel.from_pretrained("{repo_id}")
```
"""
|