File size: 1,772 Bytes
a48ebf4
6e61211
0a5203f
c5797b7
 
a48ebf4
24eb62a
 
80ba3ac
 
c5797b7
ad1c334
90d3a07
ad1c334
 
 
 
 
 
c5797b7
e727785
24eb62a
fb006ce
e727785
90d3a07
e727785
 
 
24eb62a
fb006ce
0306e1b
90d3a07
0306e1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from transformers import AutoProcessor, BlipForQuestionAnswering
from transformers.utils import logging

class Inference:
    def __init__(self):
        self.blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-vqa-base")
        self.blip_model_saffal = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_saffal_fashion_finetuning")
        self.blip_model_control_net = BlipForQuestionAnswering.from_pretrained("wiusdy/blip_pretrained_control_net_fashion_finetuning")
        logging.set_verbosity_info()
        self.logger = logging.get_logger("transformers")

    def inference(self, selected, image, text):
        self.logger.info(f"selected model {selected}, question {text}")
        if selected == "Blip Saffal":
            return self.__inference_saffal_blip(image, text)
        elif selected == "Blip CN":
            return self.__inference_control_net_blip(image, text)
        else:
            self.logger.warning("Please select a model to make the inference..")

    def __inference_saffal_blip(self, image, text):
        encoding = self.blip_processor(image, text, return_tensors="pt")
        out = self.blip_model_saffal.generate(**encoding, max_new_tokens=100)
        generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
        self.logger.info(f"answer {generated_text}")
        return f"{generated_text}"

    def __inference_control_net_blip(self, image, text):
        encoding = self.blip_processor(image, text, return_tensors="pt")
        out = self.blip_model_control_net.generate(**encoding, max_new_tokens=100)
        generated_text = self.blip_processor.decode(out[0], skip_special_tokens=True)
        self.logger.info(f"answer {generated_text}")
        return f"{generated_text}"