FlashVL-2B-Static / modeling_FlashVLStatic.py
FlashVL's picture
Upload folder using huggingface_hub
8155cef verified
import os
import math
import copy
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from PIL import Image
from functools import partial
from typing import List, Optional, Tuple, Union, Dict
from dataclasses import dataclass
import transformers
from transformers.modeling_outputs import ModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, Qwen2Config, SiglipVisionModel
from .adapters import AdapterSigLIP
from .mm_constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
from .processing_FlashVL import tokenizer_image_token_qwen
from .configuration_FlashVLStatic import FlashVLStaticConfig
@dataclass
class FlashVLStaticOutputWithPast(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
class FlashVLStatic(PreTrainedModel):
config_class = FlashVLStaticConfig
def __init__(self, config):
super().__init__(config)
self.llm = AutoModelForCausalLM.from_config(config.llm_config, trust_remote_code=True)
self.vit = SiglipVisionModel(config.vision_config).vision_model
self.adp = AdapterSigLIP(config)
self.image_token_num = config.image_token_num
self.image_size = config.vision_config.image_size
def merge_text_image_tokens(self, inputs):
input_ids, image_features, targets, attn_mask, loss_mask = inputs
micro_batch_size, tokens_len = input_ids.shape
device = input_ids.device
img_rows, img_cols = torch.where(input_ids == IMAGE_TOKEN_INDEX)
image_idxs = {i: [] for i in range(micro_batch_size)}
for row, col in zip(img_rows.tolist(), img_cols.tolist()):
image_idxs[row].append(col)
for row in range(micro_batch_size):
image_idxs[row] = sorted(image_idxs[row])
split_sizes = []
for row in range(micro_batch_size):
image_num = len(image_idxs[row])
if image_num == 0:
split_sizes.append(tokens_len)
continue
if image_idxs[row][0] != 0:
split_sizes.append(image_idxs[row][0])
for idx in range(image_num - 1):
split_sizes.append(self.image_token_num)
if image_idxs[row][idx + 1] > image_idxs[row][idx] + self.image_token_num:
split_sizes.append(image_idxs[row][idx + 1] - (image_idxs[row][idx] + self.image_token_num))
if image_idxs[row][image_num - 1] + self.image_token_num >= tokens_len:
split_sizes.append(tokens_len - image_idxs[row][image_num - 1])
else:
split_sizes.append(self.image_token_num)
split_sizes.append(tokens_len - (image_idxs[row][image_num - 1] + self.image_token_num))
input_ids_noim = torch.where(input_ids < 0, 151643, input_ids)
input_ids_noim = input_ids_noim.view(-1)
input_embeds = self.llm.model.embed_tokens(input_ids_noim)
input_embeds_split = torch.split(input_embeds, split_sizes, dim=0)
vl_embeds_list = []
cur_language_idx = 0
cur_image_idx = 0
for row in range(micro_batch_size):
image_num = len(image_idxs[row])
if image_num == 0:
vl_embeds_list.append(input_embeds_split[cur_language_idx])
cur_language_idx += 1
vl_embeds_list.append(image_features[cur_image_idx][0:0])
cur_image_idx += 1
continue
if image_idxs[row][0] != 0:
vl_embeds_list.append(input_embeds_split[cur_language_idx])
cur_language_idx += 1
for idx in range(image_num - 1):
vl_embeds_list.append(image_features[cur_image_idx])
cur_language_idx += 1
cur_image_idx += 1
if image_idxs[row][idx + 1] > image_idxs[row][idx] + self.image_token_num:
vl_embeds_list.append(input_embeds_split[cur_language_idx])
cur_language_idx += 1
if image_idxs[row][image_num - 1] + self.image_token_num >= tokens_len:
vl_embeds_list.append(image_features[cur_image_idx][0 : tokens_len - image_idxs[row][image_num - 1]])
cur_language_idx += 1
cur_image_idx += 1
else:
vl_embeds_list.append(image_features[cur_image_idx])
cur_language_idx += 1
cur_image_idx += 1
vl_embeds_list.append(input_embeds_split[cur_language_idx])
cur_language_idx += 1
vl_embeds = torch.cat(vl_embeds_list)
vl_embeds = vl_embeds.view(micro_batch_size, tokens_len, vl_embeds.shape[-1])
return (input_ids, vl_embeds, targets, attn_mask, loss_mask)
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
local_pos_batch: Optional[torch.LongTensor] = None,
image_idx_batch: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
):
inputs = [input_ids, pixel_values, labels, attention_mask, loss_mask]
if isinstance(inputs[1], list):
pixel_values = [p.bfloat16() for p in inputs[1]]
else:
pixel_values = inputs[1].bfloat16()
img_token = self.vit.forward(pixel_values)
if hasattr(img_token, 'last_hidden_state'):
img_token = img_token.last_hidden_state
inputs = self.adp(inputs[:1]+[img_token]+inputs[2:])
inputs = self.merge_text_image_tokens(inputs)
tokens, hidden_states, targets, attn_mask, loss_mask = inputs
outputs = self.llm.forward(
inputs_embeds = hidden_states,
attention_mask = attn_mask,
use_cache = use_cache)
lm_logits = outputs.logits
loss = None
if targets is not None:
labels = targets.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(reduction='none')
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
batch_size = labels.size(0)
loss_mask = loss_mask[:, 1:].to(loss.dtype)
loss = (loss.view(batch_size, -1) * loss_mask).sum() / loss_mask.sum()
return FlashVLStaticOutputWithPast(
loss=loss,
logits=lm_logits
)
def get_input_embeddings(self):
return self.llm.get_input_embeddings()
def generate(
self,
input_ids=None,
pixel_values=None,
attention_mask=None,
**kwargs
):
image = pixel_values
img_token = self.vit.forward(image.bfloat16())
if hasattr(img_token, 'last_hidden_state'):
img_token = img_token.last_hidden_state
inputs = self.adp((
input_ids.to(self.device),
img_token,
None, None, None))
inputs = self.merge_text_image_tokens(inputs)
tokens, hidden_states, targets, attn_mask, loss_mask = inputs
keys_to_pop = ['loss_mask', 'labels','attention_mask']
kwargs = {k: v for k, v in kwargs.items() if k not in keys_to_pop}
outputs = self.llm.generate(
inputs_embeds=hidden_states.bfloat16(),
max_new_tokens=2048,
do_sample=False,
**kwargs
)
return outputs
def chat(self, pil_image, messages, answer_prompt=None, do_sample=True, max_new_tokens=256):
data={}
data['img'] = pil_image
data['text_only'] = (pil_image is None)
data['messages'] = messages
sources = self.to_llava_format(data)
sources = [sources]
has_image = not sources[0]['text_only']
if has_image:
img_list = sources[0]['image']
if not isinstance(img_list, list):
img_list = [img_list]
image = torch.stack([torch.from_numpy(self.im_trans(i)['pixel_values'][0]) for i in img_list], dim=0)
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = self.preprocess_qwen(
sources,
self.tokenizer,
has_image=has_image,
)
input_ids_data = data_dict["input_ids"][0]
data_dict["input_ids"] = [ input_ids_data, ]
if not has_image:
image = torch.zeros(1, 3, self.image_size, self.image_size)
data_dict = dict(tokens=data_dict["input_ids"][0],)
img_token = self.vit.forward(image.cuda().bfloat16())
if hasattr(img_token, 'last_hidden_state'):
img_token = img_token.last_hidden_state
inputs = self.adp((
data_dict['tokens'].unsqueeze(0).to(self.device),
img_token,
None, None, None))
inputs = self.merge_text_image_tokens(inputs)
tokens, hidden_states, targets, attn_mask, loss_mask = inputs
outputs = self.llm.generate(
inputs_embeds=hidden_states.bfloat16(),
return_dict_in_generate=False,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
pad_token_id=False,
)
decoded = self.tokenizer.decode(outputs[0])
stop_words_ids = [self.llm.generation_config.bos_token_id,
self.llm.generation_config.eos_token_id,
self.tokenizer.convert_tokens_to_ids('<|im_start|>')]
stop_words = [self.tokenizer.decode(w) for w in stop_words_ids]
for stop_word in stop_words:
decoded = decoded.replace(stop_word, "").strip()
return decoded
def preprocess_qwen(
self,
sources,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False,
max_len=2048,
system_message: str = "You are a helpful assistant.",) -> Dict:
roles = {"human": "user", "gpt": "assistant"}
tokenizer = copy.deepcopy(tokenizer)
tokenizer.add_tokens(["<image>"], special_tokens=True)
image_token_index = tokenizer.convert_tokens_to_ids("<image>")
im_start, im_end = tokenizer.additional_special_tokens_ids[:2]
unmask_tokens_idx = [198, im_start, im_end]
nl_tokens = tokenizer("\n").input_ids
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer.chat_template = chat_template
input_ids, targets = [], []
for i, source in enumerate(sources):
if roles[source[0]["from"]] != roles["human"]:
source = source[1:]
input_id, target = [], []
input_id += tokenizer.apply_chat_template([{"role" : "system", "content" : system_message}])
target += [IGNORE_INDEX] * len(input_id)
i=0
for conv in source:
try:
role = conv["role"]
content = conv["content"]
except:
role = conv["from"]
content = conv["value"]
role = roles.get(role, role)
if i==len(source)-1:
conv = [{"role" : role, "content" : content}]
encode_id = tokenizer.apply_chat_template(conv,add_generation_prompt=True)
else:
conv = [{"role" : role, "content" : content}]
encode_id = tokenizer.apply_chat_template(conv)
i=i+1
if image_token_index in encode_id:
encode_id = tokenizer_image_token_qwen(encode_id, tokenizer, image_token_index, image_token_num=self.image_token_num)
input_id += encode_id
if role in ["user", "system"]:
target += [IGNORE_INDEX] * len(encode_id)
else:
target += encode_id
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}"
for idx, encode_id in enumerate(input_id):
if encode_id in unmask_tokens_idx:
target[idx] = encode_id
if encode_id == image_token_index:
input_id[idx] = IMAGE_TOKEN_INDEX
input_ids.append(input_id)
targets.append(target)
input_ids = torch.tensor(input_ids, dtype=torch.long)
targets = torch.tensor(targets, dtype=torch.long)
return dict(
input_ids=input_ids,
labels=targets,
)
def to_llava_format(self, data):
img_pil = data['img']
messages = data['messages']
text_only = data['text_only']
is_video=False
if 'is_video' in data:
is_video=data['is_video']
messages.append({'role': 'assistant', 'content': ''})
conversations = []
for i,m in enumerate(messages):
if m['role'] == 'user':
value = str(m['content']).replace('<image>', '')
if i == 0 and not text_only:
value = '<image>\n' + value
conversations.append({'from': 'human', 'value': value})
elif m['role'] == 'assistant':
conversations.append({'from': 'gpt', 'value': str(m['content']).replace('<image>', '')})
else:
raise ValueError(f"Wrong role in conversation. {m['role']}")
return {'image': img_pil,
'text_only': text_only,
'is_video':is_video,
'conversations': conversations}