File size: 4,921 Bytes
92aaa61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import torch
from typing import Dict, List, Any

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from configs import *

from PIL import Image

import requests
import base64
from PIL import Image
from io import BytesIO
from transformers import TextStreamer



class EndpointHandler():
    def __init__(self, path = MODEL_PATH):
        disable_torch_init()
        self.model_path = MODEL_PATH
        self.model_base = MODEL_BASE
        self.model_name = get_model_name_from_path(self.model_path)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model(self.model_path, self.model_base, self.model_name, LOAD_8BIT, LOAD_4BIT, device=self.device)

        if "llama-2" in self.model_name.lower():
            self.conv_mode = "llava_llama_2"
        elif "v1" in self.model_name.lower():
            self.conv_mode = "llava_v1"
        elif "mpt" in self.model_name.lower():
            self.conv_mode = "mpt"
        else:
            self.conv_mode = "llava_v0"
        
        # conv_mode = CONV_MODE

        # self.conv = conv_templates[conv_mode].copy()
        # if "mpt" in self.model_name.lower():
        #     self.roles = ("user", "assistant")
        # else:
        #     self.roles = self.conv.roles

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        
        self.conv = conv_templates[self.conv_mode].copy()
        if "mpt" in self.model_name.lower():
            self.roles = ("user", "assistant")
        else:
            self.roles = self.conv.roles

        # getting encoded image from the data
        image_encoded = data.pop("inputs", data)
        # getting the manual prompt  from the data
        text = data["text"]

        # decoding the base64 to image
        image = self.decode_base64_image(image_encoded)
        
        # converting the mode of the image to RGB if it is not that
        if image.mode != "RGB":
            image = image.convert("RGB")

        model_config = {"image_aspect_ratio": IMAGE_ASPECT_RATIO}
        # preprocessing the image
        image_tensor = process_images([image], self.image_processor, model_config)
        # converting to torch.tensor
        image_tensor = image_tensor.to(self.model.device, dtype = torch.float16)

        while True:    
            
            # getting the predefined prompt from the `prompts` file
            inp = text #prompt_.user_prompt
            
            if image is not None:
                # first message
                if self.model.config.mm_use_im_start_end:
                    inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
                else:
                    inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
                self.conv.append_message(self.conv.roles[0], inp)
                image = None
            else:
                # later messages
                self.conv.append_message(self.conv.roles[0], inp)
            self.conv.append_message(self.conv.roles[1], None)
            prompt = self.conv.get_prompt()

            input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
            stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2
            keywords = [stop_str]
            stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
            streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)

            with torch.inference_mode():
                output_ids = self.model.generate(
                    input_ids,
                    images=image_tensor,
                    do_sample=True,
                    temperature=TEMPERATURE,
                    max_new_tokens=MAX_NEW_TOKENS,
                    streamer=streamer,
                    use_cache=True,
                    stopping_criteria=[stopping_criteria]
                )
            
            # print(len(output_ids) if type(output_ids) is list  else output_ids.shape)
            outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
            # self.conv.messages[-1][-1] = outputs

            return outputs
            # return f"{input_ids.shape},{output_ids.shape}"
        
    def decode_base64_image(self, image_string):
        base64_image = base64.b64decode(image_string)
        buffer = BytesIO(base64_image)
        image = Image.open(buffer)
        return image