blip2-endpoint / handler.py
PD0AUTOMATIONAL's picture
Upload handler.py
d24c2b5
raw
history blame contribute delete
No virus
1.84 kB
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
from io import BytesIO
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path="Salesforce/blip2-opt-6.7b-coco"):
# load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForSeq2SeqLM.from_pretrained(path)
self.image_to_text_pipeline = pipeline('image-to-text', model=model, tokenizer=tokenizer)
image_size = 384
self.transform = transforms.Compose([
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, List[Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`dict`: will be serialized and returned
"""
# Extract inputs and kwargs from the data
inputs = data["inputs"]
parameters = data.pop("parameters", None)
# Decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
image = self.transform(image).unsqueeze(0).to(device)
# Run the model for prediction
if parameters is not None:
predictions = self.image_to_text_pipeline(image, **parameters)
else:
predictions = self.image_to_text_pipeline(image)
return predictions