File size: 5,066 Bytes
b2c8b1d 95bbad5 b2c8b1d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import argparse
import torch
from typing import Dict, List, Any
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from configs import *
from PIL import Image
import requests
import base64
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
class EndpointHandler():
def __init__(self, path = MODEL_PATH):
disable_torch_init()
self.model_path = MODEL_PATH
self.model_base = MODEL_BASE
self.model_name = get_model_name_from_path(self.model_path)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(
model_path=self.model_path,
model_name=self.model_name,
load_8bit=LOAD_8BIT,
load_4bit=LOAD_4BIT,
model_base=self.model_base,
device=self.device,
)
if "llama-2" in self.model_name.lower():
self.conv_mode = "llava_llama_2"
elif "v1" in self.model_name.lower():
self.conv_mode = "llava_v1"
elif "mpt" in self.model_name.lower():
self.conv_mode = "mpt"
else:
self.conv_mode = "llava_v0"
# conv_mode = CONV_MODE
# self.conv = conv_templates[conv_mode].copy()
# if "mpt" in self.model_name.lower():
# self.roles = ("user", "assistant")
# else:
# self.roles = self.conv.roles
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
self.conv = conv_templates[self.conv_mode].copy()
if "mpt" in self.model_name.lower():
self.roles = ("user", "assistant")
else:
self.roles = self.conv.roles
# getting encoded image from the data
image_encoded = data.pop("inputs", data)
# getting the manual prompt from the data
text = data["text"]
# decoding the base64 to image
image = self.decode_base64_image(image_encoded)
# converting the mode of the image to RGB if it is not that
if image.mode != "RGB":
image = image.convert("RGB")
model_config = {"image_aspect_ratio": IMAGE_ASPECT_RATIO}
# preprocessing the image
image_tensor = process_images([image], self.image_processor, model_config)
# converting to torch.tensor
image_tensor = image_tensor.to(self.model.device, dtype = torch.float16)
while True:
# getting the predefined prompt from the `prompts` file
inp = text #prompt_.user_prompt
if image is not None:
# first message
if self.model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
self.conv.append_message(self.conv.roles[0], inp)
image = None
else:
# later messages
self.conv.append_message(self.conv.roles[0], inp)
self.conv.append_message(self.conv.roles[1], None)
prompt = self.conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=TEMPERATURE,
max_new_tokens=MAX_NEW_TOKENS,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria]
)
# print(len(output_ids) if type(output_ids) is list else output_ids.shape)
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
# self.conv.messages[-1][-1] = outputs
return outputs
# return f"{input_ids.shape},{output_ids.shape}"
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image |