Ristretto-3B / modeling_ristretto.py
maojialiang
upload model
b61b9f8
# --------------------------------------------------------
# Ristretto
# Copyright (c) 2025 LiAutoAD
# Licensed under The MIT License
# --------------------------------------------------------
import copy
from typing import Any, List, Optional, Tuple, Union
import torch.distributed as dist
import torch.utils.checkpoint
import transformers
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import (GenerationConfig, LlamaConfig,
LlamaForCausalLM, PretrainedConfig,
Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig,
SiglipVisionModel)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.trainer_pt_utils import LabelSmoother
from transformers.utils import logging
from .conversation import get_conv_template
from .projector import TokenAdaptiveProjector
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
logger = logging.get_logger(__name__)
logger.setLevel(logging.INFO)
def version_cmp(v1, v2, op='eq'):
import operator
from packaging import version
op_func = getattr(operator, op)
return op_func(version.parse(v1), version.parse(v2))
class RistrettoConfig(PretrainedConfig):
model_type = 'ristretto'
is_composition = True
def __init__(
self,
vision_config=dict(model_type='siglip_vision_model'),
llm_config=dict(architectures=['Qwen2ForCausalLM']),
pad2square=False,
select_layer=-1,
force_image_size=None,
num_image_token=256,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=6,
**kwargs):
super().__init__(**kwargs)
if vision_config["model_type"] == "siglip_vision_model":
self.vision_config = SiglipVisionConfig(**vision_config)
else:
raise ValueError('Unsupported architecture: {}'.format(vision_config['model_type']))
if llm_config['architectures'][0] == 'LlamaForCausalLM':
self.llm_config = LlamaConfig(**llm_config)
elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
self.llm_config = Qwen2Config(**llm_config)
else:
raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
self.pad2square = pad2square
self.select_layer = select_layer
self.force_image_size = force_image_size
self.num_image_token = num_image_token
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
logger.info(f'vision_select_layer: {self.select_layer}')
logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output['vision_config'] = self.vision_config.to_dict()
output['llm_config'] = self.llm_config.to_dict()
output['model_type'] = self.__class__.model_type
output['pad2square'] = self.pad2square
output['select_layer'] = self.select_layer
output['force_image_size'] = self.force_image_size
output['num_image_token'] = self.num_image_token
output['template'] = self.template
output['dynamic_image_size'] = self.dynamic_image_size
output['use_thumbnail'] = self.use_thumbnail
output['min_dynamic_patch'] = self.min_dynamic_patch
output['max_dynamic_patch'] = self.max_dynamic_patch
return output
class RistrettoModel(PreTrainedModel):
config_class = RistrettoConfig
main_input_name = 'pixel_values'
_no_split_modules = ['SiglipVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
_supports_flash_attn_2 = True
_keys_to_ignore_on_save = []
def __init__(self, config: RistrettoConfig, vision_model=None, language_model=None):
super().__init__(config)
assert version_cmp(transformers.__version__, '4.37.0', 'ge')
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.image_size = image_size
self.patch_size = patch_size
self.select_layer = config.select_layer
self.template = config.template
self.num_image_token = config.num_image_token
self.llm_arch_name = config.llm_config.architectures[0]
self.vision_model_type = config.vision_config.model_type
if vision_model is not None:
self.vision_model = vision_model
else:
if config.vision_config.model_type == 'siglip_vision_model':
self.vision_model = SiglipVisionModel(config.vision_config)
else:
raise NotImplementedError(f'{config.vision_config.model_type} is not implemented.')
if language_model is not None:
self.language_model = language_model
else:
if config.llm_config.architectures[0] == 'LlamaForCausalLM':
self.language_model = LlamaForCausalLM(config.llm_config)
elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
self.language_model = Qwen2ForCausalLM(config.llm_config)
else:
raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.llm_config.hidden_size
self.projector = TokenAdaptiveProjector(
vit_hidden_size=vit_hidden_size,
llm_hidden_size=llm_hidden_size,
num_image_token=self.num_image_token,
)
self.img_context_token_id = None
self.conv_template = get_conv_template(self.template)
self.system_message = self.conv_template.system_message
self.num_samples = 0
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
image_flags: Optional[torch.LongTensor] = None,
num_image_tokens: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
num_image_token = None
if num_image_tokens is not None:
assert num_image_tokens.unique().shape[0] == 1, 'num_image_tokens must be the same for all samples in a batch'
num_image_token = num_image_tokens[0].item()
image_flags = image_flags.squeeze(-1)
input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
vit_embeds = self.extract_feature(pixel_values, num_image_token)
vit_embeds = vit_embeds[image_flags == 1]
vit_batch_size = pixel_values.shape[0]
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.img_context_token_id)
try:
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
ignore_flag = False
except Exception as e:
vit_embeds = vit_embeds.reshape(-1, C)
print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
f'vit_embeds.shape={vit_embeds.shape}')
n_token = selected.sum()
input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
ignore_flag = True
input_embeds = input_embeds.reshape(B, N, C)
outputs = self.language_model(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss(reduction='none')
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Calc loss weight
loss_token_mask = shift_labels != loss_fct.ignore_index
loss_token_num = loss_token_mask.sum(dim=1, keepdim=True).float()
loss_token_weight = 1. / (loss_token_num.expand_as(shift_labels) ** 0.5 + 1e-6)
# Flatten the tokens
shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
shift_labels = shift_labels.view(-1)
loss_token_weight = loss_token_weight.view(-1)
loss_token_mask = loss_token_mask.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
all_token_weight = (loss_token_weight * loss_token_mask.float()).sum()
dist.all_reduce(all_token_weight, op=dist.ReduceOp.SUM)
loss = (loss * loss_token_weight * loss_token_mask.float()).sum() / (all_token_weight + 1e-6)
# Hack for DDP training, since the loss is reduced in the forward function
loss = loss * dist.get_world_size()
if ignore_flag:
loss = loss * 0.0
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def extract_feature(self, pixel_values, num_image_token=None):
if self.select_layer == -1:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=False,
return_dict=True).last_hidden_state
else:
vit_embeds = self.vision_model(
pixel_values=pixel_values,
output_hidden_states=True,
return_dict=True).hidden_states[self.select_layer]
vit_embeds = self.projector(vit_embeds, num_image_token=num_image_token)
return vit_embeds
def batch_chat(self, tokenizer, pixel_values, questions, generation_config, num_patches_list=None,
history=None, return_history=False, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>',
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>', verbose=False, image_counts=None):
if history is not None or return_history:
print('Now multi-turn chat is not supported in batch_chat.')
raise NotImplementedError
if image_counts is not None:
num_patches_list = image_counts
print('Warning: `image_counts` is deprecated. Please use `num_patches_list` instead.')
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_context_token_id = img_context_token_id
if verbose and pixel_values is not None:
image_bs = pixel_values.shape[0]
print(f'dynamic ViT batch size: {image_bs}')
queries = []
for idx, _num_patches_list in enumerate(num_patches_list):
question = questions[idx]
if pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
template = get_conv_template(self.template)
template.system_message = self.system_message
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
for num_patches in _num_patches_list:
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
queries.append(query)
tokenizer.padding_side = 'left'
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
input_ids = model_inputs['input_ids'].cuda()
attention_mask = model_inputs['attention_mask'].cuda()
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
generation_config['eos_token_id'] = eos_token_id
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
**generation_config
)
responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
responses = [response.split(template.sep)[0].strip() for response in responses]
return responses
def chat(self, tokenizer, pixel_values, question, generation_config, num_image_token=None, history=None, return_history=False,
num_patches_list=None, IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>',
verbose=False):
if history is None and pixel_values is not None and '<image>' not in question:
question = '<image>\n' + question
if num_patches_list is None:
num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else []
assert pixel_values is None or len(pixel_values) == sum(num_patches_list)
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.img_context_token_id = img_context_token_id
template = get_conv_template(self.template)
template.system_message = self.system_message
eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
history = [] if history is None else history
for (old_question, old_answer) in history:
template.append_message(template.roles[0], old_question)
template.append_message(template.roles[1], old_answer)
template.append_message(template.roles[0], question)
template.append_message(template.roles[1], None)
query = template.get_prompt()
if verbose and pixel_values is not None:
image_bs = pixel_values.shape[0]
print(f'dynamic ViT batch size: {image_bs}')
if num_image_token is None:
num_image_token = self.num_image_token
for num_patches in num_patches_list:
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * num_image_token * num_patches + IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
model_inputs = tokenizer(query, return_tensors='pt')
input_ids = model_inputs['input_ids'].cuda()
attention_mask = model_inputs['attention_mask'].cuda()
generation_config['eos_token_id'] = tokenizer.eos_token_id
generation_config['pad_token_id'] = tokenizer.pad_token_id
generation_output = self.generate(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
num_image_token=num_image_token,
**generation_config
)
response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
response = response.split(template.sep)[0].strip()
history.append((question, response))
if return_history:
return response, history
else:
query_to_print = query.replace(IMG_CONTEXT_TOKEN, '')
query_to_print = query_to_print.replace(f'{IMG_START_TOKEN}{IMG_END_TOKEN}', '<image>')
if verbose:
print(query_to_print, response)
return response
@torch.no_grad()
def generate(
self,
pixel_values: Optional[torch.FloatTensor] = None,
input_ids: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
visual_features: Optional[torch.FloatTensor] = None,
num_image_token: Optional[int] = None,
generation_config: Optional[GenerationConfig] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**generate_kwargs,
) -> torch.LongTensor:
assert self.img_context_token_id is not None
if pixel_values is not None:
if visual_features is not None:
vit_embeds = visual_features
else:
vit_embeds = self.extract_feature(pixel_values, num_image_token)
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)
input_ids = input_ids.reshape(B * N)
selected = (input_ids == self.img_context_token_id)
assert selected.sum() != 0
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
input_embeds = input_embeds.reshape(B, N, C)
else:
input_embeds = self.language_model.get_input_embeddings()(input_ids)
outputs = self.language_model.generate(
inputs_embeds=input_embeds,
attention_mask=attention_mask,
generation_config=generation_config,
output_hidden_states=output_hidden_states,
use_cache=True,
**generate_kwargs,
)
return outputs