blip_captioning / pipeline.py
philschmid's picture
philschmid HF staff
Update pipeline.py
8085a1b
raw history blame
No virus
2.14 kB
from typing import Dict, List, Any
from PIL import Image
import requests
import torch
import base64
from io import BytesIO
from blip import blip_decoder
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
class PreTrainedPipeline():
def __init__(self, path=""):
# load the optimized model
self.model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
self.model = blip_decoder(pretrained=self.model_url, image_size=384, vit='large',med_config=os.path.join(path, 'configs/med_config.json'))
self.model.eval()
self.model = self.model.to(device)
image_size = 384
self.transform = transforms.Compose([
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`list`:. The object returned should be a list of one list like [[{"label": 0.9939950108528137}]] containing :
- "label": A string representing what the label/class is. There can be multiple labels.
- "score": A score between 0 and 1 describing how confident the model is for this label/class.
"""
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
# decode base64 image to PIL
image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
image = self.transform(image).unsqueeze(0).to(device)
with torch.no_grad():
caption = self.model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
# postprocess the prediction
return caption