wiusdy commited on
Commit
0306e1b
1 Parent(s): 629822e

using our new model

Browse files
Files changed (1) hide show
  1. inference.py +9 -8
inference.py CHANGED
@@ -1,13 +1,15 @@
1
- from transformers import ViltProcessor, ViltForQuestionAnswering, Pix2StructProcessor, Pix2StructForConditionalGeneration
2
  from transformers.utils import logging
3
 
 
 
4
  class Inference:
5
  def __init__(self):
6
  self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
7
  self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
8
 
9
- self.deplot_processor = Pix2StructProcessor.from_pretrained('google/deplot')
10
- self.deplot_model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
11
  logging.set_verbosity_info()
12
  self.logger = logging.get_logger("transformers")
13
 
@@ -17,8 +19,6 @@ class Inference:
17
  return self.__inference_vilt(image, text)
18
  elif selected == "Model 2":
19
  return self.__inference_deplot(image, text)
20
- elif selected == "Model 3":
21
- return self.__inference_vilt(image, text)
22
  else:
23
  self.logger.warning("Please select a model to make the inference..")
24
 
@@ -30,6 +30,7 @@ class Inference:
30
  return f"{self.vilt_model.config.id2label[idx]}"
31
 
32
  def __inference_deplot(self, image, text):
33
- inputs = self.deplot_processor(images=image, text=text, return_tensors="pt")
34
- predictions = self.deplot_model.generate(**inputs, max_new_tokens=512)
35
- return f"{self.deplot_processor.decode(predictions[0], skip_special_tokens=True)}"
 
 
1
+ from transformers import ViltProcessor, ViltForQuestionAnswering, BlipProcessor, BlipForQuestionAnswering
2
  from transformers.utils import logging
3
 
4
+ import torch
5
+
6
  class Inference:
7
  def __init__(self):
8
  self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
9
  self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
10
 
11
+ self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
12
+ self.blip_model = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning").to("cuda")
13
  logging.set_verbosity_info()
14
  self.logger = logging.get_logger("transformers")
15
 
 
19
  return self.__inference_vilt(image, text)
20
  elif selected == "Model 2":
21
  return self.__inference_deplot(image, text)
 
 
22
  else:
23
  self.logger.warning("Please select a model to make the inference..")
24
 
 
30
  return f"{self.vilt_model.config.id2label[idx]}"
31
 
32
  def __inference_deplot(self, image, text):
33
+ encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
34
+ out = self.blip_model.generate(**encoding)
35
+ generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
36
+ return f"{generated_text}"