Llama-3.2-1B-Vision-Caption / llama_custom_vision.py
HV-Khurdula's picture
Create llama_custom_vision.py
5363707 verified
import torch
import torch.nn as nn
import math
from transformers import (
LlamaPreTrainedModel,
LlamaConfig,
LlamaTokenizer,
ViTModel,
ViTImageProcessor,
AutoTokenizer,
LlamaForCausalLM,
)
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaAttention,
LlamaRMSNorm,
)
from typing import Optional, Tuple
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm
from IPython.display import clear_output, display
import torch.optim as optim
from torch.amp import autocast, GradScaler
from transformers import get_linear_schedule_with_warmup
# Custom Cross-Attention Module
class CustomCrossAttention(LlamaAttention):
def __init__(self, config, layer_idx, num_key_value_heads=None):
super().__init__(config, layer_idx=layer_idx)
self.layer_idx = layer_idx
self.num_key_value_heads = num_key_value_heads or config.num_attention_heads
self.k_proj = nn.Linear(
config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
config.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.q_proj = nn.Linear(
config.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, config.hidden_size, bias=False
)
# Removed the incorrect line:
# self.rotary_emb = self.rotary_emb
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
use_cache=False,
):
if encoder_hidden_states is None:
raise ValueError(
f"Cross-attention layer {self.layer_idx} requires encoder_hidden_states"
)
bsz, tgt_len, _ = hidden_states.size()
src_len = encoder_hidden_states.size(1)
# Project hidden_states to query states
query_states = self.q_proj(hidden_states).view(
bsz, tgt_len, self.num_heads, self.head_dim
).transpose(1, 2)
# Project encoder_hidden_states to key and value states
key_states = self.k_proj(encoder_hidden_states).view(
bsz, src_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(
bsz, src_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# Expand key/value states to match num_heads
if self.num_key_value_heads != self.num_heads:
key_states = key_states.repeat_interleave(
self.num_heads // self.num_key_value_heads, dim=1
)
value_states = value_states.repeat_interleave(
self.num_heads // self.num_key_value_heads, dim=1
)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if use_cache:
next_past_key_value = (key_states, value_states)
else:
next_past_key_value = None
# Compute attention
attn_weights = torch.matmul(
query_states, key_states.transpose(-1, -2)
) / math.sqrt(self.head_dim)
if encoder_attention_mask is not None:
attn_weights = attn_weights + encoder_attention_mask
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = (
attn_output.transpose(1, 2).contiguous().view(bsz, tgt_len, -1)
)
attn_output = self.o_proj(attn_output)
outputs = (attn_output,)
if output_attentions:
outputs += (attn_weights,)
if use_cache:
outputs += (next_past_key_value,)
else:
outputs += (None,)
if output_attentions:
outputs += (attn_weights,)
return outputs
# Custom LLaMA Decoder Layer
class CustomLlamaDecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.self_attn = LlamaAttention(config, layer_idx=layer_idx)
self.mlp = LlamaDecoderLayer(config, layer_idx=layer_idx).mlp
self.input_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.cross_attn = CustomCrossAttention(config, layer_idx=layer_idx)
self.cross_attn_layer_norm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
use_cache=False,
):
if encoder_hidden_states is None:
raise ValueError(f"Cross-attention layer {self.layer_idx} requires encoder_hidden_states")
#print(f"Layer {self.layer_idx}: Received encoder_hidden_states with shape {encoder_hidden_states.shape}")
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self-Attention
self_attn_past_key_value = (
past_key_value[:2] if past_key_value is not None else None
)
self_attn_outputs = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + self_attn_outputs[0]
# Cross-Attention
residual = hidden_states
hidden_states = self.cross_attn_layer_norm(hidden_states)
cross_attn_past_key_value = (
past_key_value[2:] if past_key_value is not None else None
)
cross_attn_outputs = self.cross_attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
position_ids=position_ids,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + cross_attn_outputs[0]
# Feed Forward Network
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
present_key_value = None
if use_cache:
# Handle cases where past_key_value might be None
self_pkv = self_attn_outputs[2] if len(self_attn_outputs) > 2 else None
cross_pkv = cross_attn_outputs[2] if len(cross_attn_outputs) > 2 else None
if self_pkv is not None and cross_pkv is not None:
present_key_value = self_pkv + cross_pkv
elif self_pkv is not None:
present_key_value = self_pkv + (None, None)
elif cross_pkv is not None:
present_key_value = (None, None) + cross_pkv
else:
present_key_value = None
outputs = (hidden_states,)
if use_cache:
outputs += (present_key_value,)
else:
outputs += (None,)
if output_attentions:
attn_weights = {}
attn_weights["self_attn"] = self_attn_outputs[1]
attn_weights["cross_attn"] = cross_attn_outputs[1]
outputs += (attn_weights,)
return outputs
# Custom LLaMA Model
class CustomLlamaModel(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.layers = nn.ModuleList(
[
CustomLlamaDecoderLayer(config, layer_idx=i)
for i in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.post_init()
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, hidden_states, past_key_values_length
):
# create causal mask
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
device = hidden_states.device
tgt_len = input_shape[-1]
src_len = tgt_len + past_key_values_length
# Causal mask
causal_mask = self._make_causal_mask(
input_shape, device, past_key_values_length
)
if attention_mask is not None:
expanded_attn_mask = self._expand_mask(
attention_mask, hidden_states.dtype, tgt_len=tgt_len
)
combined_attention_mask = expanded_attn_mask + causal_mask
else:
combined_attention_mask = causal_mask
return combined_attention_mask
def _expand_mask(self, mask, dtype, tgt_len=None):
"""
Expands attention_mask from `[batch_size, seq_len]` to
`[batch_size, 1, tgt_seq_len, src_seq_len]`
"""
batch_size, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(batch_size, 1, tgt_len, src_len)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.to(dtype=dtype)
def _make_causal_mask(self, input_shape, device, past_key_values_length):
batch_size, tgt_len = input_shape
total_len = tgt_len + past_key_values_length
mask = torch.tril(torch.ones((total_len, total_len), device=device))
if past_key_values_length > 0:
mask = mask[past_key_values_length:, :]
return mask[None, None, :, :]
def forward(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
use_cache=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=True,
):
if input_ids is not None:
batch_size, seq_length = input_ids.size()
else:
raise ValueError("You have to specify input_ids")
if past_key_values is None:
past_key_values = [None] * len(self.layers)
past_key_values_length = 0
else:
# Check if past_key_values[0] and past_key_values[0][0] are not None
if past_key_values[0] is not None and past_key_values[0][0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
else:
past_key_values_length = 0
# Generate position_ids if None
if position_ids is None:
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length,
dtype=torch.long, device=input_ids.device
)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
hidden_states = self.embed_tokens(input_ids)
# Prepare attention mask
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length), device=input_ids.device
)
combined_attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=combined_attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
past_key_values[idx] = layer_outputs[1]
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
outputs = (hidden_states,)
if use_cache:
outputs += (past_key_values,)
if output_hidden_states:
outputs += (all_hidden_states,)
return outputs
return {
"last_hidden_state": hidden_states,
"past_key_values": past_key_values if use_cache else None,
"hidden_states": all_hidden_states,
}
# Custom LLaMA for Conditional Generation
# Custom LLaMA for Conditional Generation
class CustomLlamaForConditionalGeneration(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.model = CustomLlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=True,
position_ids=None, # Added position_ids parameter
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
position_ids=position_ids, # Pass position_ids to the model
)
hidden_states = outputs["last_hidden_state"]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
loss_fct = nn.CrossEntropyLoss(ignore_index=-100) # Removed label_smoothing
# Reshape logits to [batch_size * max_length, vocab_size]
logits_flat = logits.view(-1, logits.size(-1))
# Reshape labels to [batch_size * max_length]
labels_flat = labels.view(-1)
loss = loss_fct(logits_flat, labels_flat)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return {
"loss": loss,
"logits": logits,
"past_key_values": outputs["past_key_values"],
"hidden_states": outputs["hidden_states"],
}
def generate_caption(
self,
pixel_values,
max_length=16,
temperature=0.7,
top_k=50,
top_p=0.9,
):
self.eval()
with torch.no_grad():
# Encode the image
encoder_outputs = self.model.vit_model(pixel_values=pixel_values)
encoder_hidden_states = self.model.image_proj(encoder_outputs.last_hidden_state)
encoder_hidden_states = self.model.image_layer_norm(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states * self.model.scale
encoder_hidden_states = self.model.image_activation(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.model.image_residual(encoder_hidden_states)
# Initialize generated_ids with bos_token_id
generated_ids = torch.full(
(pixel_values.size(0), 1),
self.tokenizer.bos_token_id,
dtype=torch.long,
device=pixel_values.device,
)
past_key_values = None
for _ in range(max_length):
# Generate position_ids based on the current sequence length
position_ids = torch.arange(
generated_ids.size(1),
dtype=torch.long,
device=pixel_values.device
).unsqueeze(0).expand(generated_ids.size(0), -1)
outputs = self.model(
input_ids=generated_ids,
attention_mask=torch.ones_like(generated_ids).to(generated_ids.device),
encoder_hidden_states=encoder_hidden_states,
use_cache=True,
past_key_values=past_key_values,
return_dict=True,
position_ids=position_ids, # Pass full position_ids
)
logits = outputs["logits"][:, -1, :] / temperature
past_key_values = outputs["past_key_values"]
# Apply top-k and top-p filtering
filtered_logits = self.top_k_top_p_filtering(
logits, top_k=top_k, top_p=top_p
)
probabilities = nn.functional.softmax(filtered_logits, dim=-1)
next_token_id = torch.multinomial(probabilities, num_samples=1)
generated_ids = torch.cat((generated_ids, next_token_id), dim=-1)
if next_token_id.item() == self.tokenizer.eos_token_id:
break
captions = [
self.tokenizer.decode(generated_id, skip_special_tokens=True)
for generated_id in generated_ids
]
return captions
def top_k_top_p_filtering(self, logits, top_k=0, top_p=1.0, filter_value=-float("Inf")):
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
"""
# Clone logits to avoid in-place modifications
logits = logits.clone()
# Batch size and vocabulary size
batch_size, vocab_size = logits.size()
# Top-K Filtering
if top_k > 0:
top_k = min(max(top_k, 1), vocab_size) # Safety check
# Get top-k indices
top_k_values, _ = torch.topk(logits, top_k, dim=-1)
min_top_k = top_k_values[:, -1, None]
# Create a mask for logits less than the top-k threshold
indices_to_remove = logits < min_top_k
logits[indices_to_remove] = filter_value
# Nucleus (Top-p) Filtering
if top_p < 1.0:
# Sort logits in descending order
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(
nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
# Create a mask for cumulative probabilities exceeding top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the mask to include the first token above the threshold
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
# Scatter the mask back to the original indexing
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
# Multi-Layer Image Projection with Residual and LayerNorm
class ImageProjection(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.proj = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim),
nn.GELU()
)
def forward(self, x):
return self.proj(x)
class ImageCaptioningModel(nn.Module):
def __init__(
self, vit_model, llama_config, llama_model, tokenizer,
):
super().__init__()
self.vit_model = vit_model
self.tokenizer = tokenizer
self.llama_model = llama_model
# Enhanced Image Projection Layer
self.image_proj = ImageProjection(
input_dim=vit_model.config.hidden_size,
hidden_dim=llama_config.hidden_size * 2, # Increased capacity
output_dim=llama_config.hidden_size
)
nn.init.xavier_uniform_(self.image_proj.proj[0].weight, gain=math.sqrt(2))
nn.init.zeros_(self.image_proj.proj[0].bias)
nn.init.xavier_uniform_(self.image_proj.proj[2].weight, gain=math.sqrt(2))
nn.init.zeros_(self.image_proj.proj[2].bias)
# Layer Normalization
self.image_layer_norm = nn.LayerNorm(llama_config.hidden_size)
# Scaling Factor
self.scale = nn.Parameter(torch.ones(1, 1, llama_config.hidden_size) * 5.0) # Further increased scaling
# Non-linear activation
self.image_activation = nn.GELU()
# Residual Connection
self.image_residual = nn.Linear(llama_config.hidden_size, llama_config.hidden_size)
nn.init.xavier_uniform_(self.image_residual.weight, gain=1.0)
nn.init.zeros_(self.image_residual.bias)
def forward(self, pixel_values, input_ids=None, attention_mask=None, labels=None):
# Encode the image
encoder_outputs = self.vit_model(pixel_values=pixel_values)
encoder_hidden_states = self.image_proj(encoder_outputs.last_hidden_state)
encoder_hidden_states = self.image_layer_norm(encoder_hidden_states) # Apply LayerNorm
encoder_hidden_states = encoder_hidden_states * self.scale # Apply increased scaling
encoder_hidden_states = self.image_activation(encoder_hidden_states) # Apply GELU
encoder_hidden_states = encoder_hidden_states + self.image_residual(encoder_hidden_states) # Apply residual
# Generate position_ids for the decoder if needed
if input_ids is not None:
batch_size, seq_length = input_ids.size()
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
else:
position_ids = None
outputs = self.llama_model(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None, # Modify if you have encoder attention mask
labels=labels,
use_cache=False,
return_dict=True,
position_ids=position_ids, # Pass position_ids to the llama_model
)
return outputs
def top_k_top_p_filtering(self, logits, top_k=0, top_p=1.0, filter_value=-float("Inf")):
"""
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
"""
# Clone logits to avoid in-place modifications
logits = logits.clone()
# Batch size and vocabulary size
batch_size, vocab_size = logits.size()
# Top-K Filtering
if top_k > 0:
top_k = min(max(top_k, 1), vocab_size) # Safety check
# Get top-k indices
top_k_values, _ = torch.topk(logits, top_k, dim=-1)
min_top_k = top_k_values[:, -1, None]
# Create a mask for logits less than the top-k threshold
indices_to_remove = logits < min_top_k
logits[indices_to_remove] = filter_value
# Nucleus (Top-p) Filtering
if top_p < 1.0:
# Sort logits in descending order
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(
nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
# Create a mask for cumulative probabilities exceeding top_p
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the mask to include the first token above the threshold
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
# Scatter the mask back to the original indexing
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=-1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
def generate_caption(
self,
pixel_values,
max_length=16,
temperature=0.7,
top_k=50,
top_p=0.9,
):
self.eval()
with torch.no_grad():
# Encode the image
encoder_outputs = self.vit_model(pixel_values=pixel_values)
encoder_hidden_states = self.image_proj(encoder_outputs.last_hidden_state)
encoder_hidden_states = self.image_layer_norm(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states * self.scale
encoder_hidden_states = self.image_activation(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.image_residual(encoder_hidden_states)
# Initialize generated_ids with bos_token_id
generated_ids = torch.full(
(pixel_values.size(0), 1),
self.tokenizer.bos_token_id,
dtype=torch.long,
device=pixel_values.device,
)
past_key_values = None
for _ in range(max_length):
# Generate position_ids based on the current sequence length
position_ids = torch.arange(
generated_ids.size(1),
dtype=torch.long,
device=pixel_values.device
).unsqueeze(0).expand(generated_ids.size(0), -1)
outputs = self.llama_model(
input_ids=generated_ids,
attention_mask=torch.ones_like(generated_ids).to(generated_ids.device),
encoder_hidden_states=encoder_hidden_states,
use_cache=True,
past_key_values=past_key_values,
return_dict=True,
position_ids=position_ids, # Pass full position_ids
)
logits = outputs["logits"][:, -1, :] / temperature
past_key_values = outputs["past_key_values"]
# Apply top-k and top-p filtering
filtered_logits = self.top_k_top_p_filtering(
logits, top_k=top_k, top_p=top_p
)
probabilities = nn.functional.softmax(filtered_logits, dim=-1)
next_token_id = torch.multinomial(probabilities, num_samples=1)
generated_ids = torch.cat((generated_ids, next_token_id), dim=-1)
if next_token_id.item() == self.tokenizer.eos_token_id:
break
captions = [
self.tokenizer.decode(generated_id, skip_special_tokens=True)
for generated_id in generated_ids
]
return captions