blip2zh-chatglm-6b / modeling_blip2chatglm.py
shwu
feat: new generation code
d6e13c5
import copy
import os
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from torch.nn.utils.rnn import pad_sequence
import warnings
from torch import Tensor, nn
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
Blip2VisionModel,
Blip2QFormerModel,
Blip2Model,
Blip2PreTrainedModel,
Blip2ForConditionalGeneration,
GenerationConfig,
)
from transformers.models.blip_2.modeling_blip_2 import (
Blip2ForConditionalGenerationModelOutput,
)
from transformers.utils import logging
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from .modeling_chatglm import (
ChatGLMForConditionalGeneration,
InvalidScoreLogitsProcessor,
)
from .configuration_blip2chatglm import Blip2ChatGLMConfig
logger = logging.get_logger(__name__)
class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
config_class = Blip2ChatGLMConfig
def __init__(self, config: Blip2ChatGLMConfig):
Blip2PreTrainedModel.__init__(self, config)
# NOTE: we only initialize Blip2PreTrainedModel
# directly call super().__init__() will cause error since ChatGLM cannot be found by AutoModel
self.vision_model = Blip2VisionModel(config.vision_config)
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)
)
self.qformer = Blip2QFormerModel(config.qformer_config)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size, config.text_config.hidden_size
)
self.language_model = ChatGLMForConditionalGeneration(config.text_config)
# Initialize weights and apply final processing
# self.post_init()
def setup_dtype(self, vision_encoder_dtype: str = "fp32", lm_dtype: str = "fp16"):
if vision_encoder_dtype == "fp32":
self.vision_model = self.vision_model.float()
elif vision_encoder_dtype == "fp16":
self.vision_model = self.vision_model.half()
else:
raise NotImplementedError(
f"Unsupported vision_encoder_dtype: {vision_encoder_dtype}"
)
if lm_dtype == "fp32":
self.language_model = self.language_model.float()
elif lm_dtype == "fp16":
self.language_model = self.language_model.half()
elif lm_dtype == "int4":
self.language_model = self.language_model.half().quantize(4)
elif lm_dtype == "int8":
self.language_model = self.language_model.half().quantize(8)
else:
raise NotImplementedError(f"Unsupported lm_dtype: {lm_dtype}")
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids: torch.FloatTensor,
image_slot_offset: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
"""_summary_
Args:
pixel_values (torch.FloatTensor): _description_
input_ids (torch.FloatTensor): input_ids[:, :num_query_tokens] should be filled with tokenizer.unk_token_id
image_slot_offset (Optional[torch.LongTensor], optional): if not set, all vtokens are placed as prefix (image_slot_offset = torch.zeros(bsz)). Defaults to None.
attention_mask (Optional[torch.LongTensor], optional): _description_. Defaults to None.
output_attentions (Optional[bool], optional): _description_. Defaults to None.
output_hidden_states (Optional[bool], optional): _description_. Defaults to None.
labels (Optional[torch.LongTensor], optional): _description_. Defaults to None.
return_dict (Optional[bool], optional): _description_. Defaults to None.
Returns:
Union[Tuple, Blip2ForConditionalGenerationModelOutput]: _description_
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(
image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device
)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0]
# step 3: use the language model, conditioned on the query outputs and the prompt
language_model_inputs = self.language_projection(query_output)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_slot_offset is None:
# image as prefix
# update data to avoid inplace operation of leaf Variable
inputs_embeds.data[
:, : self.config.num_query_tokens, :
] = language_model_inputs
else:
for i, offset in enumerate(image_slot_offset):
inputs_embeds.data[
i, offset : offset + self.config.num_query_tokens, :
] = language_model_inputs[i]
outputs = self.language_model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = outputs.logits if return_dict else outputs[0]
loss = None
# we compute the loss here since we need to take into account the sequence length of the query embeds
if labels is not None:
logits = logits[:, -labels.size(1) :, :]
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous().to(logits.device)
# Flatten the tokens
loss_fct = CrossEntropyLoss(reduction="mean")
loss = loss_fct(
shift_logits.view(-1, self.config.text_config.vocab_size),
shift_labels.view(-1),
)
if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
return ((loss,) + output) if loss is not None else output
return Blip2ForConditionalGenerationModelOutput(
loss=loss,
logits=logits,
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
language_model_outputs=outputs,
)
def prepare_inputs_for_chat(
self,
tokenizer: PreTrainedTokenizer,
batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
max_length: int,
user_role: str = "问",
bot_role: str = "答",
):
device = self.device
nvtokens = self.config.num_query_tokens
# 1. Prepare token ids
all_images = []
all_image_slots = []
all_input_ids = []
for messages in batch_messages:
images = []
image_slots = []
input_ids = []
round_roles = [set()]
for role, qtext, qimgs in messages:
if role in round_roles[-1]:
# a new round (not the first round)
input_ids += tokenizer(
f"\n[Round {len(round_roles)}]\n{role}:",
add_special_tokens=False,
).input_ids
round_roles.append({role})
else:
round_roles[-1].add(role)
input_ids += tokenizer(
# For first role, no new line
f"\n{role}:" if len(input_ids) != 0 else f"{role}:", add_special_tokens=False
).input_ids
cur_index = 0
for qimg, img_idx in qimgs:
if img_idx > cur_index:
input_ids += tokenizer(
qtext[cur_index:img_idx], add_special_tokens=False
).input_ids
cur_index = img_idx
# image slot, embedding will be replaced by image embeddings
image_slots.append(len(input_ids))
input_ids += [tokenizer.unk_token_id] * nvtokens
images.append(qimg)
input_ids += tokenizer(
qtext[cur_index:], add_special_tokens=False
).input_ids
if len(round_roles) == 1:
# only 1 round
if len(round_roles[0]) == 1 and user_role in round_roles[0]:
# only user role
input_ids += tokenizer("").input_ids
else:
input_ids += tokenizer(f"\n{bot_role}:").input_ids
else:
# add tag for round 0
input_ids = (
tokenizer(f"[Round 0]\n", add_special_tokens=False).input_ids
+ input_ids
)
input_ids += tokenizer(f"\n{bot_role}:").input_ids
if len(input_ids) >= max_length:
image_slots_after_truncate = []
images_after_truncate = []
truncate_index = len(input_ids) - max_length
for image_slot, image in zip(image_slots, images):
# truncate from left
if len(input_ids) - image_slot < max_length:
image_slots_after_truncate.append(image_slot)
images_after_truncate.append(image)
elif len(input_ids) - (image_slot + nvtokens) < max_length:
# in-contact image slot is not allowed
truncate_index = max(truncate_index, image_slot + nvtokens)
for i, image_slot in enumerate(image_slots_after_truncate):
image_slots_after_truncate[i] = image_slot - truncate_index
input_ids = input_ids[truncate_index:]
image_slots = image_slots_after_truncate
images = images_after_truncate
# print(tokenizer.convert_ids_to_tokens(input_ids))
all_images.extend(images)
all_image_slots.append(image_slots)
all_input_ids.append(input_ids)
# 2. Prepare image embeddings
if len(all_images) != 0:
vision_outputs = self.vision_model.forward(torch.cat(all_images, dim=0))
all_image_embeds = vision_outputs[0]
indices_or_sections = [len(chunk) for chunk in all_image_slots]
indices_or_sections = np.cumsum(indices_or_sections)
all_vtokens = []
# TODO: qformer not batched
for image_embeds in torch.tensor_split(
all_image_embeds, tuple(indices_or_sections)
):
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
device
)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer.forward(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
)
query_output = query_outputs[0]
all_vtokens.append(self.language_projection(query_output))
else:
all_vtokens = None
# 3. Place image embeddings into slots
input_ids = (
torch.ones(
(len(all_input_ids), max(len(ids) for ids in all_input_ids)),
dtype=torch.long,
)
* tokenizer.pad_token_id
)
for i, ids in enumerate(all_input_ids):
# pad left
input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long)
input_ids = input_ids.to(device)
inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
if all_vtokens is not None:
for i, (image_slots, vtokens) in enumerate(
zip(all_image_slots, all_vtokens)
):
for slot, vimg in zip(image_slots, vtokens):
inputs_embeds[i][slot : slot + nvtokens, :] = vimg
return input_ids, inputs_embeds
@torch.no_grad()
def batch_chat(
self,
tokenizer: PreTrainedTokenizer,
batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]],
max_length: int = 2048,
num_beams=1,
do_sample=True,
top_p=0.7,
temperature=0.95,
user_role: str = "问",
bot_role: str = "答",
**kwargs,
):
input_ids, inputs_embeds = self.prepare_inputs_for_chat(
tokenizer=tokenizer,
batch_messages=batch_messages,
max_length=max_length,
user_role=user_role,
bot_role=bot_role,
)
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
outputs = self.language_model.generate(
input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
)
responses = []
for i, output in enumerate(outputs.tolist()):
output = output[len(input_ids[i]) :]
response = tokenizer.decode(output)
responses.append(self.language_model.process_response(response))
return responses
@torch.no_grad()
def stream_chat(
self,
tokenizer: PreTrainedTokenizer,
messages: List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]],
num_beams=5,
max_length=512,
top_p=0.9,
do_sample=True,
temperature=1,
user_role: str = "问",
bot_role: str = "答",
**kwargs,
):
input_ids, inputs_embeds = self.prepare_inputs_for_chat(
tokenizer=tokenizer,
batch_messages=[messages],
max_length=max_length,
user_role=user_role,
bot_role=bot_role,
)
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
"max_length": max_length,
"num_beams": num_beams,
"do_sample": do_sample,
"top_p": top_p,
"temperature": temperature,
"logits_processor": logits_processor,
**kwargs,
}
for outputs in self.language_model.stream_generate(
input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
):
outputs = outputs.tolist()[0][len(input_ids[0]) :]
response = tokenizer.decode(outputs)
response = self.language_model.process_response(response)
yield response