from typing import Dict, Any import torch from transformers import Blip2Processor, Blip2Config, Blip2ForConditionalGeneration, BitsAndBytesConfig from accelerate import init_empty_weights, infer_auto_device_map from PIL import Image from io import BytesIO import base64 import torch.nn.functional as F class EndpointHandler(): def __init__(self, path=""): self.processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl") config = Blip2Config.from_pretrained("Salesforce/blip2-flan-t5-xxl") with init_empty_weights(): model = Blip2ForConditionalGeneration(config) device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"]) device_map['language_model.lm_head'] = device_map["language_model.encoder.embed_tokens"] self.model = Blip2ForConditionalGeneration.from_pretrained( "Salesforce/blip2-flan-t5-xxl", device_map=device_map, torch_dtype=torch.float16, quantization_config=BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True) ) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = data["inputs"] if inputs["mode"] == 'generate_text': input_text: str = inputs['input_text'] image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image']))) max_new_tokens: int = inputs['max_new_tokens'] stop: str = inputs['stop'] temperature: float = inputs['temperature'] inputs = self.processor(images=image, text=input_text, return_tensors="pt").to( self.model.device, self.model.dtype ) output = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature )[0] output_text = self.processor.decode(output, skip_special_tokens=True).strip() if stop in output_text: output_text = output_text[: output_text.find(stop)] return {'output_text': output_text} elif inputs["mode"] == 'get_continuation_likelihood': prompt: str = inputs['prompt'] continuation = inputs['continuation'] image: Image.Image = Image.open(BytesIO(base64.b64decode(inputs['image']))) inputs = self.processor( images=image, text=(prompt + continuation), return_tensors="pt" ).to(self.model.device, self.model.dtype) inputs["labels"] = inputs["input_ids"] input_ids = inputs["input_ids"][0] tokens = [self.processor.decode([t]) for t in input_ids] logits = self.model(**inputs).logits[0] logprobs = F.log_softmax(logits, dim=1) logprobs = [logprobs[i, inputs["input_ids"][0][i]] for i in range(len(tokens))] return { 'prompt': prompt, 'continuation': continuation, 'tokens': tokens, 'logprobs': logprobs }