Spaces:
Sleeping
Sleeping
making the model comparison
Browse files- app.py +4 -5
- inference.py +24 -17
app.py
CHANGED
@@ -7,17 +7,16 @@ inference = Inference()
|
|
7 |
|
8 |
|
9 |
with gr.Blocks() as block:
|
10 |
-
options = gr.Dropdown(choices=["ViLT", "Blip Saffal", "Blip CN"], label="Models", info="Select the model to use..", )
|
11 |
-
# need to improve this one...
|
12 |
-
|
13 |
txt = gr.Textbox(label="Insert a question..", lines=2)
|
14 |
-
|
|
|
|
|
15 |
btn = gr.Button(value="Submit")
|
16 |
|
17 |
dogs = os.path.join(os.path.dirname(__file__), "617.jpg")
|
18 |
image = gr.Image(type="pil", value=dogs)
|
19 |
|
20 |
-
btn.click(inference.inference, inputs=[
|
21 |
|
22 |
if __name__ == "__main__":
|
23 |
block.launch()
|
|
|
7 |
|
8 |
|
9 |
with gr.Blocks() as block:
|
|
|
|
|
|
|
10 |
txt = gr.Textbox(label="Insert a question..", lines=2)
|
11 |
+
outputs = [gr.outputs.Textbox(label="Answer from BLIP saffal model"), gr.outputs.Textbox(label="Answer from BLIP control net"),
|
12 |
+
gr.outputs.Textbox(label="Answer from ViLT saffal model"), gr.outputs.Textbox(label="Answer from ViLT control net")]
|
13 |
+
|
14 |
btn = gr.Button(value="Submit")
|
15 |
|
16 |
dogs = os.path.join(os.path.dirname(__file__), "617.jpg")
|
17 |
image = gr.Image(type="pil", value=dogs)
|
18 |
|
19 |
+
btn.click(inference.inference, inputs=[image, txt], outputs=outputs)
|
20 |
|
21 |
if __name__ == "__main__":
|
22 |
block.launch()
|
inference.py
CHANGED
@@ -6,31 +6,38 @@ import torch
|
|
6 |
class Inference:
|
7 |
def __init__(self):
|
8 |
self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
9 |
-
self.
|
|
|
10 |
|
11 |
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
12 |
self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning")
|
13 |
self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning")
|
|
|
14 |
logging.set_verbosity_info()
|
15 |
self.logger = logging.get_logger("transformers")
|
16 |
|
17 |
-
def inference(self,
|
18 |
-
self.logger.info(f"
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
encoding = self.vilt_processor(image, text, return_tensors="pt")
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
return f"{self.vilt_model.config.id2label[idx]}"
|
34 |
|
35 |
def __inference_saffal_blip(self, image, text):
|
36 |
encoding = self.blip_processor(image, text, return_tensors="pt")
|
|
|
6 |
class Inference:
|
7 |
def __init__(self):
|
8 |
self.vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
9 |
+
self.vilt_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/vilt_saffal_model")
|
10 |
+
self.vilt_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/vilt_control_net")
|
11 |
|
12 |
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
13 |
self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning")
|
14 |
self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning")
|
15 |
+
|
16 |
logging.set_verbosity_info()
|
17 |
self.logger = logging.get_logger("transformers")
|
18 |
|
19 |
+
def inference(self, image, text):
|
20 |
+
self.logger.info(f"Running inference for model ViLT Saffal")
|
21 |
+
ViLT_saffal_inference = self.__inference_vilt_saffal(image, text)
|
22 |
+
self.logger.info(f"Running inference for model ViLT Control Net")
|
23 |
+
ViLT_control_net_inference = self.__inference_vilt_control_net(image, text)
|
24 |
+
self.logger.info(f"Running inference for model BLIP Saffal")
|
25 |
+
BLIP_saffal_inference = self.__inference_saffal_blip(image, text)
|
26 |
+
self.logger.info(f"Running inference for model BLIP Control Net")
|
27 |
+
BLIP_control_net_inference = self.__inference_control_net_blip(image, text)
|
28 |
+
return BLIP_saffal_inference, BLIP_control_net_inference, ViLT_saffal_inference, ViLT_control_net_inference
|
29 |
+
|
30 |
+
def __inference_vilt_saffal(self, image, text):
|
31 |
+
encoding = self.vilt_processor(image, text, return_tensors="pt")
|
32 |
+
out = self.vilt_model_saffal.generate(**encoding)
|
33 |
+
generated_text = self.vilt_processor.decode(out[0], skip_special_tokens=True)
|
34 |
+
return f"{generated_text}"
|
35 |
+
|
36 |
+
def __inference_vilt_control_net(self, image, text):
|
37 |
encoding = self.vilt_processor(image, text, return_tensors="pt")
|
38 |
+
out = self.vilt_model_control_net.generate(**encoding)
|
39 |
+
generated_text = self.vilt_processor.decode(out[0], skip_special_tokens=True)
|
40 |
+
return f"{generated_text}"
|
|
|
41 |
|
42 |
def __inference_saffal_blip(self, image, text):
|
43 |
encoding = self.blip_processor(image, text, return_tensors="pt")
|