|
import copy |
|
import os |
|
from typing import Callable, List, Optional, Tuple, Union |
|
import numpy as np |
|
import torch |
|
from torch.nn import CrossEntropyLoss |
|
from torch.nn.utils.rnn import pad_sequence |
|
import warnings |
|
from torch import Tensor, nn |
|
|
|
from transformers import ( |
|
PreTrainedModel, |
|
PreTrainedTokenizer, |
|
Blip2VisionModel, |
|
Blip2QFormerModel, |
|
Blip2Model, |
|
Blip2PreTrainedModel, |
|
Blip2ForConditionalGeneration, |
|
GenerationConfig, |
|
) |
|
from transformers.models.blip_2.modeling_blip_2 import ( |
|
Blip2ForConditionalGenerationModelOutput, |
|
) |
|
from transformers.utils import logging |
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList |
|
|
|
from .modeling_chatglm import ( |
|
ChatGLMForConditionalGeneration, |
|
InvalidScoreLogitsProcessor, |
|
) |
|
from .configuration_blip2chatglm import Blip2ChatGLMConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration): |
|
config_class = Blip2ChatGLMConfig |
|
|
|
def __init__(self, config: Blip2ChatGLMConfig): |
|
Blip2PreTrainedModel.__init__(self, config) |
|
|
|
|
|
|
|
self.vision_model = Blip2VisionModel(config.vision_config) |
|
|
|
self.query_tokens = nn.Parameter( |
|
torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size) |
|
) |
|
self.qformer = Blip2QFormerModel(config.qformer_config) |
|
|
|
self.language_projection = nn.Linear( |
|
config.qformer_config.hidden_size, config.text_config.hidden_size |
|
) |
|
self.language_model = ChatGLMForConditionalGeneration(config.text_config) |
|
|
|
|
|
|
|
|
|
def setup_dtype(self, vision_encoder_dtype: str = "fp32", lm_dtype: str = "fp16"): |
|
if vision_encoder_dtype == "fp32": |
|
self.vision_model = self.vision_model.float() |
|
elif vision_encoder_dtype == "fp16": |
|
self.vision_model = self.vision_model.half() |
|
else: |
|
raise NotImplementedError( |
|
f"Unsupported vision_encoder_dtype: {vision_encoder_dtype}" |
|
) |
|
|
|
if lm_dtype == "fp32": |
|
self.language_model = self.language_model.float() |
|
elif lm_dtype == "fp16": |
|
self.language_model = self.language_model.half() |
|
elif lm_dtype == "int4": |
|
self.language_model = self.language_model.half().quantize(4) |
|
elif lm_dtype == "int8": |
|
self.language_model = self.language_model.half().quantize(8) |
|
else: |
|
raise NotImplementedError(f"Unsupported lm_dtype: {lm_dtype}") |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
input_ids: torch.FloatTensor, |
|
image_slot_offset: Optional[torch.LongTensor] = None, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: |
|
"""_summary_ |
|
|
|
Args: |
|
pixel_values (torch.FloatTensor): _description_ |
|
input_ids (torch.FloatTensor): input_ids[:, :num_query_tokens] should be filled with tokenizer.unk_token_id |
|
image_slot_offset (Optional[torch.LongTensor], optional): if not set, all vtokens are placed as prefix (image_slot_offset = torch.zeros(bsz)). Defaults to None. |
|
attention_mask (Optional[torch.LongTensor], optional): _description_. Defaults to None. |
|
output_attentions (Optional[bool], optional): _description_. Defaults to None. |
|
output_hidden_states (Optional[bool], optional): _description_. Defaults to None. |
|
labels (Optional[torch.LongTensor], optional): _description_. Defaults to None. |
|
return_dict (Optional[bool], optional): _description_. Defaults to None. |
|
|
|
Returns: |
|
Union[Tuple, Blip2ForConditionalGenerationModelOutput]: _description_ |
|
""" |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
image_embeds = vision_outputs[0] |
|
|
|
|
|
image_attention_mask = torch.ones( |
|
image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device |
|
) |
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
query_outputs = self.qformer( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
query_output = query_outputs[0] |
|
|
|
|
|
language_model_inputs = self.language_projection(query_output) |
|
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) |
|
if image_slot_offset is None: |
|
|
|
|
|
inputs_embeds.data[ |
|
:, : self.config.num_query_tokens, : |
|
] = language_model_inputs |
|
else: |
|
for i, offset in enumerate(image_slot_offset): |
|
inputs_embeds.data[ |
|
i, offset : offset + self.config.num_query_tokens, : |
|
] = language_model_inputs[i] |
|
|
|
outputs = self.language_model( |
|
input_ids=input_ids, |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
logits = outputs.logits if return_dict else outputs[0] |
|
loss = None |
|
|
|
if labels is not None: |
|
logits = logits[:, -labels.size(1) :, :] |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous().to(logits.device) |
|
|
|
|
|
loss_fct = CrossEntropyLoss(reduction="mean") |
|
|
|
loss = loss_fct( |
|
shift_logits.view(-1, self.config.text_config.vocab_size), |
|
shift_labels.view(-1), |
|
) |
|
|
|
if not return_dict: |
|
output = (logits, vision_outputs, query_outputs, outputs) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return Blip2ForConditionalGenerationModelOutput( |
|
loss=loss, |
|
logits=logits, |
|
vision_outputs=vision_outputs, |
|
qformer_outputs=query_outputs, |
|
language_model_outputs=outputs, |
|
) |
|
|
|
def prepare_inputs_for_chat( |
|
self, |
|
tokenizer: PreTrainedTokenizer, |
|
batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]], |
|
max_length: int, |
|
user_role: str = "问", |
|
bot_role: str = "答", |
|
): |
|
device = self.device |
|
nvtokens = self.config.num_query_tokens |
|
|
|
all_images = [] |
|
all_image_slots = [] |
|
all_input_ids = [] |
|
for messages in batch_messages: |
|
images = [] |
|
image_slots = [] |
|
input_ids = [] |
|
|
|
round_roles = [set()] |
|
for role, qtext, qimgs in messages: |
|
if role in round_roles[-1]: |
|
|
|
input_ids += tokenizer( |
|
f"\n[Round {len(round_roles)}]\n{role}:", |
|
add_special_tokens=False, |
|
).input_ids |
|
round_roles.append({role}) |
|
else: |
|
round_roles[-1].add(role) |
|
input_ids += tokenizer( |
|
|
|
f"\n{role}:" if len(input_ids) != 0 else f"{role}:", add_special_tokens=False |
|
).input_ids |
|
cur_index = 0 |
|
for qimg, img_idx in qimgs: |
|
if img_idx > cur_index: |
|
input_ids += tokenizer( |
|
qtext[cur_index:img_idx], add_special_tokens=False |
|
).input_ids |
|
cur_index = img_idx |
|
|
|
image_slots.append(len(input_ids)) |
|
input_ids += [tokenizer.unk_token_id] * nvtokens |
|
images.append(qimg) |
|
input_ids += tokenizer( |
|
qtext[cur_index:], add_special_tokens=False |
|
).input_ids |
|
if len(round_roles) == 1: |
|
|
|
if len(round_roles[0]) == 1 and user_role in round_roles[0]: |
|
|
|
input_ids += tokenizer("").input_ids |
|
else: |
|
input_ids += tokenizer(f"\n{bot_role}:").input_ids |
|
else: |
|
|
|
input_ids = ( |
|
tokenizer(f"[Round 0]\n", add_special_tokens=False).input_ids |
|
+ input_ids |
|
) |
|
input_ids += tokenizer(f"\n{bot_role}:").input_ids |
|
|
|
if len(input_ids) >= max_length: |
|
image_slots_after_truncate = [] |
|
images_after_truncate = [] |
|
truncate_index = len(input_ids) - max_length |
|
for image_slot, image in zip(image_slots, images): |
|
|
|
if len(input_ids) - image_slot < max_length: |
|
image_slots_after_truncate.append(image_slot) |
|
images_after_truncate.append(image) |
|
elif len(input_ids) - (image_slot + nvtokens) < max_length: |
|
|
|
truncate_index = max(truncate_index, image_slot + nvtokens) |
|
for i, image_slot in enumerate(image_slots_after_truncate): |
|
image_slots_after_truncate[i] = image_slot - truncate_index |
|
input_ids = input_ids[truncate_index:] |
|
image_slots = image_slots_after_truncate |
|
images = images_after_truncate |
|
|
|
|
|
|
|
all_images.extend(images) |
|
all_image_slots.append(image_slots) |
|
all_input_ids.append(input_ids) |
|
|
|
|
|
if len(all_images) != 0: |
|
vision_outputs = self.vision_model.forward(torch.cat(all_images, dim=0)) |
|
all_image_embeds = vision_outputs[0] |
|
indices_or_sections = [len(chunk) for chunk in all_image_slots] |
|
indices_or_sections = np.cumsum(indices_or_sections) |
|
all_vtokens = [] |
|
|
|
for image_embeds in torch.tensor_split( |
|
all_image_embeds, tuple(indices_or_sections) |
|
): |
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( |
|
device |
|
) |
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
query_outputs = self.qformer.forward( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
) |
|
query_output = query_outputs[0] |
|
|
|
all_vtokens.append(self.language_projection(query_output)) |
|
else: |
|
all_vtokens = None |
|
|
|
|
|
input_ids = ( |
|
torch.ones( |
|
(len(all_input_ids), max(len(ids) for ids in all_input_ids)), |
|
dtype=torch.long, |
|
) |
|
* tokenizer.pad_token_id |
|
) |
|
for i, ids in enumerate(all_input_ids): |
|
|
|
input_ids[i][-len(ids) :] = torch.as_tensor(ids, dtype=torch.long) |
|
input_ids = input_ids.to(device) |
|
inputs_embeds = self.language_model.transformer.word_embeddings(input_ids) |
|
if all_vtokens is not None: |
|
for i, (image_slots, vtokens) in enumerate( |
|
zip(all_image_slots, all_vtokens) |
|
): |
|
for slot, vimg in zip(image_slots, vtokens): |
|
inputs_embeds[i][slot : slot + nvtokens, :] = vimg |
|
|
|
return input_ids, inputs_embeds |
|
|
|
@torch.no_grad() |
|
def batch_chat( |
|
self, |
|
tokenizer: PreTrainedTokenizer, |
|
batch_messages: List[List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]]], |
|
max_length: int = 2048, |
|
num_beams=1, |
|
do_sample=True, |
|
top_p=0.7, |
|
temperature=0.95, |
|
user_role: str = "问", |
|
bot_role: str = "答", |
|
**kwargs, |
|
): |
|
input_ids, inputs_embeds = self.prepare_inputs_for_chat( |
|
tokenizer=tokenizer, |
|
batch_messages=batch_messages, |
|
max_length=max_length, |
|
user_role=user_role, |
|
bot_role=bot_role, |
|
) |
|
|
|
logits_processor = LogitsProcessorList() |
|
logits_processor.append(InvalidScoreLogitsProcessor()) |
|
gen_kwargs = { |
|
"max_length": max_length, |
|
"num_beams": num_beams, |
|
"do_sample": do_sample, |
|
"top_p": top_p, |
|
"temperature": temperature, |
|
"logits_processor": logits_processor, |
|
**kwargs, |
|
} |
|
|
|
outputs = self.language_model.generate( |
|
input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs |
|
) |
|
responses = [] |
|
for i, output in enumerate(outputs.tolist()): |
|
output = output[len(input_ids[i]) :] |
|
response = tokenizer.decode(output) |
|
responses.append(self.language_model.process_response(response)) |
|
return responses |
|
|
|
@torch.no_grad() |
|
def stream_chat( |
|
self, |
|
tokenizer: PreTrainedTokenizer, |
|
messages: List[Tuple[str, str, List[Tuple[torch.Tensor, int]]]], |
|
num_beams=5, |
|
max_length=512, |
|
top_p=0.9, |
|
do_sample=True, |
|
temperature=1, |
|
user_role: str = "问", |
|
bot_role: str = "答", |
|
**kwargs, |
|
): |
|
input_ids, inputs_embeds = self.prepare_inputs_for_chat( |
|
tokenizer=tokenizer, |
|
batch_messages=[messages], |
|
max_length=max_length, |
|
user_role=user_role, |
|
bot_role=bot_role, |
|
) |
|
|
|
logits_processor = LogitsProcessorList() |
|
logits_processor.append(InvalidScoreLogitsProcessor()) |
|
gen_kwargs = { |
|
"max_length": max_length, |
|
"num_beams": num_beams, |
|
"do_sample": do_sample, |
|
"top_p": top_p, |
|
"temperature": temperature, |
|
"logits_processor": logits_processor, |
|
**kwargs, |
|
} |
|
|
|
for outputs in self.language_model.stream_generate( |
|
input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs |
|
): |
|
outputs = outputs.tolist()[0][len(input_ids[0]) :] |
|
response = tokenizer.decode(outputs) |
|
response = self.language_model.process_response(response) |
|
yield response |
|
|