llava-1-5 / handler.py
saurabh-straive's picture
initial commit
8874c09
import argparse
import torch
from typing import Dict, List, Any
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from configs import *
from PIL import Image
import requests
import base64
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
class EndpointHandler():
def __init__(self, path = MODEL_PATH):
disable_torch_init()
self.model_path = MODEL_PATH
self.model_base = MODEL_BASE
self.model_name = get_model_name_from_path(self.model_path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(
model_path=self.model_path,
model_name=self.model_name,
load_8bit=LOAD_8BIT,
load_4bit=LOAD_4BIT,
device=self.device,
)
if "llama-2" in self.model_name.lower():
self.conv_mode = "llava_llama_2"
elif "v1" in self.model_name.lower():
self.conv_mode = "llava_v1"
elif "mpt" in self.model_name.lower():
self.conv_mode = "mpt"
else:
self.conv_mode = "llava_v0"
# conv_mode = CONV_MODE
# self.conv = conv_templates[conv_mode].copy()
# if "mpt" in self.model_name.lower():
# self.roles = ("user", "assistant")
# else:
# self.roles = self.conv.roles
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
self.conv = conv_templates[self.conv_mode].copy()
if "mpt" in self.model_name.lower():
self.roles = ("user", "assistant")
else:
self.roles = self.conv.roles
# getting encoded image from the data
image_encoded = data.pop("inputs", data)
# getting the manual prompt from the data
text = data["text"]
# decoding the base64 to image
image = self.decode_base64_image(image_encoded)
# converting the mode of the image to RGB if it is not that
if image.mode != "RGB":
image = image.convert("RGB")
model_config = {"image_aspect_ratio": IMAGE_ASPECT_RATIO}
# preprocessing the image
image_tensor = process_images([image], self.image_processor, model_config)
# converting to torch.tensor
image_tensor = image_tensor.to(self.model.device, dtype = torch.float16)
while True:
# getting the predefined prompt from the `prompts` file
inp = text #prompt_.user_prompt
if image is not None:
# first message
if self.model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
self.conv.append_message(self.conv.roles[0], inp)
image = None
else:
# later messages
self.conv.append_message(self.conv.roles[0], inp)
self.conv.append_message(self.conv.roles[1], None)
prompt = self.conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=TEMPERATURE,
max_new_tokens=MAX_NEW_TOKENS,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria]
)
# print(len(output_ids) if type(output_ids) is list else output_ids.shape)
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# self.conv.messages[-1][-1] = outputs
return outputs
# return f"{input_ids.shape},{output_ids.shape}"
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image