File size: 5,554 Bytes
43f2643
 
 
 
 
 
9e1deca
43f2643
 
 
4d8a17e
43f2643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d8a17e
43f2643
 
 
 
 
 
 
 
 
 
 
 
 
 
4d8a17e
43f2643
 
 
 
 
 
 
9e1deca
 
43f2643
3b36384
43f2643
9e1deca
43f2643
 
3b36384
43f2643
 
 
 
 
 
 
 
 
 
 
 
3b36384
43f2643
 
3b36384
43f2643
3b36384
43f2643
 
 
 
 
3b36384
43f2643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b36384
43f2643
 
3b36384
43f2643
 
 
 
3b36384
43f2643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b36384
43f2643
 
3b36384
 
43f2643
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
A model worker executes the model.
"""
import json
import uuid
import torch
import spaces

from peft import PeftModel

from llava.utils import (build_logger, server_error_msg)
from model_builder import load_pretrained_model
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from transformers import TextIteratorStreamer
from threading import Thread


GB = 1 << 30

worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0

model_semaphore = None


class ModelWorker:
    def __init__(self, model_path, model_base, model_name, load_bf16, lora_path):
        self.worker_id = worker_id
        if model_path.endswith("/"):
            model_path = model_path[:-1]
        if model_name is None:
            model_paths = model_path.split("/")
            if model_paths[-1].startswith('checkpoint-'):
                self.model_name = model_paths[-2] + "_" + model_paths[-1]
            else:
                self.model_name = model_paths[-1]
        else:
            self.model_name = model_name

        logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
            model_path, model_base, self.model_name, False, False, load_bf16=load_bf16)
        self.is_multimodal = 'llava' in self.model_name.lower()
        self.load_bf16 = load_bf16

        if lora_path is not None:
            self.model = PeftModel.from_pretrained(
                self.model,
                lora_path,
                torch_device='cpu',
                device_map="cpu",
            )
        self.model.to('cuda')

    @spaces.GPU
    def generate_stream(self, params):
        tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
        logger.info(f'Model devices: {model.device}')

        prompt = params["prompt"]
        ori_prompt = prompt
        images = params.get("images", None)
        num_image_tokens = 0
        if images is not None and len(images) > 0 and self.is_multimodal:
            if len(images) > 0:
                if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
                    raise ValueError("Number of images does not match number of <image> tokens in prompt")

                images = [load_image_from_base64(image) for image in images]
                images = process_images(images, image_processor, model.config)
                logger.info(f'Images: {images.shape}')

                if type(images) is list:
                    images = [image.to(model.device, dtype=torch.float16) for image in images]
                else:
                    images = images.to(model.device, dtype=torch.float16)

                if self.load_bf16:
                    images = images.to(dtype=torch.bfloat16)

                replace_token = DEFAULT_IMAGE_TOKEN
                if getattr(model.config, 'mm_use_im_start_end', False):
                    replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
                prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)

                num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
            else:
                images = None
            image_args = {"images": images}
        else:
            images = None
            image_args = {}

        temperature = float(params.get("temperature", 1.0))
        top_p = float(params.get("top_p", 1.0))
        max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
        max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
        stop_str = params.get("stop", None)
        do_sample = True if temperature > 0.001 else False

        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=None)

        max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)

        if max_new_tokens < 1:
            yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode()
            return

        thread = Thread(target=model.generate, kwargs=dict(
            inputs=input_ids,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            streamer=streamer,
            stopping_criteria=[stopping_criteria],
            use_cache=True,
            **image_args
        ))
        thread.start()

        generated_text = ori_prompt
        for new_text in streamer:
            generated_text += new_text
            if generated_text.endswith(stop_str):
                generated_text = generated_text[:-len(stop_str)]
            yield json.dumps({"text": generated_text, "error_code": 0}).encode()

    def generate_stream_gate(self, params):
        for x in self.generate_stream(params):
            yield x

def release_model_semaphore(fn=None):
    model_semaphore.release()
    if fn is not None:
        fn()