| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, List, Union, Tuple |
| | from transformers import Qwen2VLTextModel, Qwen2VLTextConfig, Qwen2VLPreTrainedModel, PretrainedConfig |
| | from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding |
| | from transformers.generation.utils import GenerationMixin |
| | from transformers.modeling_utils import PreTrainedModel |
| | from transformers.modeling_outputs import ModelOutput |
| | from PIL import Image, ImageOps |
| | from encoder import build_sam_vit_b, build_clip_l, MlpProjector |
| | from addict import Dict as ADict |
| | import os |
| | import math |
| | from data import ( |
| | format_messages, |
| | load_pil_images, |
| | text_encode, |
| | BasicImageTransform, |
| | dynamic_preprocess, |
| | re_match, |
| | process_image_with_refs, |
| | NoEOSTextStreamer, |
| | ) |
| | from tqdm import tqdm |
| | from dataclasses import dataclass |
| |
|
| |
|
| | class DeepQwenVLConfig(PretrainedConfig): |
| | """ |
| | Configuration class for DeepQwenVL model. |
| | |
| | This config wraps both the Qwen2VL text config and DeepSeek vision config. |
| | When loading from a Qwen2-VL checkpoint, it will use the checkpoint's config |
| | directly for the text model. |
| | """ |
| | model_type = "deepqwen_vl" |
| | |
| | def __init__( |
| | self, |
| | deepseek_vision_hidden_size: int = 2048, |
| | |
| | |
| | projector_type: str = "mlp", |
| | projector_input_dim: int = 2048, |
| | projector_output_dim: int = None, |
| | projector_hidden_dim: int = None, |
| | |
| | |
| | image_newline_dim: int = None, |
| | view_separator_dim: int = None, |
| | |
| | hidden_size: int = 1536, |
| | intermediate_size: int = 8960, |
| | num_hidden_layers: int = 28, |
| | num_attention_heads: int = 12, |
| | num_key_value_heads: int = 2, |
| | hidden_act: str = "silu", |
| | max_position_embeddings: int = 32768, |
| | initializer_range: float = 0.02, |
| | rms_norm_eps: float = 1e-6, |
| | use_cache: bool = True, |
| | tie_word_embeddings: bool = True, |
| | rope_theta: float = 1000000.0, |
| | attention_dropout: float = 0.0, |
| | vocab_size: int = 151936, |
| | |
| | bos_token_id: int = 151643, |
| | eos_token_id: int = 151645, |
| | pad_token_id: int = 151643, |
| | image_token_id: int = 151655, |
| | video_token_id: int = 151656, |
| | vision_start_token_id: int = 151652, |
| | vision_end_token_id: int = 151653, |
| | vision_token_id: int = 151654, |
| | |
| | rope_scaling: dict = None, |
| | |
| | **kwargs |
| | ): |
| | super().__init__( |
| | bos_token_id=bos_token_id, |
| | eos_token_id=eos_token_id, |
| | pad_token_id=pad_token_id, |
| | tie_word_embeddings=tie_word_embeddings, |
| | **kwargs |
| | ) |
| | |
| | self.deepseek_vision_hidden_size = deepseek_vision_hidden_size |
| | |
| | |
| | self.projector_type = projector_type |
| | self.projector_input_dim = projector_input_dim |
| | self.projector_output_dim = projector_output_dim if projector_output_dim else hidden_size |
| | self.projector_hidden_dim = projector_hidden_dim if projector_hidden_dim else self.projector_output_dim |
| | |
| | |
| | self.image_newline_dim = image_newline_dim if image_newline_dim else hidden_size |
| | self.view_separator_dim = view_separator_dim if view_separator_dim else hidden_size |
| | |
| | |
| | self.hidden_size = hidden_size |
| | self.intermediate_size = intermediate_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.num_key_value_heads = num_key_value_heads |
| | self.hidden_act = hidden_act |
| | self.max_position_embeddings = max_position_embeddings |
| | self.initializer_range = initializer_range |
| | self.rms_norm_eps = rms_norm_eps |
| | self.use_cache = use_cache |
| | self.rope_theta = rope_theta |
| | self.attention_dropout = attention_dropout |
| | self.vocab_size = vocab_size |
| | |
| | |
| | self.image_token_id = image_token_id |
| | self.video_token_id = video_token_id |
| | self.vision_start_token_id = vision_start_token_id |
| | self.vision_end_token_id = vision_end_token_id |
| | self.vision_token_id = vision_token_id |
| | |
| | |
| | if rope_scaling is None: |
| | rope_scaling = {"type": "mrope", "mrope_section": [16, 24, 24]} |
| | self.rope_scaling = rope_scaling |
| | |
| | def to_text_config(self) -> Qwen2VLTextConfig: |
| | """Convert to Qwen2VLTextConfig for the text model.""" |
| | return Qwen2VLTextConfig( |
| | hidden_size=self.hidden_size, |
| | intermediate_size=self.intermediate_size, |
| | num_hidden_layers=self.num_hidden_layers, |
| | num_attention_heads=self.num_attention_heads, |
| | num_key_value_heads=self.num_key_value_heads, |
| | hidden_act=self.hidden_act, |
| | max_position_embeddings=self.max_position_embeddings, |
| | initializer_range=self.initializer_range, |
| | rms_norm_eps=self.rms_norm_eps, |
| | use_cache=self.use_cache, |
| | tie_word_embeddings=self.tie_word_embeddings, |
| | rope_theta=self.rope_theta, |
| | attention_dropout=self.attention_dropout, |
| | vocab_size=self.vocab_size, |
| | bos_token_id=self.bos_token_id, |
| | eos_token_id=self.eos_token_id, |
| | pad_token_id=self.pad_token_id, |
| | rope_scaling=self.rope_scaling, |
| | ) |
| |
|
| |
|
| | @dataclass |
| | class DeepQwenOutputWithPast(ModelOutput): |
| | last_hidden_state: torch.FloatTensor = None |
| | past_key_values: Optional[list[torch.FloatTensor]] = None |
| | hidden_states: Optional[tuple[torch.FloatTensor]] = None |
| | attentions: Optional[tuple[torch.FloatTensor]] = None |
| |
|
| | @dataclass |
| | class DeepQwenCausalLMOutputWithPast(ModelOutput): |
| | loss: Optional[torch.FloatTensor] = None |
| | logits: Optional[torch.FloatTensor] = None |
| | past_key_values: Optional[list[torch.FloatTensor]] = None |
| | hidden_states: Optional[tuple[torch.FloatTensor]] = None |
| | attentions: Optional[tuple[torch.FloatTensor]] = None |
| |
|
| |
|
| | class VisionProjector(nn.Module): |
| | """ |
| | Vision projector with DeepSeek's pretrained layer + trainable adapter. |
| | |
| | Architecture: |
| | deepseek_proj: Linear(2048→1280) [FROZEN - loaded from DeepSeek checkpoint] |
| | SiLU activation |
| | norm: LayerNorm(1280) [TRAINABLE] |
| | adapter: Linear(1280→1536) [TRAINABLE] |
| | |
| | This preserves DeepSeek's learned vision-text alignment while adapting to Qwen's |
| | embedding space. Total 2 layers like LLaVA's MLP projector. |
| | """ |
| |
|
| | def __init__(self, input_dim: int = 2048, hidden_dim: int = 1280, output_dim: int = 1536): |
| | super().__init__() |
| | |
| | self.deepseek_proj = nn.Linear(input_dim, hidden_dim) |
| | |
| | self.norm = nn.LayerNorm(hidden_dim) |
| | self.adapter = nn.Linear(hidden_dim, output_dim) |
| | self._init_adapter_weights() |
| |
|
| | def _init_adapter_weights(self): |
| | """Initialize adapter weights. deepseek_proj will be loaded from checkpoint.""" |
| | nn.init.ones_(self.norm.weight) |
| | nn.init.zeros_(self.norm.bias) |
| | nn.init.normal_(self.adapter.weight, mean=0.0, std=0.01) |
| | nn.init.zeros_(self.adapter.bias) |
| |
|
| | def forward(self, x): |
| | x = self.deepseek_proj(x) |
| | x = F.silu(x) |
| | x = self.norm(x) |
| | x = self.adapter(x) |
| | return x |
| |
|
| | class DeepQwenVLPreTrainedModel(PreTrainedModel): |
| | config_class = DeepQwenVLConfig |
| | base_model_prefix = "model" |
| | supports_gradient_checkpointing = True |
| | _skip_keys_device_placement = "past_key_values" |
| | _supports_flash_attn = True |
| | _supports_sdpa = True |
| | _supports_static_cache = True |
| | _supports_attention_backend = True |
| | |
| | _keys_to_ignore_on_load_missing = [ |
| | "sam_model", |
| | "vision_model", |
| | "projector", |
| | "image_newline", |
| | "view_separator", |
| | ] |
| | |
| | def _init_weights(self, module): |
| | """Initialize the weights.""" |
| | std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02 |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| |
|
| |
|
| | class DeepQwenVLModel(Qwen2VLTextModel): |
| | """ |
| | DeepQwenVL Model that combines DeepSeek's vision encoders with Qwen2VL's text model. |
| | |
| | Accepts either: |
| | - A DeepQwenVLConfig |
| | - A Qwen2VLTextConfig (for compatibility with from_pretrained from Qwen checkpoints) |
| | - A generic PretrainedConfig (will extract necessary fields) |
| | """ |
| | config_class = DeepQwenVLConfig |
| | |
| | def __init__(self, config): |
| | if isinstance(config, DeepQwenVLConfig): |
| | text_config = config.to_text_config() |
| | output_hidden_size = config.projector_output_dim |
| | vision_dim = config.deepseek_vision_hidden_size |
| | elif isinstance(config, Qwen2VLTextConfig): |
| | text_config = config |
| | output_hidden_size = config.hidden_size |
| | vision_dim = 2048 |
| | else: |
| | text_config = config |
| | output_hidden_size = getattr(config, 'hidden_size', 1536) |
| | vision_dim = getattr(config, 'deepseek_vision_hidden_size', 2048) |
| | |
| | super(DeepQwenVLModel, self).__init__(text_config) |
| | |
| | self.config = config |
| | self.output_hidden_size = output_hidden_size |
| | |
| | self.sam_model = build_sam_vit_b() |
| | self.vision_model = build_clip_l() |
| | |
| | self.deepseek_vision_dim = vision_dim |
| | self.deepseek_hidden_dim = 1280 |
| | |
| | self.projector = VisionProjector( |
| | input_dim=self.deepseek_vision_dim, |
| | hidden_dim=self.deepseek_hidden_dim, |
| | output_dim=output_hidden_size |
| | ) |
| | |
| | embed_std = 1 / torch.sqrt(torch.tensor(output_hidden_size, dtype=torch.float32)) |
| | self.image_newline = nn.Parameter(torch.randn(output_hidden_size) * embed_std) |
| | self.view_separator = nn.Parameter(torch.randn(output_hidden_size) * embed_std) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | images: Optional[torch.FloatTensor] = None, |
| | images_seq_mask: Optional[torch.FloatTensor] = None, |
| | images_spatial_crop: Optional[torch.FloatTensor] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| |
|
| | if inputs_embeds is None: |
| | inputs_embeds = self.get_input_embeddings()(input_ids) |
| | |
| | sam_model = getattr(self, 'sam_model', None) |
| | vision_model = getattr(self, 'vision_model', None) |
| |
|
| | should_process_images = ( |
| | sam_model is not None |
| | and images is not None |
| | and images_seq_mask is not None |
| | and (input_ids.shape[1] != 1 or self.training) |
| | and torch.sum(images[0][1]).item() != 0 |
| | ) |
| |
|
| | if should_process_images: |
| | idx = 0 |
| | for image, crop_shape in zip(images, images_spatial_crop): |
| | images_in_this_batch = [] |
| | patches = image[0] |
| | image_ori = image[1] |
| |
|
| | if torch.sum(patches).item() != 0: |
| | |
| | with torch.no_grad(): |
| | local_features_1 = sam_model(patches) |
| | local_features_2 = vision_model(patches, local_features_1) |
| | local_features = torch.cat((local_features_2[:, 1:], local_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
| | local_features = local_features.detach() |
| | local_features = self.projector(local_features) |
| |
|
| | |
| | with torch.no_grad(): |
| | global_features_1 = sam_model(image_ori) |
| | global_features_2 = vision_model(image_ori, global_features_1) |
| | global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
| | global_features = global_features.detach() |
| | global_features = self.projector(global_features) |
| |
|
| | |
| | _, hw, n_dim = global_features.shape |
| | h = w = int(hw ** 0.5) |
| | _2, hw2, n_dim2 = local_features.shape |
| | h2 = w2 = int(hw2 ** 0.5) |
| | width_crop_num, height_crop_num = crop_shape[0], crop_shape[1] |
| |
|
| | global_features = global_features.view(h, w, n_dim) |
| | global_features = torch.cat( |
| | [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 |
| | ) |
| | global_features = global_features.view(-1, n_dim) |
| |
|
| | local_features = local_features.view( |
| | height_crop_num, width_crop_num, h2, w2, n_dim2 |
| | ).permute(0, 2, 1, 3, 4).reshape(height_crop_num*h2, width_crop_num*w2, n_dim2) |
| | local_features = torch.cat( |
| | [local_features, self.image_newline[None, None, :].expand(height_crop_num * h2, 1, n_dim2)], dim=1 |
| | ) |
| | local_features = local_features.view(-1, n_dim2) |
| |
|
| | global_local_features = torch.cat([local_features, global_features, self.view_separator[None, :]], dim=0) |
| | images_in_this_batch.append(global_local_features) |
| | else: |
| | |
| | with torch.no_grad(): |
| | global_features_1 = sam_model(image_ori) |
| | global_features_2 = vision_model(image_ori, global_features_1) |
| | global_features = torch.cat((global_features_2[:, 1:], global_features_1.flatten(2).permute(0, 2, 1)), dim=-1) |
| | global_features = global_features.detach() |
| | global_features = self.projector(global_features) |
| |
|
| | _, hw, n_dim = global_features.shape |
| | h = w = int(hw ** 0.5) |
| | global_features = global_features.view(h, w, n_dim) |
| | global_features = torch.cat( |
| | [global_features, self.image_newline[None, None, :].expand(h, 1, n_dim)], dim=1 |
| | ) |
| | global_features = global_features.view(-1, n_dim) |
| | global_local_features = torch.cat([global_features, self.view_separator[None, :]], dim=0) |
| | images_in_this_batch.append(global_local_features) |
| |
|
| | if images_in_this_batch: |
| | images_in_this_batch = torch.cat(images_in_this_batch, dim=0) |
| | inputs_embeds[idx].masked_scatter_(images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch) |
| | idx += 1 |
| |
|
| | outputs = super().forward( |
| | input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids, |
| | output_attentions=output_attentions, output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, cache_position=cache_position |
| | ) |
| |
|
| | return DeepQwenOutputWithPast( |
| | last_hidden_state=outputs.last_hidden_state, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) if return_dict else outputs.to_tuple() |
| |
|
| |
|
| | class DeepQwenVLForCausalLM(DeepQwenVLModel, GenerationMixin): |
| | """ |
| | DeepQwenVL Model for causal language modeling with vision capabilities. |
| | |
| | Combines DeepSeek's vision encoders (SAM + CLIP) with Qwen2VL's text model. |
| | """ |
| | config_class = DeepQwenVLConfig |
| | _tied_weights_keys = ["lm_head.weight"] |
| | |
| | _keys_to_ignore_on_load_missing = [ |
| | |
| | |
| | |
| | |
| | |
| | ] |
| | |
| | def __init__(self, config): |
| | """ |
| | Initialize the model. |
| | |
| | Args: |
| | config: Can be DeepQwenVLConfig, Qwen2VLTextConfig, or a generic config |
| | from a Qwen2-VL checkpoint. |
| | """ |
| | super().__init__(config) |
| |
|
| | hidden_size = getattr(config, 'hidden_size', 1536) |
| | vocab_size = getattr(config, 'vocab_size', 151936) |
| |
|
| | self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) |
| |
|
| | self.post_init() |
| | |
| | def get_output_embeddings(self): |
| | return getattr(self, 'lm_head', None) |
| | |
| | def set_output_embeddings(self, new_embeddings): |
| | self.lm_head = new_embeddings |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | images: Optional[torch.FloatTensor] = None, |
| | images_seq_mask: Optional[torch.FloatTensor] = None, |
| | images_spatial_crop: Optional[torch.FloatTensor] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| |
|
| | outputs = super().forward( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | position_ids = position_ids, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | images=images, |
| | images_seq_mask=images_seq_mask, |
| | images_spatial_crop=images_spatial_crop, |
| | return_dict=True, |
| | cache_position=cache_position, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| | logits = self.lm_head(hidden_states) |
| | logits = logits.float() |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) |
| |
|
| | return DeepQwenCausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| | |
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids, |
| | past_key_values=None, |
| | attention_mask=None, |
| | inputs_embeds=None, |
| | cache_position=None, |
| | position_ids=None, |
| | images=None, |
| | images_seq_mask=None, |
| | images_spatial_crop=None, |
| | **kwargs, |
| | ): |
| | model_inputs = super().prepare_inputs_for_generation( |
| | input_ids, |
| | past_key_values=past_key_values, |
| | attention_mask=attention_mask, |
| | inputs_embeds=inputs_embeds, |
| | cache_position=cache_position, |
| | position_ids=position_ids, |
| | **kwargs, |
| | ) |
| |
|
| | model_inputs["images"] = images |
| | model_inputs["images_seq_mask"] = images_seq_mask |
| | model_inputs["images_spatial_crop"] = images_spatial_crop |
| | model_inputs["position_ids"] = None |
| |
|
| | |
| | if cache_position is not None and cache_position[0] != 0: |
| | model_inputs["images"] = None |
| | model_inputs["images_seq_mask"] = None |
| | model_inputs["images_spatial_crop"] = None |
| |
|
| | return model_inputs |
| | |
| | def reinitialize_projector(self, vis_mlp=None, device=None, dtype=None): |
| | """ |
| | Reinitialize the projector, image_newline, and view_separator. |
| | Call this after from_pretrained when loading from a Qwen checkpoint. |
| | """ |
| | if device is None: |
| | for param in self.parameters(): |
| | if param.device.type != 'meta': |
| | device = param.device |
| | break |
| | if device is None: |
| | device = 'cpu' |
| | if dtype is None: |
| | dtype = torch.bfloat16 |
| | |
| | input_dim = self.deepseek_vision_dim |
| | output_dim = self.output_hidden_size |
| | |
| | if vis_mlp is not None: |
| | self.projector = VisionProjector(input_dim=input_dim, output_dim=output_dim).to(device=device, dtype=dtype) |
| | |
| | else: |
| | self.projector = nn.Linear(in_features=input_dim, out_features=output_dim).to(device=device, dtype=dtype) |
| | nn.init.normal_(self.projector.weight, mean=0.0, std=0.01) |
| | if self.projector.bias is not None: |
| | nn.init.zeros_(self.projector.bias) |
| | |
| | embed_std = 1 / torch.sqrt(torch.tensor(output_dim, dtype=torch.float32)) |
| | self.image_newline = nn.Parameter( |
| | torch.randn(output_dim, device=device, dtype=dtype) * embed_std.item() |
| | ) |
| | self.view_separator = nn.Parameter( |
| | torch.randn(output_dim, device=device, dtype=dtype) * embed_std.item() |
| | ) |
| | |
| | print(f"Projector reinitialized on {device} with dtype {dtype}") |
| | |
| | def load_pretrained_vision(self, pretrained_path: str): |
| | try: |
| | from safetensors import safe_open |
| | except ImportError: |
| | raise ImportError("Please install safetensors to load the pretrained vision model.") |
| | |
| | assert os.path.exists(pretrained_path), f"Pretrained path {pretrained_path} does not exist." |
| |
|
| | vision_weights = {} |
| | with safe_open(f"{pretrained_path}/model-00001-of-000001.safetensors", framework="pt", device="cpu") as f: |
| | for k in f.keys(): |
| | vision_weights[k] = f.get_tensor(k) |
| | |
| | prefixes = { |
| | "sam_model": "model.sam_model.", |
| | "vision_model": "model.vision_model.", |
| | } |
| |
|
| | try: |
| | for p in prefixes.keys(): |
| | state_dict = {} |
| |
|
| | for k, v in vision_weights.items(): |
| | if k.startswith(prefixes[p]): |
| | new_key = k[len(prefixes[p]):] |
| | state_dict[new_key] = v |
| |
|
| | getattr(self, p).load_state_dict(state_dict, strict=False) |
| |
|
| | print("Pretrained vision model loaded successfully.") |
| | except Exception as e: |
| | print("Error loading pretrained vision model:", e) |
| | raise e |
| |
|
| | def load_deepseek_projector(self, pretrained_path: str): |
| | """ |
| | Load DeepSeek's projector weights into the deepseek_proj layer. |
| | |
| | DeepSeek checkpoint has: |
| | - projector.weight: shape (1280, 2048) |
| | - projector.bias: shape (1280,) |
| | |
| | These get loaded into self.projector.deepseek_proj |
| | """ |
| | try: |
| | from safetensors import safe_open |
| | except ImportError: |
| | raise ImportError("Please install safetensors to load DeepSeek projector.") |
| |
|
| | assert os.path.exists(pretrained_path), f"Pretrained path {pretrained_path} does not exist." |
| |
|
| | |
| | safetensor_files = [f for f in os.listdir(pretrained_path) if f.endswith('.safetensors')] |
| | if not safetensor_files: |
| | raise FileNotFoundError(f"No safetensors files found in {pretrained_path}") |
| |
|
| | safetensor_path = os.path.join(pretrained_path, safetensor_files[0]) |
| |
|
| | projector_weights = {} |
| | with safe_open(safetensor_path, framework="pt", device="cpu") as f: |
| | for k in f.keys(): |
| | if 'projector' in k: |
| | projector_weights[k] = f.get_tensor(k) |
| |
|
| | |
| | if 'projector.weight' in projector_weights: |
| | self.projector.deepseek_proj.weight.data = projector_weights['projector.weight'] |
| | self.projector.deepseek_proj.bias.data = projector_weights['projector.bias'] |
| | print(f"Loaded DeepSeek projector weights: {self.projector.deepseek_proj.weight.shape}") |
| | print(f" Weight mean: {self.projector.deepseek_proj.weight.mean().item():.6f}") |
| | print(f" Weight std: {self.projector.deepseek_proj.weight.std().item():.6f}") |
| | elif 'model.projector.weight' in projector_weights: |
| | self.projector.deepseek_proj.weight.data = projector_weights['model.projector.weight'] |
| | self.projector.deepseek_proj.bias.data = projector_weights['model.projector.bias'] |
| | print(f"Loaded DeepSeek projector weights (model. prefix)") |
| | else: |
| | print(f"Warning: Could not find projector weights. Available keys: {list(projector_weights.keys())}") |
| |
|
| | def disable_torch_init(self): |
| | """ |
| | Disable the redundant torch default initialization to accelerate model creation. |
| | """ |
| | import torch |
| | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) |
| | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) |
| |
|
| | def infer( |
| | self, |
| | tokenizer, |
| | prompt='', |
| | image_file='', |
| | output_path = '', |
| | base_size=1024, |
| | image_size=640, |
| | crop_mode=True, |
| | test_compress=False, |
| | save_results=False, |
| | eval_mode=False |
| | ): |
| | self.disable_torch_init() |
| | |
| | os.makedirs(output_path, exist_ok=True) |
| | os.makedirs(f'{output_path}/images', exist_ok=True) |
| | conversation = [ |
| | { |
| | "role": "user", |
| | "content": [ |
| | { |
| | "type": "image", |
| | "image": f"{image_file}", |
| | }, |
| | {"type": "text", "text": f"{prompt}"}, |
| | ], |
| | } |
| | ] |
| | |
| | formatted_prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) |
| |
|
| | patch_size = 16 |
| | downsample_ratio = 4 |
| | images = load_pil_images(conversation) |
| |
|
| | valid_img_tokens = 0 |
| | ratio = 1 |
| |
|
| | image_draw = images[0].copy() |
| |
|
| | w,h = image_draw.size |
| | ratio = 1 - ((max(w, h) - min(w, h)) / (max(w, h))) |
| | |
| |
|
| | image_transform=BasicImageTransform(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), normalize=True) |
| | images_seq_mask = [] |
| |
|
| | image_token = '<|image_pad|>' |
| | image_token_id = 151655 |
| | text_splits = formatted_prompt.split(image_token) |
| |
|
| | images_list, images_crop_list, images_seq_mask = [], [], [] |
| | tokenized_str = [] |
| | images_spatial_crop = [] |
| | for text_sep, image in zip(text_splits, images): |
| |
|
| | tokenized_sep = text_encode(tokenizer, text_sep, bos=False, eos=False) |
| | tokenized_str += tokenized_sep |
| | images_seq_mask += [False] * len(tokenized_sep) |
| |
|
| | if crop_mode: |
| |
|
| | if image.size[0] <= 640 and image.size[1] <= 640: |
| | crop_ratio = [1, 1] |
| |
|
| | else: |
| | if crop_mode: |
| | images_crop_raw, crop_ratio = dynamic_preprocess(image) |
| | else: |
| | crop_ratio = [1, 1] |
| | |
| | global_view = ImageOps.pad(image, (base_size, base_size), |
| | color=tuple(int(x * 255) for x in image_transform.mean)) |
| | |
| | if base_size == 1024: |
| | valid_img_tokens += int(256 * ratio) |
| | elif base_size == 1280: |
| | valid_img_tokens += int(400 * ratio) |
| | |
| | |
| | |
| |
|
| | images_list.append(image_transform(global_view).to(torch.bfloat16)) |
| |
|
| | |
| |
|
| | width_crop_num, height_crop_num = crop_ratio |
| |
|
| | images_spatial_crop.append([width_crop_num, height_crop_num]) |
| | |
| | |
| | if width_crop_num > 1 or height_crop_num > 1: |
| | """process the local views""" |
| | |
| | for i in range(len(images_crop_raw)): |
| | images_crop_list.append(image_transform(images_crop_raw[i]).to(torch.bfloat16)) |
| | |
| | if image_size == 640: |
| | valid_img_tokens += len(images_crop_list) * 100 |
| |
|
| | num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
| | num_queries_base = math.ceil((base_size // patch_size) / downsample_ratio) |
| |
|
| | """add image tokens""" |
| | |
| | tokenized_image = ([image_token_id] * num_queries_base + [image_token_id]) * num_queries_base |
| | tokenized_image += [image_token_id] |
| | if width_crop_num > 1 or height_crop_num > 1: |
| | tokenized_image += ([image_token_id] * (num_queries * width_crop_num) + [image_token_id]) * ( |
| | num_queries * height_crop_num) |
| | tokenized_str += tokenized_image |
| | images_seq_mask += [True] * len(tokenized_image) |
| | |
| |
|
| | else: |
| | """process the global view""" |
| | if image_size <= 640: |
| | image = image.resize((image_size, image_size)) |
| | global_view = ImageOps.pad(image, (image_size, image_size), |
| | color=tuple(int(x * 255) for x in image_transform.mean)) |
| | images_list.append(image_transform(global_view).to(torch.bfloat16)) |
| |
|
| | if base_size == 1024: |
| | valid_img_tokens += int(256 * ratio) |
| | elif base_size == 1280: |
| | valid_img_tokens += int(400 * ratio) |
| | elif base_size == 640: |
| | valid_img_tokens += int(100 * 1) |
| | elif base_size == 512: |
| | valid_img_tokens += int(64 * 1) |
| |
|
| | width_crop_num, height_crop_num = 1, 1 |
| |
|
| | images_spatial_crop.append([width_crop_num, height_crop_num]) |
| |
|
| |
|
| | """add image tokens""" |
| | num_queries = math.ceil((image_size // patch_size) / downsample_ratio) |
| |
|
| | tokenized_image = ([image_token_id] * num_queries + [image_token_id]) * num_queries |
| | tokenized_image += [image_token_id] |
| | |
| | |
| | tokenized_str += tokenized_image |
| | images_seq_mask += [True] * len(tokenized_image) |
| | |
| | |
| | """process the last text split""" |
| | tokenized_sep = text_encode(tokenizer, text_splits[-1], bos=False, eos=False) |
| | tokenized_str += tokenized_sep |
| | images_seq_mask += [False] * len(tokenized_sep) |
| |
|
| | |
| | |
| |
|
| | input_ids = torch.LongTensor(tokenized_str) |
| |
|
| | images_seq_mask = torch.tensor(images_seq_mask, dtype=torch.bool) |
| |
|
| | if len(images_list) == 0: |
| | images_ori = torch.zeros((1, 3, image_size, image_size)) |
| | images_spatial_crop = torch.zeros((1, 2), dtype=torch.long) |
| | images_crop = torch.zeros((1, 3, base_size, base_size)) |
| |
|
| | else: |
| | images_ori = torch.stack(images_list, dim=0) |
| | images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long) |
| | if images_crop_list: |
| | images_crop = torch.stack(images_crop_list, dim=0) |
| | else: |
| | images_crop = torch.zeros((1, 3, base_size, base_size)) |
| |
|
| |
|
| |
|
| | if not eval_mode: |
| | streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) |
| | with torch.autocast("cuda", dtype=torch.bfloat16): |
| | with torch.no_grad(): |
| | output_ids = self.generate( |
| | input_ids.unsqueeze(0).cuda(), |
| | images=[(images_crop.cuda(), images_ori.cuda())], |
| | images_seq_mask=images_seq_mask.unsqueeze(0).cuda(), |
| | images_spatial_crop=images_spatial_crop, |
| | temperature=0.5, |
| | eos_token_id=tokenizer.eos_token_id, |
| | streamer=streamer, |
| | max_new_tokens=8192, |
| | no_repeat_ngram_size=20, |
| | use_cache=True |
| | ) |
| | else: |
| | with torch.autocast("cuda", dtype=torch.bfloat16): |
| | with torch.no_grad(): |
| | output_ids = self.generate( |
| | input_ids.unsqueeze(0).cuda(), |
| | images=[(images_crop.cuda(), images_ori.cuda())], |
| | images_seq_mask=images_seq_mask.unsqueeze(0).cuda(), |
| | images_spatial_crop=images_spatial_crop, |
| | temperature=0.5, |
| | eos_token_id=tokenizer.eos_token_id, |
| | max_new_tokens=8192, |
| | no_repeat_ngram_size=35, |
| | use_cache=True |
| | ) |
| |
|
| | |
| | has_image = any( |
| | (isinstance(item, dict) and item.get('type') == 'image') |
| | for msg in conversation |
| | for item in (msg.get('content', []) if isinstance(msg.get('content'), list) else []) |
| | ) |
| | |
| | if has_image and eval_mode: |
| | outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False) |
| | |
| | stop_str = tokenizer.eos_token or '<|im_end|>' |
| | if outputs.endswith(stop_str): |
| | outputs = outputs[:-len(stop_str)] |
| | outputs = outputs.strip() |
| |
|
| | return outputs |
| | |
| | if has_image and test_compress: |
| | outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False) |
| | pure_texts_outputs_token_length = len(text_encode(tokenizer, outputs, bos=False, eos=False)) |
| | print('='*50) |
| | print('image size: ', (w, h)) |
| | print('valid image tokens: ', int(valid_img_tokens)) |
| | print('output texts tokens (valid): ', pure_texts_outputs_token_length) |
| | print('compression ratio: ', round(pure_texts_outputs_token_length/valid_img_tokens, 2)) |
| | print('='*50) |
| |
|
| |
|
| | if has_image and save_results: |
| | outputs = tokenizer.decode(output_ids[0, input_ids.unsqueeze(0).cuda().shape[1]:], skip_special_tokens=False) |
| | |
| | stop_str = tokenizer.eos_token or '<|im_end|>' |
| |
|
| | print('='*15 + 'save results:' + '='*15) |
| | |
| | if outputs.endswith(stop_str): |
| | outputs = outputs[:-len(stop_str)] |
| | outputs = outputs.strip() |
| |
|
| | matches_ref, matches_images, mathes_other = re_match(outputs) |
| | result = process_image_with_refs(image_draw, matches_ref, output_path) |
| |
|
| | for idx, a_match_image in enumerate(tqdm(matches_images, desc="image")): |
| | outputs = outputs.replace(a_match_image, ' + '.jpg)\n') |
| | |
| | for idx, a_match_other in enumerate(tqdm(mathes_other, desc="other")): |
| | outputs = outputs.replace(a_match_other, '').replace('\\coloneqq', ':=').replace('\\eqqcolon', '=:') |
| |
|
| | with open(f'{output_path}/result.mmd', 'w', encoding = 'utf-8') as afile: |
| | afile.write(outputs) |
| |
|
| | if 'line_type' in outputs: |
| | import matplotlib.pyplot as plt |
| | lines = eval(outputs)['Line']['line'] |
| |
|
| | line_type = eval(outputs)['Line']['line_type'] |
| | endpoints = eval(outputs)['Line']['line_endpoint'] |
| |
|
| | fig, ax = plt.subplots(figsize=(3,3), dpi=200) |
| | ax.set_xlim(-15, 15) |
| | ax.set_ylim(-15, 15) |
| |
|
| | for idx, line in enumerate(lines): |
| | try: |
| | p0 = eval(line.split(' -- ')[0]) |
| | p1 = eval(line.split(' -- ')[-1]) |
| |
|
| | if line_type[idx] == '--': |
| | ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth=0.8, color='k') |
| | else: |
| | ax.plot([p0[0], p1[0]], [p0[1], p1[1]], linewidth = 0.8, color = 'k') |
| |
|
| | ax.scatter(p0[0], p0[1], s=5, color = 'k') |
| | ax.scatter(p1[0], p1[1], s=5, color = 'k') |
| | except: |
| | pass |
| |
|
| | for endpoint in endpoints: |
| |
|
| | label = endpoint.split(': ')[0] |
| | (x, y) = eval(endpoint.split(': ')[1]) |
| | ax.annotate(label, (x, y), xytext=(1, 1), textcoords='offset points', |
| | fontsize=5, fontweight='light') |
| | |
| |
|
| | plt.savefig(f'{output_path}/geo.jpg') |
| | plt.close() |
| |
|
| | result.save(f"{output_path}/result_with_boxes.jpg") |
| |
|
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|