| import torch |
| import torch.nn as nn |
| from typing import List, Tuple, Optional, Any, Dict |
| from dataclasses import dataclass |
|
|
| from transformers import Qwen2_5_VLForConditionalGeneration |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig |
| from transformers.models.idefics2.modeling_idefics2 import Idefics2PerceiverResampler |
| from transformers.models.idefics2.configuration_idefics2 import Idefics2PerceiverConfig |
| from transformers.utils import ModelOutput |
| from transformers.processing_utils import Unpack |
|
|
| @dataclass |
| class TRASEROutput(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[List[torch.FloatTensor]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| rope_deltas: Optional[torch.LongTensor] = None |
|
|
| class TRASER(Qwen2_5_VLForConditionalGeneration): |
| def __init__(self, config: Qwen2_5_VLConfig, **kwargs): |
| super().__init__(config) |
| |
| for k, v in kwargs.items(): |
| if not hasattr(config, k): |
| setattr(config, k, v) |
| |
| self.config = config |
| self._build_perceiver(dtype=config.torch_dtype, attn_impl=config._attn_implementation) |
| self.post_init() |
|
|
| def _build_perceiver(self, dtype: torch.dtype, attn_impl: str) -> None: |
| h = int(getattr(self.config, "hidden_size", 2048)) |
| n_latents = int(getattr(self.config, "temporal_resampler_n_latents", 64)) |
| depth = int(getattr(self.config, "resampler_depth", 3)) |
|
|
| perceiver_cfg = Idefics2PerceiverConfig( |
| hidden_size=h, |
| resampler_n_latents=n_latents, |
| resampler_depth=depth, |
| _attn_implementation=attn_impl, |
| torch_dtype=dtype, |
| ) |
| self.perceiver_resampler = Idefics2PerceiverResampler(perceiver_cfg) |
| |
| if getattr(self.config, "object_resampler", True): |
| second_n_latents = int(getattr(self.config, "object_resampler_n_latents", 32)) |
|
|
| second_perceiver_cfg = Idefics2PerceiverConfig( |
| hidden_size=h, |
| resampler_n_latents=second_n_latents, |
| resampler_depth=depth, |
| _attn_implementation=attn_impl, |
| torch_dtype=dtype, |
| ) |
| self.second_perceiver_resampler = Idefics2PerceiverResampler(second_perceiver_cfg) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| position_ids=None, |
| use_cache=True, |
| pixel_values=None, |
| pixel_values_videos=None, |
| image_grid_thw=None, |
| video_grid_thw=None, |
| second_per_grid_ts=None, |
| **kwargs, |
| ): |
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| position_ids=position_ids, |
| pixel_values=pixel_values, |
| pixel_values_videos=pixel_values_videos, |
| image_grid_thw=image_grid_thw, |
| video_grid_thw=video_grid_thw, |
| second_per_grid_ts=second_per_grid_ts, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
|
|
| model_inputs["position_ids"] = position_ids |
| if cache_position is not None and cache_position[0] != 0: |
| model_inputs["pixel_values"] = None |
| model_inputs["pixel_values_videos"] = None |
| model_inputs["position_ids"] = None |
| return model_inputs |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| rope_deltas: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[Any], |
| ) -> TRASEROutput: |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
| if rope_deltas is not None: |
| self.model.rope_deltas = rope_deltas |
|
|
| is_prefill = (inputs_embeds is not None) and ( |
| past_key_values is None or (hasattr(past_key_values, "get_seq_length") and past_key_values.get_seq_length() == 0) |
| ) |
|
|
| if is_prefill: |
| outputs = self.model.language_model( |
| input_ids=None, |
| inputs_embeds=inputs_embeds, |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| return_dict=True, |
| ) |
| else: |
| inputs_embeds = self.model.get_input_embeddings()(input_ids) |
| batch_size, seq_length, _ = inputs_embeds.shape |
| delta = ( |
| (cache_position[0] + self.model.rope_deltas).to(inputs_embeds.device) |
| if cache_position is not None |
| else 0 |
| ) |
| pos = torch.arange(seq_length, device=inputs_embeds.device).view(1, -1).expand(batch_size, -1) |
| if cache_position is not None: |
| delta = delta.repeat_interleave(max(1, batch_size // delta.shape[0]), dim=0) |
| pos = pos.add(delta).unsqueeze(0).expand(3, -1, -1) |
|
|
| outputs = self.model.language_model( |
| input_ids=None, |
| position_ids=pos, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
|
|
| return TRASEROutput( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| rope_deltas=self.model.rope_deltas, |
| ) |