Vidensogende's picture
added custom handler
6cc79d4
raw
history blame
846 Bytes
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from typing import Dict, List, Any
import torch
class EndpointHandler():
def __init__(self, path=""):
self.processor = Blip2Processor.from_pretrained("Salesforce/blip-image-captioning-large")
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
image = data.pop("inputs", data)
processed = self.processor(images=image, return_tensors="pt").to(self.device)
out = self.model.generate(**processed)
return self.processor.decode(out[0], skip_special_tokens=True)