File size: 2,511 Bytes
ca0693b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from typing import Dict, Any
import torch
from transformers import Blip2ForConditionalGeneration, Blip2Processor
from PIL import Image
from io import BytesIO
import base64
import torch.nn.functional as F
class EndpointHandler():
def __init__(self, path=""):
self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
self.model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xxl", device_map="auto",
torch_dtype=torch.float16
# load_in_8bit=True,
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data["inputs"]
if inputs["mode"] == 'generate_text':
input_text: str = inputs['input_text']
image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
max_new_tokens: int = inputs['max_new_tokens']
stop: str = inputs['stop']
temperature: float = inputs['temperature']
inputs = self.processor(images=image, text=input_text, return_tensors="pt").to(
self.model.device, self.model.dtype
)
output = self.model.generate(
**inputs, max_new_tokens=max_new_tokens, temperature=temperature
)[0]
output_text = self.processor.decode(output, skip_special_tokens=True).strip()
if stop in output_text:
output_text = output_text[: output_text.find(stop)]
return {'output_text': output_text}
elif inputs["mode"] == 'get_continuation_likelihood':
prompt: str = inputs['prompt']
continuation = inputs['continuation']
image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(
images=image, text=(prompt + continuation), return_tensors="pt"
).to(self.model.device, self.model.dtype)
inputs["labels"] = inputs["input_ids"]
input_ids = inputs["input_ids"][0]
tokens = [self.processor.decode([t]) for t in input_ids]
logits = self.model(**inputs).logits[0]
logprobs = F.log_softmax(logits, dim=1)
logprobs = [logprobs[i, inputs["input_ids"][0][i]] for i in range(len(tokens))]
return {
'prompt': prompt,
'continuation': continuation,
'tokens': tokens,
'logprobs': logprobs
}
|