|
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration |
|
import torch |
|
from PIL import Image |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
disable_torch_init() |
|
self.processor = LlavaNextProcessor.from_pretrained(path, use_fast=False) |
|
self.model = LlavaNextForConditionalGeneration.from_pretrained( |
|
path, |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True, |
|
load_in_4bit=True |
|
) |
|
self.model.to("cuda:0") |
|
|
|
def __call__(self, data): |
|
image_encoded = data.pop("inputs", data) |
|
prompt = data["text"] |
|
|
|
image = self.decode_base64_image(image_encoded) |
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
inputs = self.processor(prompt, image, return_tensors="pt").to("cuda:0") |
|
|
|
|
|
output = self.model.generate(**inputs, max_new_tokens=500) |
|
|
|
return processor.decode(output[0], skip_special_tokens=True) |
|
|