v1-7B / modeling_v1.py
kjunh's picture
Support Transformers AutoModel
b7d5f5c
import math
from typing import Optional, Union, Tuple, List
from dataclasses import dataclass
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLModel,
Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLCausalLMOutputWithPast,
)
from .configuration_v1 import V1Config
def init_identity(layer, scale: float = 1):
if isinstance(layer, nn.Linear):
with torch.no_grad():
# Ensure weight matrix is square
rows, cols = layer.weight.shape
identity_matrix = (
torch.eye(rows, cols) * scale
) # Creates an identity matrix
layer.weight.copy_(
identity_matrix
) # Copy identity matrix into layer weights
if hasattr(layer, "bias"):
layer.bias.fill_(0) # Set bias to zero (or another value if needed)
@dataclass
class V1CausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
z_loss: torch.Tensor = None
gen_loss: torch.Tensor = None
copy_loss: torch.Tensor = None
class V1ForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
config_class = V1Config
def __init__(self, config):
super().__init__(config)
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(
config.vision_config
)
self.model = Qwen2_5_VLModel(config)
self.copy_init_scale = 1 / math.sqrt(self.config.hidden_size)
# self.tokenizer_vocab_size = (
# config.tokenizer_vocab_size
# ) # Qwen2.5-VL: different from embedding_size==vocab_size. 151665 vs. 152064
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.rope_deltas = None # cache rope_deltas here
if self.config.do_copy:
if self.config.tie_copy_heads:
self._copy_head = nn.Linear(config.hidden_size, config.copy_hidden_size)
else:
self._copy_q_head = nn.Linear(
config.hidden_size, config.copy_hidden_size
)
self._copy_k_head = nn.Linear(
config.hidden_size, config.copy_hidden_size
)
if self.config.use_gate:
self.gate = nn.Linear(config.hidden_size, 1, bias=False)
# Initialize weights and apply final processing
self.post_init()
@torch.no_grad()
def after_loading(self):
if self.config.do_copy:
self.init_heads()
if self.config.use_gate:
self.lm_head.weight.data = self.lm_head.weight.data * 2
self.gate.weight.data.fill_(0)
@property
def copy_q_head(self):
return self._copy_head if self.config.tie_copy_heads else self._copy_q_head
@property
def copy_k_head(self):
return self._copy_head if self.config.tie_copy_heads else self._copy_k_head
def init_heads(self):
if hasattr(self, "_copy_head"):
init_identity(self._copy_head, self.copy_init_scale)
if hasattr(self, "_copy_k_head"):
init_identity(self._copy_k_head, self.copy_init_scale)
if hasattr(self, "_copy_q_head"):
init_identity(self._copy_q_head, self.copy_init_scale)
def copy_representations(
self,
inputs_embeds: torch.FloatTensor,
input_ids: torch.LongTensor,
copy_values: Optional[torch.FloatTensor] = None,
):
if copy_values is None:
mask = input_ids == self.config.image_token_id
copy_values, _ = self.extract_image_tokens(inputs_embeds, mask) # initial
assert copy_values is not None
copy_values = copy_values.to(inputs_embeds.device)
input_ids = input_ids.to(inputs_embeds.device)
input_ids = input_ids.clone()
input_ids = input_ids - self.config.copy_token_start
copy_mask = input_ids >= 0
input_ids[~copy_mask] = 0
assert copy_values is not None
extracted = copy_values.gather(
1, input_ids[..., None].repeat(1, 1, copy_values.shape[-1])
)
copy_mask = copy_mask.to(extracted.dtype)[..., None]
return copy_mask * extracted + (1 - copy_mask) * inputs_embeds
def extract_image_tokens(self, features: torch.FloatTensor, mask: torch.Tensor):
out_feat, out_mask = extract_image_tokens_right_pad(features, mask)
return out_feat, out_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What is shown in this image?"},
],
},
]
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
input_ids = input_ids.clone()
input_ids_with_ptrs = input_ids.clone()
input_ids[input_ids >= self.config.copy_token_start] = (
self.config.region_token_id
)
if inputs_embeds is None:
inputs_embeds = self.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)
image_embeds = image_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
if pixel_values_videos is not None:
raise NotImplementedError("video inputs are not supported yet.")
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
n_video_features = video_embeds.shape[0]
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
mask = input_ids == self.config.video_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
video_mask = mask_expanded.to(inputs_embeds.device)
video_embeds = video_embeds.to(
inputs_embeds.device, inputs_embeds.dtype
)
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
if self.config.do_copy:
copy_keys, copy_keys_mask = None, None
copy_values, copy_values_mask = None, None
has_cache = bool(past_key_values)
if has_cache:
copy_keys, copy_values = past_key_values[len(past_key_values) - 2]
copy_keys_mask, copy_values_mask = past_key_values[
len(past_key_values) - 1
]
# we add channel dim to the mask for consistency in tensor shape in cache
copy_keys_mask = copy_keys_mask[..., 0]
copy_values_mask = copy_values_mask[..., 0]
inputs_embeds = self.copy_representations(
inputs_embeds, input_ids_with_ptrs, copy_values
)
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
if position_ids is None and (
attention_mask is None or attention_mask.ndim == 2
):
# calculate RoPE index once per generation in the pre-fill stage only
if (
(cache_position is not None and cache_position[0] == 0)
or self.rope_deltas is None
or (past_key_values is None or past_key_values.get_seq_length() == 0)
):
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts,
attention_mask,
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = (
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
if cache_position is not None
else 0
)
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
outputs = self.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
hidden_states = outputs[0]
gen_logits = self.lm_head(hidden_states)
if self.config.do_copy:
assert (
self.config.copy_extraction_layer == -1
), f"copy_extraction_layer should be -1: {self.config.copy_extraction_layer}"
copy_hidden_states = hidden_states
copy_q_states = copy_hidden_states
if self.config.normalize_copy_states:
copy_q_states = F.normalize(copy_q_states, 2, -1)
copy_q_states = self.copy_q_head(copy_q_states)
present_key_values = outputs.past_key_values
if not has_cache:
mask = input_ids == self.config.image_token_id
copy_k_states = (
inputs_embeds
if self.config.use_embeddings_as_keys
else copy_hidden_states
)
if self.config.normalize_copy_states:
copy_k_states = F.normalize(copy_k_states, 2, -1)
copy_k_states, copy_k_mask = self.extract_image_tokens(
self.copy_k_head(copy_k_states), mask
)
copy_v_states, copy_v_mask = self.extract_image_tokens(
inputs_embeds.detach(), mask
)
# we add channel dim to the mask for consistency in tensor shape in cache
copy_memories = [
(copy_k_states.detach(), copy_v_states.detach()),
(copy_k_mask[..., None], copy_v_mask[..., None]),
]
if use_cache:
# only update at the first iteration
start = len(present_key_values)
for i, mem in enumerate(copy_memories):
present_key_values.update(*mem, start + i)
else:
copy_k_states = copy_keys
copy_k_mask = copy_keys_mask
assert copy_k_states is not None
assert copy_k_mask is not None
assert (
copy_k_states.shape[1] > 0
), f"zero image tokens on batch elements: {copy_k_mask.sum(dim=1)}"
copy_logits = (copy_q_states @ copy_k_states.transpose(-1, -2)).to(
gen_logits.device
) * self.copy_init_scale
if hasattr(self, "gate"):
gate = torch.sigmoid(self.gate(hidden_states))
gen_logits = gen_logits * (1 - gate)
copy_logits = copy_logits * gate
copy_logits = copy_logits.masked_fill(
~copy_k_mask[:, None, :].to(copy_logits.device),
torch.finfo(copy_logits.dtype).min,
)
logits = torch.cat(
[gen_logits[..., : self.config.copy_token_start], copy_logits], dim=-1
)
else:
logits = gen_logits
loss = None
z_loss = None
gen_loss = None
if labels is not None:
gen_logits = gen_logits.float()
shift_gen_logits = gen_logits[:, :-1, :].contiguous().float()
shift_labels = labels[:, 1:].contiguous()
gen_loss_fct = CrossEntropyLoss(reduction="none")
gen_logits_flat = shift_gen_logits.view(-1, shift_gen_logits.shape[-1])
gen_labels_flat = shift_labels.view(-1)
gen_loss_all = gen_loss_fct(gen_logits_flat, gen_labels_flat)
gen_loss = gen_loss_all.mean()
loss = gen_loss
if self.config.z_loss_weight > 0:
valid_mask = shift_labels >= 0
# top-k approx z_loss for better memory usage
top_logits, _ = torch.topk(
shift_gen_logits, k=self.config.z_loss_top_k, dim=-1
)
lse = torch.logsumexp(top_logits, dim=-1)
z_loss = lse[valid_mask].pow(2).mean() * self.config.z_loss_weight
# z_loss = (
# torch.logsumexp(shift_logits, dim=-1).pow(2)[valid_mask].mean()
# * self.config.z_loss_weight
# )
loss = loss + z_loss
z_loss = z_loss.detach()
return V1CausalLMOutputWithPast(
loss=loss,
z_loss=z_loss,
gen_loss=gen_loss,
copy_loss=None,
logits=logits,
# copy_logits=copy_logits,
# gen_logits=gen_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
loss = None
z_loss = None
gen_loss = None
copy_loss = None
if labels is not None:
if self.config.separate_copy_loss:
# Shift labels and logits for next-token prediction
shift_gen_logits = gen_logits[:, :-1, :].contiguous().float()
shift_copy_logits = copy_logits[:, :-1, :].contiguous().float()
shift_labels = labels[:, 1:].contiguous()
shift_logits = shift_copy_logits
# Build masks
gen_mask = shift_labels < self.config.copy_token_start
copy_mask = shift_labels >= self.config.copy_token_start
# Generation loss
if gen_mask.any():
gen_loss_fct = CrossEntropyLoss(reduction="none")
G = shift_gen_logits.shape[-1]
gen_logits_flat = shift_gen_logits.view(-1, G)
gen_labels_flat = shift_labels.view(-1)
gen_mask_flat = gen_mask.view(-1)
# mask logits
gen_logits_flat_masked = gen_logits_flat[gen_mask_flat]
gen_labels_flat_masked = gen_labels_flat[gen_mask_flat]
gen_loss_all = gen_loss_fct(
gen_logits_flat_masked, gen_labels_flat_masked
)
gen_loss = gen_loss_all.mean()
# Copy loss (adjust label indices to match copy_logits range)
if copy_mask.any():
copy_loss_fct = CrossEntropyLoss(reduction="none")
C = shift_copy_logits.shape[-1]
copy_logits_flat = shift_copy_logits.view(-1, C)
copy_labels_flat = (
shift_labels.view(-1) - self.config.copy_token_start
)
copy_mask_flat = copy_mask.view(-1)
copy_logits_flat_masked = copy_logits_flat[copy_mask_flat]
copy_labels_flat_masked = copy_labels_flat[copy_mask_flat]
copy_loss_all = copy_loss_fct(
copy_logits_flat_masked, copy_labels_flat_masked
)
copy_loss = copy_loss_all.mean()
else:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss(label_smoothing=self.config.label_smoothing)
total_vocab_size = logits.shape[-1] # gen + copy
shift_logits = shift_logits.view(-1, total_vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
gen_loss = loss_fct(shift_logits, shift_labels)
loss = 0.0
if gen_loss is not None:
loss += gen_loss
if copy_loss is not None:
loss += copy_loss
if self.config.z_loss_weight > 0:
valid_mask = shift_labels >= 0
# top-k approx z_loss for better memory usage
top_logits, _ = torch.topk(
shift_logits, k=self.config.z_loss_top_k, dim=-1
)
lse = torch.logsumexp(top_logits, dim=-1)
z_loss = lse[valid_mask].pow(2).mean() * self.config.z_loss_weight
# z_loss = (
# torch.logsumexp(shift_logits, dim=-1).pow(2)[valid_mask].mean()
# * self.config.z_loss_weight
# )
loss = loss + z_loss
z_loss = z_loss.detach()
if gen_loss is not None:
gen_loss = gen_loss.detach()
if copy_loss is not None:
copy_loss = copy_loss.detach()
if self.config.use_cfg:
# expand as max_size for logit processors
extended_vocab_size = self.config.vocab_size + self.config.copy_token_num
B, L, V = logits.shape
pads = torch.full(
(B, L, extended_vocab_size - V),
torch.finfo(gen_logits.dtype).min,
device=logits.device,
).to(logits.dtype)
logits = torch.cat([logits, pads], dim=-1)
# logits = logits.clamp_min(-1e4)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
logits = logits.float()
return V1CausalLMOutputWithPast(
loss=loss,
z_loss=z_loss,
gen_loss=gen_loss,
copy_loss=copy_loss,
logits=logits,
# copy_logits=copy_logits,
# gen_logits=gen_logits,
past_key_values=present_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
rope_deltas=self.rope_deltas,
)
def extract_image_tokens_right_pad(features: torch.FloatTensor, mask: torch.Tensor):
X, M = features, mask.long() # bool is not supported for sort in CUDA
B, L, _ = X.shape
device = X.device
M = M.to(device)
# Compute number of valid elements per batch
valid_counts = M.sum(dim=1) # Shape: [B]
# Replace `.item()` with `max()` and `clamp_min()` for Torch Dynamo compatibility
R = valid_counts.max().clamp_min(1) # Ensures at least 1 for tensor compatibility
# Create index tensors for selection
sorted_indices = M.argsort(dim=1, descending=True) # Move True values to front
batch_indices = torch.arange(B, device=device).unsqueeze(1).expand(B, L)
# Gather sorted X based on mask sorting
X_sorted = X[batch_indices, sorted_indices] # Shape: [B, L, C]
X_selected = X_sorted[:, :R, :] # Select the top valid elements per batch
# Create new mask M2 using `torch.arange`
M2 = torch.arange(L, device=device).expand(B, L) < valid_counts.unsqueeze(1)
M2 = M2[:, :R] # Trim to selected size
# Set out-of-bound values to zero
X_selected = torch.where(M2.unsqueeze(-1), X_selected, torch.zeros_like(X_selected))
return X_selected, M2
__all__ = ["V1ForConditionalGeneration"]