tiny-random-minicpmv-2_6 / modeling_minicpmv.py
katuni4ka's picture
Update modeling_minicpmv.py
80ae08b verified
import json
import math
from copy import deepcopy
from threading import Thread
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2ForCausalLM, Qwen2PreTrainedModel, TextIteratorStreamer
from .configuration_minicpm import MiniCPMVConfig
from .modeling_navit_siglip import SiglipVisionTransformer
from .resampler import Resampler
class MiniCPMVPreTrainedModel(Qwen2PreTrainedModel):
config_class = MiniCPMVConfig
class MiniCPMV(MiniCPMVPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.llm = Qwen2ForCausalLM(config)
self.vpm = self.init_vision_module()
self.vision_dim = self.vpm.embed_dim
self.embed_dim = self.llm.config.hidden_size
self.resampler = self.init_resampler(self.embed_dim, self.vision_dim)
self.processor = None
self.terminators = ["<|im_end|>", "<|endoftext|>"]
def init_vision_module(self):
# same as HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit add tgt_sizes
if self.config._attn_implementation == "flash_attention_2":
self.config.vision_config._attn_implementation = "flash_attention_2"
else:
# not suport sdpa
self.config.vision_config._attn_implementation = "eager"
model = SiglipVisionTransformer(self.config.vision_config)
if self.config.drop_vision_last_layer:
model.encoder.layers = model.encoder.layers[:-1]
setattr(model, "embed_dim", model.embeddings.embed_dim)
setattr(model, "patch_size", model.embeddings.patch_size)
return model
def init_resampler(self, embed_dim, vision_dim):
return Resampler(
num_queries=self.config.query_num,
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_dim,
adaptive=True,
)
def get_input_embeddings(self):
return self.llm.get_input_embeddings()
def set_input_embeddings(self, value):
self.llm.embed_tokens = value
def get_output_embeddings(self):
return self.llm.lm_head
def set_output_embeddings(self, new_embeddings):
self.llm.lm_head = new_embeddings
def set_decoder(self, decoder):
self.llm = decoder
def get_decoder(self):
return self.llm
def get_vllm_embedding(self, data):
if "vision_hidden_states" not in data:
dtype = self.llm.model.embed_tokens.weight.dtype
device = self.llm.model.embed_tokens.weight.device
tgt_sizes = data["tgt_sizes"]
pixel_values_list = data["pixel_values"]
vision_hidden_states = []
all_pixel_values = []
img_cnt = []
for pixel_values in pixel_values_list:
img_cnt.append(len(pixel_values))
all_pixel_values.extend([i.flatten(end_dim=1).permute(1, 0) for i in pixel_values])
# exist image
if all_pixel_values:
tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
max_patches = torch.max(tgt_sizes[:, 0] * tgt_sizes[:, 1])
all_pixel_values = torch.nn.utils.rnn.pad_sequence(
all_pixel_values, batch_first=True, padding_value=0.0
)
B, L, _ = all_pixel_values.shape
all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L)
patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=device)
for i in range(B):
patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True
vision_batch_size = self.config.vision_batch_size
all_pixel_values = all_pixel_values.type(dtype)
if B > vision_batch_size:
hs = []
for i in range(0, B, vision_batch_size):
start_idx = i
end_idx = i + vision_batch_size
tmp_hs = self.vpm(
all_pixel_values[start_idx:end_idx],
patch_attention_mask=patch_attn_mask[start_idx:end_idx],
tgt_sizes=tgt_sizes[start_idx:end_idx],
).last_hidden_state
hs.append(tmp_hs)
vision_embedding = torch.cat(hs, dim=0)
else:
vision_embedding = self.vpm(
all_pixel_values, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes
).last_hidden_state
vision_embedding = self.resampler(vision_embedding, tgt_sizes)
start = 0
for pixel_values in pixel_values_list:
img_cnt = len(pixel_values)
if img_cnt > 0:
vision_hidden_states.append(vision_embedding[start : start + img_cnt])
start += img_cnt
else:
vision_hidden_states.append([])
else: # no image
if self.training:
dummy_image = torch.zeros((1, 3, 224, 224), device=device, dtype=dtype)
tgt_sizes = torch.Tensor(
[[(224 // self.config.patch_size), math.ceil(224 / self.config.patch_size)]]
).type(torch.int32)
dummy_feature = self.resampler(self.vpm(dummy_image).last_hidden_state, tgt_sizes)
else:
dummy_feature = []
for _ in range(len(pixel_values_list)):
vision_hidden_states.append(dummy_feature)
else:
vision_hidden_states = data["vision_hidden_states"]
if hasattr(self.llm.config, "scale_emb"):
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"]) * self.llm.config.scale_emb
else:
vllm_embedding = self.llm.model.embed_tokens(data["input_ids"])
vision_hidden_states = [
i.type(vllm_embedding.dtype) if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
]
bs = len(data["input_ids"])
for i in range(bs):
cur_vs_hs = vision_hidden_states[i]
if len(cur_vs_hs) > 0:
cur_vllm_emb = vllm_embedding[i]
cur_image_bound = data["image_bound"][i]
if len(cur_image_bound) > 0:
image_indices = torch.stack(
[torch.arange(r[0], r[1], dtype=torch.long) for r in cur_image_bound]
).to(vllm_embedding.device)
cur_vllm_emb.scatter_(
0,
image_indices.view(-1, 1).repeat(1, cur_vllm_emb.shape[-1]),
cur_vs_hs.view(-1, cur_vs_hs.shape[-1]),
)
elif self.training:
cur_vllm_emb += cur_vs_hs[0].mean() * 0
return vllm_embedding, vision_hidden_states
def forward(self, data, **kwargs):
vllm_embedding, vision_hidden_states = self.get_vllm_embedding(data)
position_ids = data["position_ids"]
if position_ids.dtype != torch.int64:
position_ids = position_ids.long()
return self.llm(input_ids=None, position_ids=position_ids, inputs_embeds=vllm_embedding, **kwargs)
def _decode(self, inputs_embeds, tokenizer, attention_mask, decode_text=False, **kwargs):
terminators = None
if tokenizer is not None:
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
kwargs.pop("image_sizes", None)
output = self.llm.generate(
inputs_embeds=inputs_embeds,
# pad_token_id=0,
eos_token_id=terminators,
attention_mask=attention_mask,
**kwargs,
)
if decode_text:
return self._decode_text(output, tokenizer)
return output
def _decode_stream(self, inputs_embeds, tokenizer, **kwargs):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
streamer = TextIteratorStreamer(tokenizer=tokenizer)
generation_kwargs = {
"inputs_embeds": inputs_embeds,
"pad_token_id": 0,
"eos_token_id": terminators,
"streamer": streamer,
}
generation_kwargs.update(kwargs)
thread = Thread(target=self.llm.generate, kwargs=generation_kwargs)
thread.start()
return streamer
def _decode_text(self, result_ids, tokenizer):
terminators = [tokenizer.convert_tokens_to_ids(i) for i in self.terminators]
result_text = []
for result in result_ids:
result = result[result != 0]
if result[0] == tokenizer.bos_id:
result = result[1:]
if result[-1] in terminators:
result = result[:-1]
result_text.append(tokenizer.decode(result).strip())
return result_text
def generate(
self,
input_ids=None,
pixel_values=None,
tgt_sizes=None,
image_bound=None,
attention_mask=None,
tokenizer=None,
vision_hidden_states=None,
return_vision_hidden_states=False,
stream=False,
decode_text=False,
**kwargs,
):
assert input_ids is not None
assert len(input_ids) == len(pixel_values)
model_inputs = {
"input_ids": input_ids,
"image_bound": image_bound,
}
if vision_hidden_states is None:
model_inputs["pixel_values"] = pixel_values
model_inputs["tgt_sizes"] = tgt_sizes
else:
model_inputs["vision_hidden_states"] = vision_hidden_states
with torch.inference_mode():
(
model_inputs["inputs_embeds"],
vision_hidden_states,
) = self.get_vllm_embedding(model_inputs)
if stream:
result = self._decode_stream(model_inputs["inputs_embeds"], tokenizer, **kwargs)
else:
result = self._decode(
model_inputs["inputs_embeds"], tokenizer, attention_mask, decode_text=decode_text, **kwargs
)
if return_vision_hidden_states:
return result, vision_hidden_states
return result
def chat(
self,
image,
msgs,
tokenizer,
processor=None,
vision_hidden_states=None,
max_new_tokens=2048,
min_new_tokens=0,
sampling=True,
max_inp_length=8192,
system_prompt="",
stream=False,
max_slice_nums=None,
use_image_id=None,
**kwargs,
):
if isinstance(msgs[0], list):
batched = True
else:
batched = False
msgs_list = msgs
images_list = image
if batched is False:
images_list, msgs_list = [images_list], [msgs_list]
else:
assert images_list is None, "Please integrate image to msgs when using batch inference."
images_list = [None] * len(msgs_list)
assert len(images_list) == len(msgs_list), "The batch dim of images_list and msgs_list should be the same."
if processor is None:
if self.processor is None:
self.processor = AutoProcessor.from_pretrained(self.config._name_or_path, trust_remote_code=True)
processor = self.processor
assert (
self.config.query_num == processor.image_processor.image_feature_size
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
assert (
self.config.patch_size == processor.image_processor.patch_size
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
assert (
self.config.use_image_id == processor.image_processor.use_image_id
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
assert (
self.config.slice_config.max_slice_nums == processor.image_processor.max_slice_nums
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
assert (
self.config.slice_mode == processor.image_processor.slice_mode
), "These two values should be the same. Check `config.json` and `preprocessor_config.json`."
prompts_lists = []
input_images_lists = []
for image, msgs in zip(images_list, msgs_list):
if isinstance(msgs, str):
msgs = json.loads(msgs)
copy_msgs = deepcopy(msgs)
assert len(msgs) > 0, "msgs is empty"
assert sampling or not stream, "if use stream mode, make sure sampling=True"
if image is not None and isinstance(copy_msgs[0]["content"], str):
copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]]
images = []
for i, msg in enumerate(copy_msgs):
role = msg["role"]
content = msg["content"]
assert role in ["user", "assistant"]
if i == 0:
assert role == "user", "The role of first msg should be user"
if isinstance(content, str):
content = [content]
cur_msgs = []
for c in content:
if isinstance(c, Image.Image):
images.append(c)
cur_msgs.append("(<image>./</image>)")
elif isinstance(c, str):
cur_msgs.append(c)
msg["content"] = "\n".join(cur_msgs)
if system_prompt:
sys_msg = {"role": "system", "content": system_prompt}
copy_msgs = [sys_msg] + copy_msgs
prompts_lists.append(
processor.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True)
)
input_images_lists.append(images)
inputs = processor(
prompts_lists,
input_images_lists,
max_slice_nums=max_slice_nums,
use_image_id=use_image_id,
return_tensors="pt",
max_length=max_inp_length,
).to(self.device)
if sampling:
generation_config = {
"top_p": 0.8,
"top_k": 100,
"temperature": 0.7,
"do_sample": True,
"repetition_penalty": 1.05,
}
else:
generation_config = {
"num_beams": 3,
"repetition_penalty": 1.2,
}
if min_new_tokens > 0:
generation_config["min_new_tokens"] = min_new_tokens
generation_config.update((k, kwargs[k]) for k in generation_config.keys() & kwargs.keys())
inputs.pop("image_sizes", None)
with torch.inference_mode():
res = self.generate(
**inputs,
tokenizer=tokenizer,
max_new_tokens=max_new_tokens,
vision_hidden_states=vision_hidden_states,
stream=stream,
decode_text=True,
**generation_config,
)
if stream:
def stream_gen():
for text in res:
for term in self.terminators:
text = text.replace(term, "")
yield text
return stream_gen()
else:
if batched:
answer = res
else:
answer = res[0]
return answer