POINTS-GUI-G / modeling_points_gui.py
YuanLiuuuuuu's picture
Upload folder using huggingface_hub
534f14c verified
import math
from typing import List, Optional, Tuple, Union, Any
import numpy as np
import torch
from PIL import Image
from transformers import (
GenerationMixin,
PreTrainedModel,
PreTrainedTokenizer
)
try:
from transformers import Qwen3ForCausalLM
except ImportError:
print('Please upgrade transformers to version 4.51.0 or higher')
try:
from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( # noqa
Qwen2VLImageProcessor,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import PatchMerger
except ImportError:
print('Please upgrade transformers to version 4.46.3 or higher')
from .configuration_points_gui import POINTSGUIConfig
try:
from wepoints.models import Qwen2VisionTransformerForNavitPOINTS
except ImportError:
print('Please install WePOINTS, and refer to https://github.com/WePOINTS/WePOINTS')
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast
class POINTSGUIModel(PreTrainedModel, GenerationMixin):
config_class = POINTSGUIConfig
_no_split_modules = []
_supports_flash_attn_2 = True
supports_gradient_checkpointing = True
"""Chat model for POINTSv1.5.
Args:
config (POINTSChatConfigV15): The model config.
"""
def __init__(self, config: POINTSGUIConfig, **kwargs) -> None:
super().__init__(config)
config.llm_config._attn_implementation = "flash_attention_2"
config._attn_implementation_autoset = False
self.llm = Qwen3ForCausalLM(config.llm_config)
self.vision_encoder = Qwen2VisionTransformerForNavitPOINTS._from_config( # noqa
config.vision_config, attn_implementation="flash_attention_2"
)
self.vision_projector = PatchMerger(config.llm_config.hidden_size,
context_dim=1280).to(torch.bfloat16)
def process_images(self, images: torch.Tensor,
image_grid_thws: List[list]) -> torch.Tensor:
"""Obtain image features from the vision encoder.
Args:
images (torch.Tensor): The input images.
image_grid_thws (List[list]): The grid thresholds for the images.
Returns:
torch.Tensor: The image features.
"""
image_features = self.vision_encoder(images, grid_thw=image_grid_thws)
image_features = self.vision_projector(image_features)
return image_features
def construct_prompt(self, messages: List[dict],
image_processor: Qwen2VLImageProcessor) -> Tuple[str, List[Image.Image], List[list]]: # noqa
"""Construct the prompt for the chat model.
Args:
messages (List[dict]): The input messages.
Returns:
Tuple[str, List[Image.Image], List[list]]:
The prompt, images, and image grid shape.
"""
images = []
image_grid_thws = []
reconstructed_messages = []
for message in messages:
role = message['role']
content_from_role = ''
for item in message['content']:
if item['type'] == 'text':
content_from_role += item['text']
elif item['type'] == 'image':
image_path = item['image']
max_pixels = item['max_pixels'] if 'max_pixels' in item else None
image = Image.open(image_path).convert('RGB')
if max_pixels is not None:
# obtain image size
width, height = image.size
cur_image_pixels = width * height
if cur_image_pixels > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
new_width = math.floor(width / beta)
new_height = math.floor(height / beta)
image = image.resize((new_width, new_height))
image_data = image_processor(images=image)
pixel_values = image_data['pixel_values']
image_grid_thw = image_data['image_grid_thw']
images.extend(pixel_values)
image_grid_thws.append(image_grid_thw)
seq_len = int(image_grid_thw[0][1] * image_grid_thw[0][2] / 4) # noqa
content_from_role += '<|vision_start|>' + '<|image_pad|>' * seq_len + '<|vision_end|>' + '\n' # noqa
reconstructed_messages.append({
'role': role,
'content': content_from_role
})
prompt = self.apply_chat_template(reconstructed_messages)
return prompt, images, image_grid_thws
def apply_chat_template(self, messages: List[dict]) -> str:
"""Apply the chat template to the input messages.
Args:
messages (List[dict]): The input messages.
Returns:
str: The prompt.
"""
role_prefix_mapping = {
'user': '<|im_start|>user\n',
'assistant': '<|im_start|>assistant\n',
'system': '<|im_start|>system\n'
}
role = 'user'
prompt = ''
for message in messages:
role = message['role']
content = message['content']
prompt += role_prefix_mapping[role] + content + '<|im_end|>\n'
if role == 'user':
prompt += '<|im_start|>assistant\n'
return prompt
@torch.no_grad()
def chat(self,
messages: List[dict],
tokenizer: PreTrainedTokenizer,
image_processor: object,
generation_config: dict = None) -> str:
"""Generate a response to the input prompt.
Args:
messages (List[dict]): The input messages.
tokenizer (PreTrainedTokenizer): The tokenizer to use.
image_processor (object): The image processor to use.
generation_config (dict, optional): The generation config.
Defaults to None.
Returns:
str: The generated response.
"""
prompt, images, image_grid_thws = self.construct_prompt(
messages, image_processor
)
images = np.array(images)
images = torch.from_numpy(images).to(self.vision_encoder.device).to(self.vision_encoder.dtype) # noqa
image_grid_thws = np.concatenate(image_grid_thws, axis=0)
image_grid_thws = (
torch.from_numpy(image_grid_thws)
.cuda()
.long()
)
image_features = self.vision_encoder(images, grid_thw=image_grid_thws)
image_features = self.vision_projector(image_features)
model_inputs = tokenizer(prompt, return_tensors='pt')
input_ids = model_inputs['input_ids'].to(self.device)
attention_mask = model_inputs['attention_mask'].to(self.device)
# stop token
eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
# image token
image_token_id = tokenizer.convert_tokens_to_ids("<|image_pad|>")
generation_config.update(
{
'eos_token_id': eos_token_id,
}
)
outputs = self.generate(
input_ids=input_ids,
image_grid_thws=image_grid_thws,
attention_mask=attention_mask,
image_features=[image_features],
image_token_id=image_token_id,
**generation_config
)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return response
def _split_input_ids(self, input_ids, special_token):
special_pos = input_ids == special_token
pos = (special_pos[:-1] != special_pos[1:]).nonzero() + 1
if pos.shape[0] % 2 != 0:
pos = torch.cat([torch.tensor([[0]]).to(pos.device), pos])
pos = pos.reshape(-1, 2).tolist()
return pos
def generate(self,
input_ids: torch.LongTensor,
image_grid_thws: torch.LongTensor,
attention_mask: torch.LongTensor,
image_features: List[torch.Tensor],
image_token_id: int,
generation_config: Optional[dict] = None,
output_hidden_states: Optional[bool] = None,
**generate_kwargs) -> torch.LongTensor:
input_embeddings = self.llm.model.embed_tokens(input_ids)
batch_size = input_ids.shape[0]
assert len(image_features) == batch_size
for i in range(batch_size):
pos = self._split_input_ids(input_ids[i], image_token_id)
assert len(pos) == len(image_grid_thws)
image_pos = [
int(image_grid_thw[1] * image_grid_thw[2] / 4)
for image_grid_thw in image_grid_thws
]
image_pos.insert(0, 0)
image_pos = np.cumsum(image_pos)
for j, (start, end) in enumerate(pos):
input_embeddings[i, start:end] = \
image_features[i][image_pos[j]:image_pos[j+1]]
outputs = self.llm.generate(
inputs_embeds=input_embeddings,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
use_cache=True,
**generate_kwargs
)
return outputs
def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
"""
Encodes images into continuous embeddings that can be forwarded to the language model.
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
The tensors corresponding to the input images.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
"""
pixel_values = pixel_values.type(self.visual.dtype)
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
image_embeds = torch.split(image_embeds, split_sizes)
return image_embeds
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Optional[Any],
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if inputs_embeds is None:
inputs_embeds = self.llm.get_input_embeddings()(input_ids)
if pixel_values is not None:
image_embeds = self.process_images(pixel_values, image_grid_thw)
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
n_image_features = image_embeds.shape[0]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
image_mask = (
(input_ids == self.config.image_token_id)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
outputs = self.llm.forward(
input_ids=None,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
position_ids=position_ids,
past_key_values=past_key_values,
labels=labels,
use_cache=True,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
return Qwen2VLCausalLMOutputWithPast(
loss=outputs.loss,
logits=outputs.logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions
)