apoorvkh's picture
Upload handler.py
ca0693b
raw
history blame
2.51 kB
from typing import Dict, Any
import torch
from transformers import Blip2ForConditionalGeneration, Blip2Processor
from PIL import Image
from io import BytesIO
import base64
import torch.nn.functional as F
class EndpointHandler():
def __init__(self, path=""):
self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
self.model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xxl", device_map="auto",
torch_dtype=torch.float16
# load_in_8bit=True,
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data["inputs"]
if inputs["mode"] == 'generate_text':
input_text: str = inputs['input_text']
image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
max_new_tokens: int = inputs['max_new_tokens']
stop: str = inputs['stop']
temperature: float = inputs['temperature']
inputs = self.processor(images=image, text=input_text, return_tensors="pt").to(
self.model.device, self.model.dtype
)
output = self.model.generate(
**inputs, max_new_tokens=max_new_tokens, temperature=temperature
)[0]
output_text = self.processor.decode(output, skip_special_tokens=True).strip()
if stop in output_text:
output_text = output_text[: output_text.find(stop)]
return {'output_text': output_text}
elif inputs["mode"] == 'get_continuation_likelihood':
prompt: str = inputs['prompt']
continuation = inputs['continuation']
image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(
images=image, text=(prompt + continuation), return_tensors="pt"
).to(self.model.device, self.model.dtype)
inputs["labels"] = inputs["input_ids"]
input_ids = inputs["input_ids"][0]
tokens = [self.processor.decode([t]) for t in input_ids]
logits = self.model(**inputs).logits[0]
logprobs = F.log_softmax(logits, dim=1)
logprobs = [logprobs[i, inputs["input_ids"][0][i]] for i in range(len(tokens))]
return {
'prompt': prompt,
'continuation': continuation,
'tokens': tokens,
'logprobs': logprobs
}