File size: 3,924 Bytes
0dd73f3 98463c7 2da7b2f 98463c7 d9b9660 c62ed8f 39f38b0 3fbcf10 98463c7 b3fcba4 2da7b2f 0dd73f3 3fbcf10 cf50367 fdd8533 cf50367 0dd73f3 3fbcf10 98463c7 fdd8533 e5b4a7d 98463c7 fdd8533 3fbcf10 98463c7 56bf954 3fbcf10 fdd8533 ccf25db d6792f7 ccf25db fdd8533 ccf25db fdd8533 3fbcf10 fdd8533 3fbcf10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoTokenizer
from .configuration_gpt2vision import GPT2VisionConfig
from .vision_encoder import VisionEncoder
from .modeling_gpt2 import GPT2LMHeadModel
IMAGE_TOKEN = "<image>"
ANSWER_EOS = "<|endoftext|>"
def resize_token_embeds(model_name="openai-community/gpt2"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
new_tokens = {
"additional_special_tokens": [IMAGE_TOKEN]
}
tokenizer.add_special_tokens(new_tokens)
return tokenizer
tokenizer = resize_token_embeds()
class MLP(nn.Module):
def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU(approximate="tanh")
self.fc2 = nn.Linear(hidden_features, out_features)
self.dropout = nn.Dropout(p=0.1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.dropout(x)
x = self.fc2(x)
return x
class GPT2Vision(PreTrainedModel):
config_class = GPT2VisionConfig
def __init__(self, config):
super().__init__(config)
self.vision_encoder = VisionEncoder()
self.mlp = MLP(in_features=768, hidden_features=768 * 4, out_features=768)
self.language_model = GPT2LMHeadModel(config.gpt2_config)
self.language_model.resize_token_embeddings(len(tokenizer))
self.tokenizer = tokenizer
self.tokenizer.pad_token = self.tokenizer.eos_token
self.image_token_id = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
@property
def device(self):
return next(self.language_model.parameters()).device
def preprocess_inputs(self, batch):
img_embs = batch['pixel_values']
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
img_embs = img_embs.to(self.device)
tok_embs = self.language_model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat((tok_embs[:, 0:1, :], img_embs, tok_embs[:, 1:, :]), dim=1)
img_attention = torch.ones((img_embs.size(0), img_embs.size(1)), dtype=torch.long, device=self.device)
attention_mask = torch.cat((attention_mask[:, 0:1], img_attention, attention_mask[:, 1:]), dim=1)
return inputs_embeds, attention_mask, input_ids
def generate(self, question, image, max_new_tokens=30, **kwargs):
# Process the image
# Convert the image to a tensor and add a batch dimension
with torch.no_grad():
img_features = self.vision_encoder(image,device=self.device)
img_embs = self.mlp(img_features)
# Tokenize the question
prompt = f"{IMAGE_TOKEN}Question: {question}\nAnswer:"
encoded_input = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
batch = {
"pixel_values": img_embs,
"input_ids": encoded_input.input_ids.to(self.device),
"attention_mask": encoded_input.attention_mask.to(self.device)
}
inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(batch)
output_sequences = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
max_new_tokens=max_new_tokens,
**kwargs
)
output = self.tokenizer.decode(output_sequences[0], skip_special_tokens=True)
return output |