GPT-Vision / modeling_gpt2vision.py
damerajee's picture
Update modeling_gpt2vision.py
39f38b0 verified
raw
history blame
3.66 kB
import torch
from torch import nn
from transformers import PreTrainedModel
import re
from .vision_encoder import VisionEncoder
from .configuration_gpt2vision import GPT2VisionConfig
from .modeling_gpt2 import GPT2LMHeadModel
IMAGE_TOKEN = "<image>"
ANSWER_EOS = "<|endoftext|>"
def resize_token_embeds(model_name="openai-community/gpt2"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
new_tokens = {
"additional_special_tokens": [IMAGE_TOKEN]
}
tokenizer.add_special_tokens(new_tokens)
return tokenizer
tokenizer = resize_token_embeds()
class GPT2Vision(PreTrainedModel):
config_class = GPT2VisionConfig
def __init__(self, config):
super().__init__(config)
self.vision_encoder = VisionEncoder()
self.language_model.resize_token_embeddings(len(tokenizer))
self.tokenizer = tokenizer
tokenizer.pad_token = tokenizer.eos_token
self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
if isinstance(config.gpt2_config, dict):
gpt2_config = GPT2Config(**config.gpt2_config)
else:
gpt2_config = config.gpt2_config
self.text_model = GPT2LMHeadModel(gpt2_config)
@property
def device(self):
return self.text_model.device
def encode_image(self, image,device):
return self.vision_encoder(image,device=device)
def input_embeds(self, prompt, image_embeds, tokenizer):
def _tokenize(txt):
return tokenizer(
txt, return_tensors="pt", add_special_tokens=False
).input_ids.to(self.device)
text_emb = self.text_model.get_input_embeddings()
# Add BOS token
embeds = []
embeds.append(
text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
)
if "<image>" not in prompt:
embeds.append(text_emb(_tokenize(prompt)))
else:
assert prompt.count("<image>") == 1
before, after = prompt.split("<image>")
embeds.append(text_emb(_tokenize(f"{before}<image>")))
embeds.append(image_embeds.to(self.device))
embeds.append(text_emb(_tokenize(f"</image>{after}")))
return torch.cat(embeds, dim=1)
def generate(
self,
image_embeds,
prompt,
tokenizer,
eos_text="<|endoftext|>",
max_new_tokens=128,
**kwargs,
):
eos_tokens = tokenizer(eos_text, add_special_tokens=False)["input_ids"]
generate_config = {
"eos_token_id": eos_tokens,
"bos_token_id": tokenizer.bos_token_id,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": max_new_tokens,
**kwargs,
}
with torch.no_grad():
inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
print("inputs_embeds",inputs_embeds.size())
output_ids = self.text_model.generate(
inputs_embeds=inputs_embeds, **generate_config
)
return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
def answer_question(
self,
image_embeds,
question,
tokenizer,
chat_history="",
result_queue=None,
**kwargs,
):
prompt = f"<image>\n\n{chat_history}Question: {question}\n\nAnswer: "
answer = self.generate(
image_embeds,
prompt,
tokenizer,
eos_text="<|endoftext|>",
max_new_tokens=256,
**kwargs,
)[0]
return answer