Spaces:
Sleeping
Sleeping
using our new model
Browse files- inference.py +9 -8
inference.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1 |
-
from transformers import ViltProcessor, ViltForQuestionAnswering,
|
2 |
from transformers.utils import logging
|
3 |
|
|
|
|
|
4 |
class Inference:
|
5 |
def __init__(self):
|
6 |
self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
7 |
self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
8 |
|
9 |
-
self.
|
10 |
-
self.
|
11 |
logging.set_verbosity_info()
|
12 |
self.logger = logging.get_logger("transformers")
|
13 |
|
@@ -17,8 +19,6 @@ class Inference:
|
|
17 |
return self.__inference_vilt(image, text)
|
18 |
elif selected == "Model 2":
|
19 |
return self.__inference_deplot(image, text)
|
20 |
-
elif selected == "Model 3":
|
21 |
-
return self.__inference_vilt(image, text)
|
22 |
else:
|
23 |
self.logger.warning("Please select a model to make the inference..")
|
24 |
|
@@ -30,6 +30,7 @@ class Inference:
|
|
30 |
return f"{self.vilt_model.config.id2label[idx]}"
|
31 |
|
32 |
def __inference_deplot(self, image, text):
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
1 |
+
from transformers import ViltProcessor, ViltForQuestionAnswering, BlipProcessor, BlipForQuestionAnswering
|
2 |
from transformers.utils import logging
|
3 |
|
4 |
+
import torch
|
5 |
+
|
6 |
class Inference:
|
7 |
def __init__(self):
|
8 |
self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
9 |
self.vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
10 |
|
11 |
+
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
12 |
+
self.blip_model = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning").to("cuda")
|
13 |
logging.set_verbosity_info()
|
14 |
self.logger = logging.get_logger("transformers")
|
15 |
|
|
|
19 |
return self.__inference_vilt(image, text)
|
20 |
elif selected == "Model 2":
|
21 |
return self.__inference_deplot(image, text)
|
|
|
|
|
22 |
else:
|
23 |
self.logger.warning("Please select a model to make the inference..")
|
24 |
|
|
|
30 |
return f"{self.vilt_model.config.id2label[idx]}"
|
31 |
|
32 |
def __inference_deplot(self, image, text):
|
33 |
+
encoding = self.blip_processor(image, text, return_tensors="pt").to("cuda:0", torch.float16)
|
34 |
+
out = self.blip_model.generate(**encoding)
|
35 |
+
generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
|
36 |
+
return f"{generated_text}"
|