Update modeling_gpt2vision.py
Browse files- modeling_gpt2vision.py +7 -4
modeling_gpt2vision.py
CHANGED
@@ -68,8 +68,11 @@ class GPT2Vision(PreTrainedModel):
|
|
68 |
|
69 |
def generate(self, question, image, max_new_tokens=30, **kwargs):
|
70 |
# Process the image
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
73 |
|
74 |
# Tokenize the question
|
75 |
prompt = f"{IMAGE_TOKEN}Question: {question}\nAnswer:"
|
@@ -77,8 +80,8 @@ class GPT2Vision(PreTrainedModel):
|
|
77 |
|
78 |
batch = {
|
79 |
"pixel_values": img_embs,
|
80 |
-
"input_ids": encoded_input.input_ids,
|
81 |
-
"attention_mask": encoded_input.attention_mask
|
82 |
}
|
83 |
|
84 |
inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(batch)
|
|
|
68 |
|
69 |
def generate(self, question, image, max_new_tokens=30, **kwargs):
|
70 |
# Process the image
|
71 |
+
# Convert the image to a tensor and add a batch dimension
|
72 |
+
image_tensor = self.vision_encoder.image_transform(image).unsqueeze(0).to(self.device)
|
73 |
+
with torch.no_grad():
|
74 |
+
img_features = self.vision_model(image_tensor).last_hidden_state
|
75 |
+
img_embs = self.mlp(img_features)
|
76 |
|
77 |
# Tokenize the question
|
78 |
prompt = f"{IMAGE_TOKEN}Question: {question}\nAnswer:"
|
|
|
80 |
|
81 |
batch = {
|
82 |
"pixel_values": img_embs,
|
83 |
+
"input_ids": encoded_input.input_ids.to(self.device),
|
84 |
+
"attention_mask": encoded_input.attention_mask.to(self.device)
|
85 |
}
|
86 |
|
87 |
inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(batch)
|