|
import math |
|
from typing import Optional, Union, Tuple, List |
|
from dataclasses import dataclass |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
from torch.nn import CrossEntropyLoss |
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( |
|
Qwen2_5_VisionTransformerPretrainedModel, |
|
Qwen2_5_VLModel, |
|
Qwen2_5_VLForConditionalGeneration, |
|
Qwen2_5_VLCausalLMOutputWithPast, |
|
) |
|
|
|
from .configuration_v1 import V1Config |
|
|
|
|
|
def init_identity(layer, scale: float = 1): |
|
if isinstance(layer, nn.Linear): |
|
with torch.no_grad(): |
|
|
|
rows, cols = layer.weight.shape |
|
identity_matrix = ( |
|
torch.eye(rows, cols) * scale |
|
) |
|
layer.weight.copy_( |
|
identity_matrix |
|
) |
|
if hasattr(layer, "bias"): |
|
layer.bias.fill_(0) |
|
|
|
|
|
@dataclass |
|
class V1CausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast): |
|
z_loss: torch.Tensor = None |
|
gen_loss: torch.Tensor = None |
|
copy_loss: torch.Tensor = None |
|
|
|
|
|
class V1ForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): |
|
config_class = V1Config |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config( |
|
config.vision_config |
|
) |
|
self.model = Qwen2_5_VLModel(config) |
|
self.copy_init_scale = 1 / math.sqrt(self.config.hidden_size) |
|
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
self.rope_deltas = None |
|
|
|
if self.config.do_copy: |
|
if self.config.tie_copy_heads: |
|
self._copy_head = nn.Linear(config.hidden_size, config.copy_hidden_size) |
|
else: |
|
self._copy_q_head = nn.Linear( |
|
config.hidden_size, config.copy_hidden_size |
|
) |
|
self._copy_k_head = nn.Linear( |
|
config.hidden_size, config.copy_hidden_size |
|
) |
|
if self.config.use_gate: |
|
self.gate = nn.Linear(config.hidden_size, 1, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
@torch.no_grad() |
|
def after_loading(self): |
|
if self.config.do_copy: |
|
self.init_heads() |
|
if self.config.use_gate: |
|
self.lm_head.weight.data = self.lm_head.weight.data * 2 |
|
self.gate.weight.data.fill_(0) |
|
|
|
@property |
|
def copy_q_head(self): |
|
return self._copy_head if self.config.tie_copy_heads else self._copy_q_head |
|
|
|
@property |
|
def copy_k_head(self): |
|
return self._copy_head if self.config.tie_copy_heads else self._copy_k_head |
|
|
|
def init_heads(self): |
|
if hasattr(self, "_copy_head"): |
|
init_identity(self._copy_head, self.copy_init_scale) |
|
if hasattr(self, "_copy_k_head"): |
|
init_identity(self._copy_k_head, self.copy_init_scale) |
|
if hasattr(self, "_copy_q_head"): |
|
init_identity(self._copy_q_head, self.copy_init_scale) |
|
|
|
def copy_representations( |
|
self, |
|
inputs_embeds: torch.FloatTensor, |
|
input_ids: torch.LongTensor, |
|
copy_values: Optional[torch.FloatTensor] = None, |
|
): |
|
if copy_values is None: |
|
mask = input_ids == self.config.image_token_id |
|
copy_values, _ = self.extract_image_tokens(inputs_embeds, mask) |
|
assert copy_values is not None |
|
copy_values = copy_values.to(inputs_embeds.device) |
|
input_ids = input_ids.to(inputs_embeds.device) |
|
|
|
input_ids = input_ids.clone() |
|
input_ids = input_ids - self.config.copy_token_start |
|
copy_mask = input_ids >= 0 |
|
input_ids[~copy_mask] = 0 |
|
|
|
assert copy_values is not None |
|
extracted = copy_values.gather( |
|
1, input_ids[..., None].repeat(1, 1, copy_values.shape[-1]) |
|
) |
|
copy_mask = copy_mask.to(extracted.dtype)[..., None] |
|
return copy_mask * extracted + (1 - copy_mask) * inputs_embeds |
|
|
|
def extract_image_tokens(self, features: torch.FloatTensor, mask: torch.Tensor): |
|
out_feat, out_mask = extract_image_tokens_right_pad(features, mask) |
|
return out_feat, out_mask |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
second_per_grid_ts: Optional[torch.Tensor] = None, |
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration |
|
|
|
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") |
|
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") |
|
|
|
>>> messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": "What is shown in this image?"}, |
|
], |
|
}, |
|
] |
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." |
|
```""" |
|
|
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
input_ids = input_ids.clone() |
|
input_ids_with_ptrs = input_ids.clone() |
|
input_ids[input_ids >= self.config.copy_token_start] = ( |
|
self.config.region_token_id |
|
) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
if pixel_values is not None: |
|
pixel_values = pixel_values.type(self.visual.dtype) |
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) |
|
|
|
mask = input_ids == self.config.image_token_id |
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
image_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
image_embeds = image_embeds.to( |
|
inputs_embeds.device, inputs_embeds.dtype |
|
) |
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
|
|
|
if pixel_values_videos is not None: |
|
raise NotImplementedError("video inputs are not supported yet.") |
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype) |
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) |
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
|
n_video_features = video_embeds.shape[0] |
|
if n_video_tokens != n_video_features: |
|
raise ValueError( |
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
|
) |
|
|
|
mask = input_ids == self.config.video_token_id |
|
mask_unsqueezed = mask.unsqueeze(-1) |
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) |
|
video_mask = mask_expanded.to(inputs_embeds.device) |
|
|
|
video_embeds = video_embeds.to( |
|
inputs_embeds.device, inputs_embeds.dtype |
|
) |
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
|
if self.config.do_copy: |
|
copy_keys, copy_keys_mask = None, None |
|
copy_values, copy_values_mask = None, None |
|
|
|
has_cache = bool(past_key_values) |
|
if has_cache: |
|
copy_keys, copy_values = past_key_values[len(past_key_values) - 2] |
|
copy_keys_mask, copy_values_mask = past_key_values[ |
|
len(past_key_values) - 1 |
|
] |
|
|
|
copy_keys_mask = copy_keys_mask[..., 0] |
|
copy_values_mask = copy_values_mask[..., 0] |
|
|
|
inputs_embeds = self.copy_representations( |
|
inputs_embeds, input_ids_with_ptrs, copy_values |
|
) |
|
|
|
|
|
if position_ids is None and ( |
|
attention_mask is None or attention_mask.ndim == 2 |
|
): |
|
|
|
if ( |
|
(cache_position is not None and cache_position[0] == 0) |
|
or self.rope_deltas is None |
|
or (past_key_values is None or past_key_values.get_seq_length() == 0) |
|
): |
|
position_ids, rope_deltas = self.get_rope_index( |
|
input_ids, |
|
image_grid_thw, |
|
video_grid_thw, |
|
second_per_grid_ts, |
|
attention_mask, |
|
) |
|
self.rope_deltas = rope_deltas |
|
|
|
else: |
|
batch_size, seq_length, _ = inputs_embeds.shape |
|
delta = ( |
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device) |
|
if cache_position is not None |
|
else 0 |
|
) |
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) |
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
|
if cache_position is not None: |
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
|
position_ids = position_ids.add(delta) |
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
|
|
|
outputs = self.model( |
|
input_ids=None, |
|
position_ids=position_ids, |
|
attention_mask=attention_mask, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
|
|
gen_logits = self.lm_head(hidden_states) |
|
|
|
if self.config.do_copy: |
|
assert ( |
|
self.config.copy_extraction_layer == -1 |
|
), f"copy_extraction_layer should be -1: {self.config.copy_extraction_layer}" |
|
copy_hidden_states = hidden_states |
|
copy_q_states = copy_hidden_states |
|
if self.config.normalize_copy_states: |
|
copy_q_states = F.normalize(copy_q_states, 2, -1) |
|
copy_q_states = self.copy_q_head(copy_q_states) |
|
|
|
present_key_values = outputs.past_key_values |
|
|
|
if not has_cache: |
|
mask = input_ids == self.config.image_token_id |
|
copy_k_states = ( |
|
inputs_embeds |
|
if self.config.use_embeddings_as_keys |
|
else copy_hidden_states |
|
) |
|
if self.config.normalize_copy_states: |
|
copy_k_states = F.normalize(copy_k_states, 2, -1) |
|
copy_k_states, copy_k_mask = self.extract_image_tokens( |
|
self.copy_k_head(copy_k_states), mask |
|
) |
|
copy_v_states, copy_v_mask = self.extract_image_tokens( |
|
inputs_embeds.detach(), mask |
|
) |
|
|
|
|
|
copy_memories = [ |
|
(copy_k_states.detach(), copy_v_states.detach()), |
|
(copy_k_mask[..., None], copy_v_mask[..., None]), |
|
] |
|
|
|
if use_cache: |
|
|
|
start = len(present_key_values) |
|
for i, mem in enumerate(copy_memories): |
|
present_key_values.update(*mem, start + i) |
|
else: |
|
copy_k_states = copy_keys |
|
copy_k_mask = copy_keys_mask |
|
|
|
assert copy_k_states is not None |
|
assert copy_k_mask is not None |
|
assert ( |
|
copy_k_states.shape[1] > 0 |
|
), f"zero image tokens on batch elements: {copy_k_mask.sum(dim=1)}" |
|
|
|
copy_logits = (copy_q_states @ copy_k_states.transpose(-1, -2)).to( |
|
gen_logits.device |
|
) * self.copy_init_scale |
|
|
|
if hasattr(self, "gate"): |
|
gate = torch.sigmoid(self.gate(hidden_states)) |
|
gen_logits = gen_logits * (1 - gate) |
|
copy_logits = copy_logits * gate |
|
|
|
copy_logits = copy_logits.masked_fill( |
|
~copy_k_mask[:, None, :].to(copy_logits.device), |
|
torch.finfo(copy_logits.dtype).min, |
|
) |
|
logits = torch.cat( |
|
[gen_logits[..., : self.config.copy_token_start], copy_logits], dim=-1 |
|
) |
|
else: |
|
logits = gen_logits |
|
loss = None |
|
z_loss = None |
|
gen_loss = None |
|
if labels is not None: |
|
gen_logits = gen_logits.float() |
|
shift_gen_logits = gen_logits[:, :-1, :].contiguous().float() |
|
shift_labels = labels[:, 1:].contiguous() |
|
gen_loss_fct = CrossEntropyLoss(reduction="none") |
|
gen_logits_flat = shift_gen_logits.view(-1, shift_gen_logits.shape[-1]) |
|
gen_labels_flat = shift_labels.view(-1) |
|
|
|
gen_loss_all = gen_loss_fct(gen_logits_flat, gen_labels_flat) |
|
gen_loss = gen_loss_all.mean() |
|
|
|
loss = gen_loss |
|
|
|
if self.config.z_loss_weight > 0: |
|
valid_mask = shift_labels >= 0 |
|
|
|
top_logits, _ = torch.topk( |
|
shift_gen_logits, k=self.config.z_loss_top_k, dim=-1 |
|
) |
|
lse = torch.logsumexp(top_logits, dim=-1) |
|
z_loss = lse[valid_mask].pow(2).mean() * self.config.z_loss_weight |
|
|
|
|
|
|
|
|
|
|
|
loss = loss + z_loss |
|
z_loss = z_loss.detach() |
|
|
|
return V1CausalLMOutputWithPast( |
|
loss=loss, |
|
z_loss=z_loss, |
|
gen_loss=gen_loss, |
|
copy_loss=None, |
|
logits=logits, |
|
|
|
|
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
rope_deltas=self.rope_deltas, |
|
) |
|
|
|
loss = None |
|
z_loss = None |
|
gen_loss = None |
|
copy_loss = None |
|
if labels is not None: |
|
if self.config.separate_copy_loss: |
|
|
|
shift_gen_logits = gen_logits[:, :-1, :].contiguous().float() |
|
shift_copy_logits = copy_logits[:, :-1, :].contiguous().float() |
|
shift_labels = labels[:, 1:].contiguous() |
|
shift_logits = shift_copy_logits |
|
|
|
|
|
gen_mask = shift_labels < self.config.copy_token_start |
|
copy_mask = shift_labels >= self.config.copy_token_start |
|
|
|
|
|
if gen_mask.any(): |
|
gen_loss_fct = CrossEntropyLoss(reduction="none") |
|
|
|
G = shift_gen_logits.shape[-1] |
|
gen_logits_flat = shift_gen_logits.view(-1, G) |
|
gen_labels_flat = shift_labels.view(-1) |
|
gen_mask_flat = gen_mask.view(-1) |
|
|
|
gen_logits_flat_masked = gen_logits_flat[gen_mask_flat] |
|
gen_labels_flat_masked = gen_labels_flat[gen_mask_flat] |
|
|
|
gen_loss_all = gen_loss_fct( |
|
gen_logits_flat_masked, gen_labels_flat_masked |
|
) |
|
gen_loss = gen_loss_all.mean() |
|
|
|
|
|
if copy_mask.any(): |
|
copy_loss_fct = CrossEntropyLoss(reduction="none") |
|
C = shift_copy_logits.shape[-1] |
|
copy_logits_flat = shift_copy_logits.view(-1, C) |
|
copy_labels_flat = ( |
|
shift_labels.view(-1) - self.config.copy_token_start |
|
) |
|
copy_mask_flat = copy_mask.view(-1) |
|
copy_logits_flat_masked = copy_logits_flat[copy_mask_flat] |
|
copy_labels_flat_masked = copy_labels_flat[copy_mask_flat] |
|
copy_loss_all = copy_loss_fct( |
|
copy_logits_flat_masked, copy_labels_flat_masked |
|
) |
|
copy_loss = copy_loss_all.mean() |
|
else: |
|
|
|
logits = logits.float() |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss(label_smoothing=self.config.label_smoothing) |
|
total_vocab_size = logits.shape[-1] |
|
shift_logits = shift_logits.view(-1, total_vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
gen_loss = loss_fct(shift_logits, shift_labels) |
|
|
|
loss = 0.0 |
|
if gen_loss is not None: |
|
loss += gen_loss |
|
if copy_loss is not None: |
|
loss += copy_loss |
|
|
|
if self.config.z_loss_weight > 0: |
|
valid_mask = shift_labels >= 0 |
|
|
|
top_logits, _ = torch.topk( |
|
shift_logits, k=self.config.z_loss_top_k, dim=-1 |
|
) |
|
lse = torch.logsumexp(top_logits, dim=-1) |
|
z_loss = lse[valid_mask].pow(2).mean() * self.config.z_loss_weight |
|
|
|
|
|
|
|
|
|
|
|
loss = loss + z_loss |
|
z_loss = z_loss.detach() |
|
|
|
if gen_loss is not None: |
|
gen_loss = gen_loss.detach() |
|
if copy_loss is not None: |
|
copy_loss = copy_loss.detach() |
|
|
|
if self.config.use_cfg: |
|
|
|
extended_vocab_size = self.config.vocab_size + self.config.copy_token_num |
|
B, L, V = logits.shape |
|
pads = torch.full( |
|
(B, L, extended_vocab_size - V), |
|
torch.finfo(gen_logits.dtype).min, |
|
device=logits.device, |
|
).to(logits.dtype) |
|
logits = torch.cat([logits, pads], dim=-1) |
|
|
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
logits = logits.float() |
|
return V1CausalLMOutputWithPast( |
|
loss=loss, |
|
z_loss=z_loss, |
|
gen_loss=gen_loss, |
|
copy_loss=copy_loss, |
|
logits=logits, |
|
|
|
|
|
past_key_values=present_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
rope_deltas=self.rope_deltas, |
|
) |
|
|
|
|
|
def extract_image_tokens_right_pad(features: torch.FloatTensor, mask: torch.Tensor): |
|
X, M = features, mask.long() |
|
B, L, _ = X.shape |
|
device = X.device |
|
M = M.to(device) |
|
|
|
|
|
valid_counts = M.sum(dim=1) |
|
|
|
R = valid_counts.max().clamp_min(1) |
|
|
|
sorted_indices = M.argsort(dim=1, descending=True) |
|
batch_indices = torch.arange(B, device=device).unsqueeze(1).expand(B, L) |
|
|
|
|
|
X_sorted = X[batch_indices, sorted_indices] |
|
X_selected = X_sorted[:, :R, :] |
|
|
|
|
|
M2 = torch.arange(L, device=device).expand(B, L) < valid_counts.unsqueeze(1) |
|
M2 = M2[:, :R] |
|
|
|
|
|
X_selected = torch.where(M2.unsqueeze(-1), X_selected, torch.zeros_like(X_selected)) |
|
|
|
return X_selected, M2 |
|
|
|
|
|
__all__ = ["V1ForConditionalGeneration"] |
|
|