import os import math import copy import torch import torch.nn.functional as F from torch.nn import CrossEntropyLoss from PIL import Image from functools import partial from typing import List, Optional, Tuple, Union, Dict from dataclasses import dataclass import transformers from transformers.modeling_outputs import ModelOutput from transformers.modeling_utils import PreTrainedModel from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, Qwen2Config, SiglipVisionModel from .adapters import AdapterSigLIP from .mm_constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX from .processing_FlashVL import tokenizer_image_token_qwen from .configuration_FlashVLStatic import FlashVLStaticConfig @dataclass class FlashVLStaticOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None class FlashVLStatic(PreTrainedModel): config_class = FlashVLStaticConfig def __init__(self, config): super().__init__(config) self.llm = AutoModelForCausalLM.from_config(config.llm_config, trust_remote_code=True) self.vit = SiglipVisionModel(config.vision_config).vision_model self.adp = AdapterSigLIP(config) self.image_token_num = config.image_token_num self.image_size = config.vision_config.image_size def merge_text_image_tokens(self, inputs): input_ids, image_features, targets, attn_mask, loss_mask = inputs micro_batch_size, tokens_len = input_ids.shape device = input_ids.device img_rows, img_cols = torch.where(input_ids == IMAGE_TOKEN_INDEX) image_idxs = {i: [] for i in range(micro_batch_size)} for row, col in zip(img_rows.tolist(), img_cols.tolist()): image_idxs[row].append(col) for row in range(micro_batch_size): image_idxs[row] = sorted(image_idxs[row]) split_sizes = [] for row in range(micro_batch_size): image_num = len(image_idxs[row]) if image_num == 0: split_sizes.append(tokens_len) continue if image_idxs[row][0] != 0: split_sizes.append(image_idxs[row][0]) for idx in range(image_num - 1): split_sizes.append(self.image_token_num) if image_idxs[row][idx + 1] > image_idxs[row][idx] + self.image_token_num: split_sizes.append(image_idxs[row][idx + 1] - (image_idxs[row][idx] + self.image_token_num)) if image_idxs[row][image_num - 1] + self.image_token_num >= tokens_len: split_sizes.append(tokens_len - image_idxs[row][image_num - 1]) else: split_sizes.append(self.image_token_num) split_sizes.append(tokens_len - (image_idxs[row][image_num - 1] + self.image_token_num)) input_ids_noim = torch.where(input_ids < 0, 151643, input_ids) input_ids_noim = input_ids_noim.view(-1) input_embeds = self.llm.model.embed_tokens(input_ids_noim) input_embeds_split = torch.split(input_embeds, split_sizes, dim=0) vl_embeds_list = [] cur_language_idx = 0 cur_image_idx = 0 for row in range(micro_batch_size): image_num = len(image_idxs[row]) if image_num == 0: vl_embeds_list.append(input_embeds_split[cur_language_idx]) cur_language_idx += 1 vl_embeds_list.append(image_features[cur_image_idx][0:0]) cur_image_idx += 1 continue if image_idxs[row][0] != 0: vl_embeds_list.append(input_embeds_split[cur_language_idx]) cur_language_idx += 1 for idx in range(image_num - 1): vl_embeds_list.append(image_features[cur_image_idx]) cur_language_idx += 1 cur_image_idx += 1 if image_idxs[row][idx + 1] > image_idxs[row][idx] + self.image_token_num: vl_embeds_list.append(input_embeds_split[cur_language_idx]) cur_language_idx += 1 if image_idxs[row][image_num - 1] + self.image_token_num >= tokens_len: vl_embeds_list.append(image_features[cur_image_idx][0 : tokens_len - image_idxs[row][image_num - 1]]) cur_language_idx += 1 cur_image_idx += 1 else: vl_embeds_list.append(image_features[cur_image_idx]) cur_language_idx += 1 cur_image_idx += 1 vl_embeds_list.append(input_embeds_split[cur_language_idx]) cur_language_idx += 1 vl_embeds = torch.cat(vl_embeds_list) vl_embeds = vl_embeds.view(micro_batch_size, tokens_len, vl_embeds.shape[-1]) return (input_ids, vl_embeds, targets, attn_mask, loss_mask) def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, local_pos_batch: Optional[torch.LongTensor] = None, image_idx_batch: Optional[torch.Tensor] = None, loss_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, ): inputs = [input_ids, pixel_values, labels, attention_mask, loss_mask] if isinstance(inputs[1], list): pixel_values = [p.bfloat16() for p in inputs[1]] else: pixel_values = inputs[1].bfloat16() img_token = self.vit.forward(pixel_values) if hasattr(img_token, 'last_hidden_state'): img_token = img_token.last_hidden_state inputs = self.adp(inputs[:1]+[img_token]+inputs[2:]) inputs = self.merge_text_image_tokens(inputs) tokens, hidden_states, targets, attn_mask, loss_mask = inputs outputs = self.llm.forward( inputs_embeds = hidden_states, attention_mask = attn_mask, use_cache = use_cache) lm_logits = outputs.logits loss = None if targets is not None: labels = targets.to(lm_logits.device) shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = CrossEntropyLoss(reduction='none') loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) batch_size = labels.size(0) loss_mask = loss_mask[:, 1:].to(loss.dtype) loss = (loss.view(batch_size, -1) * loss_mask).sum() / loss_mask.sum() return FlashVLStaticOutputWithPast( loss=loss, logits=lm_logits ) def get_input_embeddings(self): return self.llm.get_input_embeddings() def generate( self, input_ids=None, pixel_values=None, attention_mask=None, **kwargs ): image = pixel_values img_token = self.vit.forward(image.bfloat16()) if hasattr(img_token, 'last_hidden_state'): img_token = img_token.last_hidden_state inputs = self.adp(( input_ids.to(self.device), img_token, None, None, None)) inputs = self.merge_text_image_tokens(inputs) tokens, hidden_states, targets, attn_mask, loss_mask = inputs keys_to_pop = ['loss_mask', 'labels','attention_mask'] kwargs = {k: v for k, v in kwargs.items() if k not in keys_to_pop} outputs = self.llm.generate( inputs_embeds=hidden_states.bfloat16(), max_new_tokens=2048, do_sample=False, **kwargs ) return outputs def chat(self, pil_image, messages, answer_prompt=None, do_sample=True, max_new_tokens=256): data={} data['img'] = pil_image data['text_only'] = (pil_image is None) data['messages'] = messages sources = self.to_llava_format(data) sources = [sources] has_image = not sources[0]['text_only'] if has_image: img_list = sources[0]['image'] if not isinstance(img_list, list): img_list = [img_list] image = torch.stack([torch.from_numpy(self.im_trans(i)['pixel_values'][0]) for i in img_list], dim=0) sources = copy.deepcopy([e["conversations"] for e in sources]) data_dict = self.preprocess_qwen( sources, self.tokenizer, has_image=has_image, ) input_ids_data = data_dict["input_ids"][0] data_dict["input_ids"] = [ input_ids_data, ] if not has_image: image = torch.zeros(1, 3, self.image_size, self.image_size) data_dict = dict(tokens=data_dict["input_ids"][0],) img_token = self.vit.forward(image.cuda().bfloat16()) if hasattr(img_token, 'last_hidden_state'): img_token = img_token.last_hidden_state inputs = self.adp(( data_dict['tokens'].unsqueeze(0).to(self.device), img_token, None, None, None)) inputs = self.merge_text_image_tokens(inputs) tokens, hidden_states, targets, attn_mask, loss_mask = inputs outputs = self.llm.generate( inputs_embeds=hidden_states.bfloat16(), return_dict_in_generate=False, max_new_tokens=max_new_tokens, do_sample=do_sample, pad_token_id=False, ) decoded = self.tokenizer.decode(outputs[0]) stop_words_ids = [self.llm.generation_config.bos_token_id, self.llm.generation_config.eos_token_id, self.tokenizer.convert_tokens_to_ids('<|im_start|>')] stop_words = [self.tokenizer.decode(w) for w in stop_words_ids] for stop_word in stop_words: decoded = decoded.replace(stop_word, "").strip() return decoded def preprocess_qwen( self, sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False, max_len=2048, system_message: str = "You are a helpful assistant.",) -> Dict: roles = {"human": "user", "gpt": "assistant"} tokenizer = copy.deepcopy(tokenizer) tokenizer.add_tokens([""], special_tokens=True) image_token_index = tokenizer.convert_tokens_to_ids("") im_start, im_end = tokenizer.additional_special_tokens_ids[:2] unmask_tokens_idx = [198, im_start, im_end] nl_tokens = tokenizer("\n").input_ids chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" tokenizer.chat_template = chat_template input_ids, targets = [], [] for i, source in enumerate(sources): if roles[source[0]["from"]] != roles["human"]: source = source[1:] input_id, target = [], [] input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}]) target += [IGNORE_INDEX] * len(input_id) i=0 for conv in source: try: role = conv["role"] content = conv["content"] except: role = conv["from"] content = conv["value"] role = roles.get(role, role) if i==len(source)-1: conv = [{"role" : role, "content" : content}] encode_id = tokenizer.apply_chat_template(conv,add_generation_prompt=True) else: conv = [{"role" : role, "content" : content}] encode_id = tokenizer.apply_chat_template(conv) i=i+1 if image_token_index in encode_id: encode_id = tokenizer_image_token_qwen(encode_id, tokenizer, image_token_index, image_token_num=self.image_token_num) input_id += encode_id if role in ["user", "system"]: target += [IGNORE_INDEX] * len(encode_id) else: target += encode_id assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" for idx, encode_id in enumerate(input_id): if encode_id in unmask_tokens_idx: target[idx] = encode_id if encode_id == image_token_index: input_id[idx] = IMAGE_TOKEN_INDEX input_ids.append(input_id) targets.append(target) input_ids = torch.tensor(input_ids, dtype=torch.long) targets = torch.tensor(targets, dtype=torch.long) return dict( input_ids=input_ids, labels=targets, ) def to_llava_format(self, data): img_pil = data['img'] messages = data['messages'] text_only = data['text_only'] is_video=False if 'is_video' in data: is_video=data['is_video'] messages.append({'role': 'assistant', 'content': ''}) conversations = [] for i,m in enumerate(messages): if m['role'] == 'user': value = str(m['content']).replace('', '') if i == 0 and not text_only: value = '\n' + value conversations.append({'from': 'human', 'value': value}) elif m['role'] == 'assistant': conversations.append({'from': 'gpt', 'value': str(m['content']).replace('', '')}) else: raise ValueError(f"Wrong role in conversation. {m['role']}") return {'image': img_pil, 'text_only': text_only, 'is_video':is_video, 'conversations': conversations}