blip2_endpoint / handler.py
gdetari
blip2
095cf65
raw
history blame contribute delete
No virus
820 Bytes
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from typing import Dict, List, Any
import torch
class EndpointHandler():
def __init__(self, path=""):
self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
image = data.pop("inputs", data)
processed = self.processor(images=image, return_tensors="pt").to(self.device)
out = self.model.generate(**processed)
return self.processor.decode(out[0], skip_special_tokens=True)