teowu commited on
Commit
01730ee
1 Parent(s): b87e7e0

Update modeling_mplug_owl2.py

Browse files
Files changed (1) hide show
  1. modeling_mplug_owl2.py +5 -0
modeling_mplug_owl2.py CHANGED
@@ -270,6 +270,7 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
270
  def score(self, images,
271
  task_: str = "quality",
272
  input_: str = "image",
 
273
  ):
274
  if not hasattr(self, "weight_tensor"):
275
  self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(self.device)
@@ -281,6 +282,8 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
281
  image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
282
  output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
283
  images=image_tensor)["logits"][:,-1, self.preferential_ids_]
 
 
284
  return torch.softmax(output_logits, -1) @ self.weight_tensor
285
  else:
286
  video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
@@ -289,6 +292,8 @@ class MPLUGOwl2LlamaForCausalLM(LlamaForCausalLM, MPLUGOwl2MetaForCausalLM):
289
  video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
290
  output_logits = self(input_ids.repeat(len(video_tensors), 1),
291
  images=video_tensors)["logits"][:,-1, self.preferential_ids_]
 
 
292
  return torch.softmax(output_logits, -1) @ self.weight_tensor
293
 
294
  def forward(
 
270
  def score(self, images,
271
  task_: str = "quality",
272
  input_: str = "image",
273
+ return_dict=False,
274
  ):
275
  if not hasattr(self, "weight_tensor"):
276
  self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(self.device)
 
282
  image_tensor = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"].half().to(self.device)
283
  output_logits = self(input_ids.repeat(image_tensor.shape[0], 1),
284
  images=image_tensor)["logits"][:,-1, self.preferential_ids_]
285
+ if return_dict:
286
+ return {"logits": output_logits, "scores": torch.softmax(output_logits, -1) @ self.weight_tensor}
287
  return torch.softmax(output_logits, -1) @ self.weight_tensor
288
  else:
289
  video = [[expand2square(frame, tuple(int(x*255) for x in self.image_processor.image_mean)) for frame in vid] for vid in images]
 
292
  video_tensors = [self.image_processor.preprocess(vid, return_tensors="pt")["pixel_values"].half().to(self.model.device) for vid in video]
293
  output_logits = self(input_ids.repeat(len(video_tensors), 1),
294
  images=video_tensors)["logits"][:,-1, self.preferential_ids_]
295
+ if return_dict:
296
+ return {"logits": output_logits, "scores": torch.softmax(output_logits, -1) @ self.weight_tensor}
297
  return torch.softmax(output_logits, -1) @ self.weight_tensor
298
 
299
  def forward(