|
from typing import Optional, Tuple, Union, List |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from transformers.utils import ( |
|
logging, |
|
) |
|
from transformers.models.blip_2.configuration_blip_2 import Blip2Config |
|
from transformers.models.blip_2.modeling_blip_2 import Blip2ForConditionalGenerationModelOutput |
|
from transformers import ( |
|
Blip2PreTrainedModel, |
|
Blip2VisionModel, |
|
Blip2QFormerModel, |
|
PreTrainedTokenizer, |
|
PreTrainedModel, |
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class ZiyaBlip2ForCausalLM(Blip2PreTrainedModel): |
|
config_class = Blip2Config |
|
main_input_name = "pixel_values" |
|
_keys_to_ignore_on_load_missing = [ |
|
r"language_model", |
|
] |
|
def __init__(self, config: Blip2Config, language_model: PreTrainedModel = None): |
|
super().__init__(config) |
|
|
|
self.vision_model = Blip2VisionModel(config.vision_config) |
|
|
|
self.query_tokens = nn.Parameter(torch.zeros( |
|
1, config.num_query_tokens, config.qformer_config.hidden_size)) |
|
self.qformer = Blip2QFormerModel(config.qformer_config) |
|
|
|
self.language_projection = nn.Linear( |
|
config.qformer_config.hidden_size, config.text_config.hidden_size) |
|
self.language_model = language_model |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.language_model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.language_model.set_input_embeddings(value) |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.language_model.set_output_embeddings(new_embeddings) |
|
|
|
def get_output_embeddings(self) -> nn.Module: |
|
return self.language_model.get_output_embeddings() |
|
|
|
def get_encoder(self): |
|
return self.language_model.get_encoder() |
|
|
|
def get_decoder(self): |
|
return self.language_model.get_decoder() |
|
|
|
def _tie_weights(self): |
|
if not self.config.use_decoder_only_language_model: |
|
self.language_model.encoder.embed_tokens = self.language_model.shared |
|
self.language_model.decoder.embed_tokens = self.language_model.shared |
|
|
|
def _preprocess_accelerate(self): |
|
r""" |
|
Some pre-processing hacks to make the model `accelerate` compatible. Check |
|
https://github.com/huggingface/transformers/pull/21707 for more details. |
|
""" |
|
hf_device_map = self.hf_device_map |
|
|
|
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: |
|
|
|
logger.warning( |
|
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script" |
|
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." |
|
" Please pass a `device_map` that contains `language_model` to remove this warning." |
|
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for", |
|
" more details on creating a `device_map` for large models.", |
|
) |
|
|
|
if hasattr(self.language_model, "_hf_hook"): |
|
self.language_model._hf_hook.io_same_device = True |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
input_ids_before_image: torch.FloatTensor, |
|
input_ids_after_image: torch.FloatTensor, |
|
labels_after_image: torch.FloatTensor, |
|
|
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
image_embeds = vision_outputs[0] |
|
|
|
|
|
image_attention_mask = torch.ones( |
|
image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) |
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
query_outputs = self.qformer( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
query_output = query_outputs[0] |
|
|
|
|
|
language_model_inputs = self.language_projection(query_output) |
|
language_model_attention_mask = torch.ones( |
|
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device |
|
) |
|
|
|
assert language_model_inputs.shape[0] == input_ids_after_image.shape[0] |
|
inputs_embeds_before_image = self.language_model.get_input_embeddings()(input_ids_before_image) |
|
inputs_embeds_after_image = self.language_model.get_input_embeddings()(input_ids_after_image) |
|
inputs_embeds = torch.cat( |
|
[ |
|
inputs_embeds_before_image.to(language_model_inputs.device), |
|
language_model_inputs, |
|
inputs_embeds_after_image.to(language_model_inputs.device) |
|
], dim=1) |
|
|
|
attention_mask_before = torch.ones_like(input_ids_before_image) |
|
attention_mask_after = torch.ones_like(input_ids_after_image) |
|
attention_mask = torch.cat( |
|
[ |
|
attention_mask_before.to(language_model_attention_mask.device), |
|
language_model_attention_mask, |
|
attention_mask_after.to(language_model_attention_mask.device) |
|
], dim=1 |
|
) |
|
|
|
labels = torch.cat( |
|
[ |
|
torch.tensor( |
|
[-100]).expand_as(input_ids_before_image).to(language_model_inputs.device), |
|
torch.tensor([-100]).expand(query_tokens.shape[:-1] |
|
).to(language_model_inputs.device), |
|
labels_after_image, |
|
], dim=1 |
|
) |
|
|
|
|
|
|
|
if self.config.use_decoder_only_language_model: |
|
outputs = self.language_model( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
labels=labels, |
|
) |
|
loss = outputs.loss if return_dict else outputs[0] |
|
logits = outputs.logits if return_dict else outputs[1] |
|
|
|
else: |
|
raise Exception("not impl") |
|
|
|
if not return_dict: |
|
output = (logits, vision_outputs, query_outputs, outputs) |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return Blip2ForConditionalGenerationModelOutput( |
|
loss=loss, |
|
logits=logits, |
|
vision_outputs=vision_outputs, |
|
qformer_outputs=query_outputs, |
|
language_model_outputs=outputs, |
|
) |
|
|
|
def prepare_inputs_for_chat( |
|
self, |
|
tokenizer: PreTrainedTokenizer, |
|
query: str, |
|
pixel_values: torch.Tensor, |
|
previous_querys: List[str], |
|
previous_outputs: List[str], |
|
max_length: int, |
|
): |
|
|
|
assert len(previous_querys) == len(previous_outputs) |
|
device = self.device |
|
prefix = self.config.prompt_prefix |
|
human_name = self.config.human_name |
|
assistant_name = self.config.assistant_name |
|
input_ids_before_image = tokenizer( |
|
prefix, return_tensors="pt").input_ids.to(device) |
|
inputs_ids_after_image = [] |
|
for (p, o) in zip(previous_querys, previous_outputs): |
|
|
|
inputs_ids_after_image += tokenizer(f"{human_name}: {p}\n", add_special_tokens=False).input_ids + \ |
|
tokenizer(f"{assistant_name}: {o}\n", add_special_tokens=False).input_ids |
|
|
|
inputs_ids_after_image += tokenizer(f"{human_name}: {query}\n", |
|
add_special_tokens=False).input_ids + tokenizer(f"{assistant_name} :", |
|
add_special_tokens=False).input_ids |
|
inputs_ids_after_image = torch.IntTensor([inputs_ids_after_image]).to(device) |
|
|
|
pixel_values.to(device) |
|
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state |
|
image_attention_mask = torch.ones( |
|
image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) |
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
|
query_outputs = self.qformer( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_attention_mask, |
|
return_dict=True, |
|
) |
|
query_output = query_outputs.last_hidden_state |
|
language_model_inputs = self.language_projection(query_output) |
|
|
|
|
|
prefix_inputs_embeds = self.get_input_embeddings()(input_ids_before_image) |
|
prompt_inputs_embeds = self.get_input_embeddings()(inputs_ids_after_image) |
|
inputs_embeds = torch.cat([ |
|
prefix_inputs_embeds.to(language_model_inputs.device), |
|
language_model_inputs, |
|
prompt_inputs_embeds.to(language_model_inputs.device)], dim=1) |
|
|
|
if inputs_embeds.shape[1] > max_length: |
|
inputs_embeds = inputs_embeds[:, -max_length:, :] |
|
|
|
input_ids = torch.concat([ |
|
input_ids_before_image, |
|
torch.tensor([tokenizer.eos_token_id]).expand( |
|
query_tokens.shape[:-1]).to(language_model_inputs.device), |
|
inputs_ids_after_image, |
|
], dim=1) |
|
|
|
return input_ids, inputs_embeds |
|
|
|
def chat(self, |
|
tokenizer, |
|
query: str, |
|
pixel_values: torch.Tensor, |
|
previous_querys: List[str], |
|
previous_outputs: List[str], |
|
**generate_kwargs,): |
|
""" |
|
use for generate text by chat-style |
|
Args: |
|
tokenizer (PretrainedTokenizer): llama tokenizer |
|
query (str): current input query |
|
pixel_values (torch.Tensor): image after image_processor |
|
prompts (List[str]): chat history |
|
outputs (List[str]): chat history |
|
|
|
Returns: |
|
text: generate text |
|
""" |
|
input_ids, inputs_embeds = self.prepare_inputs_for_chat( |
|
tokenizer, query, pixel_values, previous_querys, previous_outputs, 2048 |
|
) |
|
response = self.language_model.generate( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=torch.ones_like(input_ids), |
|
**generate_kwargs, |
|
) |
|
response = tokenizer.decode(response[0], skip_special_tokens=True) |
|
return response |
|
|