Ziya-BLIP2-14B-Visual-v1 / modeling_ziya_blip2.py
wuxiaojun's picture
update ZiyaBlip2ForCausalLM
503ee33
from typing import Optional, Tuple, Union, List
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.utils import (
logging,
)
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from transformers.models.blip_2.modeling_blip_2 import Blip2ForConditionalGenerationModelOutput
from transformers import (
Blip2PreTrainedModel,
Blip2VisionModel,
Blip2QFormerModel,
PreTrainedTokenizer,
PreTrainedModel,
)
logger = logging.get_logger(__name__)
class ZiyaBlip2ForCausalLM(Blip2PreTrainedModel):
config_class = Blip2Config
main_input_name = "pixel_values"
_keys_to_ignore_on_load_missing = [
r"language_model",
]
def __init__(self, config: Blip2Config, language_model: PreTrainedModel = None):
super().__init__(config)
self.vision_model = Blip2VisionModel(config.vision_config)
self.query_tokens = nn.Parameter(torch.zeros(
1, config.num_query_tokens, config.qformer_config.hidden_size))
self.qformer = Blip2QFormerModel(config.qformer_config)
self.language_projection = nn.Linear(
config.qformer_config.hidden_size, config.text_config.hidden_size)
self.language_model = language_model
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.language_model.get_output_embeddings()
def get_encoder(self):
return self.language_model.get_encoder()
def get_decoder(self):
return self.language_model.get_decoder()
def _tie_weights(self):
if not self.config.use_decoder_only_language_model:
self.language_model.encoder.embed_tokens = self.language_model.shared
self.language_model.decoder.embed_tokens = self.language_model.shared
def _preprocess_accelerate(self):
r"""
Some pre-processing hacks to make the model `accelerate` compatible. Check
https://github.com/huggingface/transformers/pull/21707 for more details.
"""
hf_device_map = self.hf_device_map
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
# warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`.
logger.warning(
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
" Please pass a `device_map` that contains `language_model` to remove this warning."
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for",
" more details on creating a `device_map` for large models.",
)
if hasattr(self.language_model, "_hf_hook"):
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
def forward(
self,
pixel_values: torch.FloatTensor,
input_ids_before_image: torch.FloatTensor,
input_ids_after_image: torch.FloatTensor,
labels_after_image: torch.FloatTensor,
# 因为label不会出现在image之前,所以这里不需要labels_before_image, 按照input_ids_before_image补-100就可以了
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# step 1: forward the images through the vision encoder,
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[0]
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
image_attention_mask = torch.ones(
image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
query_output = query_outputs[0]
# step 2.5 generate the lm input by prompt and output
language_model_inputs = self.language_projection(query_output)
language_model_attention_mask = torch.ones(
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
)
# 确保language_model_inputs的batch
assert language_model_inputs.shape[0] == input_ids_after_image.shape[0]
inputs_embeds_before_image = self.language_model.get_input_embeddings()(input_ids_before_image)
inputs_embeds_after_image = self.language_model.get_input_embeddings()(input_ids_after_image)
inputs_embeds = torch.cat(
[
inputs_embeds_before_image.to(language_model_inputs.device),
language_model_inputs,
inputs_embeds_after_image.to(language_model_inputs.device)
], dim=1)
attention_mask_before = torch.ones_like(input_ids_before_image)
attention_mask_after = torch.ones_like(input_ids_after_image)
attention_mask = torch.cat(
[
attention_mask_before.to(language_model_attention_mask.device),
language_model_attention_mask,
attention_mask_after.to(language_model_attention_mask.device)
], dim=1
)
# labels也需要对应的处理,把前面空缺的-100加进去
labels = torch.cat(
[
torch.tensor(
[-100]).expand_as(input_ids_before_image).to(language_model_inputs.device),
torch.tensor([-100]).expand(query_tokens.shape[:-1]
).to(language_model_inputs.device),
labels_after_image,
], dim=1
)
# step 3: use the language model
if self.config.use_decoder_only_language_model:
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
labels=labels,
)
loss = outputs.loss if return_dict else outputs[0]
logits = outputs.logits if return_dict else outputs[1]
else:
raise Exception("not impl")
if not return_dict:
output = (logits, vision_outputs, query_outputs, outputs)
return ((loss,) + output) if loss is not None else output
return Blip2ForConditionalGenerationModelOutput(
loss=loss,
logits=logits,
vision_outputs=vision_outputs,
qformer_outputs=query_outputs,
language_model_outputs=outputs,
)
def prepare_inputs_for_chat(
self,
tokenizer: PreTrainedTokenizer,
query: str,
pixel_values: torch.Tensor,
previous_querys: List[str],
previous_outputs: List[str],
max_length: int,
):
# 1. process input_ids
assert len(previous_querys) == len(previous_outputs)
device = self.device
prefix = self.config.prompt_prefix
human_name = self.config.human_name
assistant_name = self.config.assistant_name
input_ids_before_image = tokenizer(
prefix, return_tensors="pt").input_ids.to(device)
inputs_ids_after_image = []
for (p, o) in zip(previous_querys, previous_outputs):
# {pormpt}\n[答]: {output}\n[问]:
inputs_ids_after_image += tokenizer(f"{human_name}: {p}\n", add_special_tokens=False).input_ids + \
tokenizer(f"{assistant_name}: {o}\n", add_special_tokens=False).input_ids
inputs_ids_after_image += tokenizer(f"{human_name}: {query}\n",
add_special_tokens=False).input_ids + tokenizer(f"{assistant_name} :",
add_special_tokens=False).input_ids
inputs_ids_after_image = torch.IntTensor([inputs_ids_after_image]).to(device)
# 2. Prepare embeddings
pixel_values.to(device)
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state
image_attention_mask = torch.ones(
image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_attention_mask,
return_dict=True,
)
query_output = query_outputs.last_hidden_state
language_model_inputs = self.language_projection(query_output)
# concatenate query embeddings with prompt embeddings
prefix_inputs_embeds = self.get_input_embeddings()(input_ids_before_image)
prompt_inputs_embeds = self.get_input_embeddings()(inputs_ids_after_image)
inputs_embeds = torch.cat([
prefix_inputs_embeds.to(language_model_inputs.device),
language_model_inputs,
prompt_inputs_embeds.to(language_model_inputs.device)], dim=1)
if inputs_embeds.shape[1] > max_length:
inputs_embeds = inputs_embeds[:, -max_length:, :]
input_ids = torch.concat([
input_ids_before_image,
torch.tensor([tokenizer.eos_token_id]).expand(
query_tokens.shape[:-1]).to(language_model_inputs.device),
inputs_ids_after_image,
], dim=1)
return input_ids, inputs_embeds
def chat(self,
tokenizer,
query: str,
pixel_values: torch.Tensor,
previous_querys: List[str],
previous_outputs: List[str],
**generate_kwargs,):
"""
use for generate text by chat-style
Args:
tokenizer (PretrainedTokenizer): llama tokenizer
query (str): current input query
pixel_values (torch.Tensor): image after image_processor
prompts (List[str]): chat history
outputs (List[str]): chat history
Returns:
text: generate text
"""
input_ids, inputs_embeds = self.prepare_inputs_for_chat(
tokenizer, query, pixel_values, previous_querys, previous_outputs, 2048
)
response = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=torch.ones_like(input_ids),
**generate_kwargs,
)
response = tokenizer.decode(response[0], skip_special_tokens=True)
return response