austinmw's picture
Upload tool
298d752
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
from transformers.tools import PipelineTool
from transformers.tools.base import get_default_device
from transformers.utils import requires_backends
class InstructBLIPImageQuestionAnsweringTool(PipelineTool):
#default_checkpoint = "Salesforce/blip2-opt-2.7b"
#default_checkpoint = "Salesforce/instructblip-flan-t5-xl"
#default_checkpoint = "Salesforce/instructblip-vicuna-7b"
default_checkpoint = "Salesforce/instructblip-vicuna-13b"
description = (
"This is a tool that answers a question about an image. It takes an input named `image` which should be the "
"image containing the information, as well as a `question` which should be the question in English. It "
"returns a text that is the answer to the question."
)
name = "image_qa"
pre_processor_class = AutoProcessor
model_class = AutoModelForVision2Seq
inputs = ["image", "text"]
outputs = ["text"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["vision"])
super().__init__(*args, **kwargs)
def setup(self):
"""
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
"""
if isinstance(self.pre_processor, str):
self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
if isinstance(self.model, str):
self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs, load_in_4bit=True, torch_dtype=torch.float16)
if self.post_processor is None:
self.post_processor = self.pre_processor
elif isinstance(self.post_processor, str):
self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
if self.device is None:
if self.device_map is not None:
self.device = list(self.model.hf_device_map.values())[0]
else:
self.device = get_default_device()
self.is_initialized = True
def encode(self, image, question: str):
return self.pre_processor(images=image, text=question, return_tensors="pt").to(device="cuda", dtype=torch.float16)
def forward(self, inputs):
outputs = self.model.generate(
**inputs,
num_beams=5,
max_new_tokens=256,
min_length=1,
top_p=0.9,
repetition_penalty=1.5,
length_penalty=1.0,
temperature=0.7,
)
return outputs
def decode(self, outputs):
return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()