damerajee commited on
Commit
ccf25db
·
verified ·
1 Parent(s): 0b8178d

Update modeling_gpt2vision.py

Browse files
Files changed (1) hide show
  1. modeling_gpt2vision.py +7 -4
modeling_gpt2vision.py CHANGED
@@ -68,8 +68,11 @@ class GPT2Vision(PreTrainedModel):
68
 
69
  def generate(self, question, image, max_new_tokens=30, **kwargs):
70
  # Process the image
71
- img_embs = self.vision_encoder(image, device=self.device)
72
- img_embs = self.mlp(img_embs)
 
 
 
73
 
74
  # Tokenize the question
75
  prompt = f"{IMAGE_TOKEN}Question: {question}\nAnswer:"
@@ -77,8 +80,8 @@ class GPT2Vision(PreTrainedModel):
77
 
78
  batch = {
79
  "pixel_values": img_embs,
80
- "input_ids": encoded_input.input_ids,
81
- "attention_mask": encoded_input.attention_mask
82
  }
83
 
84
  inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(batch)
 
68
 
69
  def generate(self, question, image, max_new_tokens=30, **kwargs):
70
  # Process the image
71
+ # Convert the image to a tensor and add a batch dimension
72
+ image_tensor = self.vision_encoder.image_transform(image).unsqueeze(0).to(self.device)
73
+ with torch.no_grad():
74
+ img_features = self.vision_model(image_tensor).last_hidden_state
75
+ img_embs = self.mlp(img_features)
76
 
77
  # Tokenize the question
78
  prompt = f"{IMAGE_TOKEN}Question: {question}\nAnswer:"
 
80
 
81
  batch = {
82
  "pixel_values": img_embs,
83
+ "input_ids": encoded_input.input_ids.to(self.device),
84
+ "attention_mask": encoded_input.attention_mask.to(self.device)
85
  }
86
 
87
  inputs_embeds, attention_mask, input_ids = self.preprocess_inputs(batch)