apoorvkh's picture
Loading in 8-bit
5eda3fd
raw
history blame
2.96 kB
from typing import Dict, Any
import torch
from transformers import Blip2Processor, Blip2Config, Blip2ForConditionalGeneration
from accelerate import init_empty_weights, infer_auto_device_map
from PIL import Image
from io import BytesIO
import base64
import torch.nn.functional as F
class EndpointHandler():
def __init__(self, path=""):
self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
config = Blip2Config.from_pretrained("Salesforce/blip2-flan-t5-xxl")
with init_empty_weights():
model = Blip2ForConditionalGeneration(config)
device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"])
device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"]
self.model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xxl", device_map=device_map,
# torch_dtype=torch.float16
load_in_8bit=True,
)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = data["inputs"]
if inputs["mode"] == 'generate_text':
input_text: str = inputs['input_text']
image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
max_new_tokens: int = inputs['max_new_tokens']
stop: str = inputs['stop']
temperature: float = inputs['temperature']
inputs = self.processor(images=image, text=input_text, return_tensors="pt").to(
self.model.device, self.model.dtype
)
output = self.model.generate(
**inputs, max_new_tokens=max_new_tokens, temperature=temperature
)[0]
output_text = self.processor.decode(output, skip_special_tokens=True).strip()
if stop in output_text:
output_text = output_text[: output_text.find(stop)]
return {'output_text': output_text}
elif inputs["mode"] == 'get_continuation_likelihood':
prompt: str = inputs['prompt']
continuation = inputs['continuation']
image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image'])))
inputs = self.processor(
images=image, text=(prompt + continuation), return_tensors="pt"
).to(self.model.device, self.model.dtype)
inputs["labels"] = inputs["input_ids"]
input_ids = inputs["input_ids"][0]
tokens = [self.processor.decode([t]) for t in input_ids]
logits = self.model(**inputs).logits[0]
logprobs = F.log_softmax(logits, dim=1)
logprobs = [logprobs[i, inputs["input_ids"][0][i]] for i in range(len(tokens))]
return {
'prompt': prompt,
'continuation': continuation,
'tokens': tokens,
'logprobs': logprobs
}