import gradio as gr import torch from torchvision import transforms from PIL import Image import urllib.request import io from pathlib import Path from blip_vqa import blip_vqa device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') image_size = 384 class App(): def __init__(self): self.selected_model=0 # Load blip for question answer print("Loading Blip for question answering") model_url = str(Path(__file__).parent/'blip_vqa.pth') self.qa_model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base') self.qa_model.eval() self.qa_model = self.qa_model.to(device) with gr.Blocks() as demo: with gr.Row(): self.image_source = gr.inputs.Image(shape=(224, 224)) with gr.Tabs(): with gr.Tab("Question/Answer"): self.question = gr.inputs.Textbox(label="Custom question (if applicable)", default="where is the right hand?") self.answer = gr.Button("Ask") self.lbl_caption = gr.outputs.Label(label="Caption") self.answer.click(self.answer_question_image, [self.image_source, self.question], self.lbl_caption) # Launch the interface demo.launch() def answer_question_image(self, img, custom_question="Describe this image"): # Load the selected PyTorch model # Preprocess the image preprocess = transforms.Compose([ transforms.Resize((image_size,image_size),interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ]) img = preprocess(Image.fromarray(img.astype('uint8'), 'RGB')) # Make a prediction with the model with torch.no_grad(): output = self.qa_model(img.unsqueeze(0).to(device), custom_question, train=False, inference='generate') answer = output # Return the predicted label as a string return answer[0] app = App()