File size: 7,795 Bytes
0bae6cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from __future__ import annotations

import base64
import json
import logging
import os
import uuid
from io import BytesIO

import requests
from PIL import Image

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel
from .. import shared

# print('model loading')
# model = AutoModelForCausalLM.from_pretrained(
#     "/home/shaozw/labs/imp-v0",
#     torch_dtype=torch.float16,
#     device_map="auto",
#     trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained("/home/shaozw/labs/imp-v0", trust_remote_code=True)
# print('model loaded')


class XMChat(BaseLLMModel):
    def __init__(self, api_key, user_name="", common_model=None, common_tokenizer=None):
        super().__init__(model_name="xmchat", user=user_name)
        self.api_key = api_key
        self.image_flag = False
        self.session_id = None
        self.reset()
        self.image_bytes = None
        self.image_path = None
        self.xm_history = []
        self.url = "https://xmbot.net/web"
        self.last_conv_id = None
        self.max_generation_token = 100
        # [Edited by zhenwei - 2024-01-26 10:35]
        self.common_model = common_model
        self.common_tokenizer = common_tokenizer
        self.system_prompt = "A chat between a curious user and an artificial intelligence assistant. This artificial intelligence assistant is a chatbot named as Imp, and developed by MILVLG team. Imp gives helpful, detailed, and polite answers to the user's questions."

    def reset(self, remain_system_prompt=False):
        logging.info("Reseting...")
        self.session_id = str(uuid.uuid4())
        self.last_conv_id = None
        self.image_bytes = None
        self.image_flag = False
        return super().reset()

    def image_to_base64(self, image_path):
        # 打开并加载图片
        img = Image.open(image_path)

        # 获取图片的宽度和高度
        width, height = img.size

        # 计算压缩比例,以确保最长边小于4096像素
        max_dimension = 2048
        scale_ratio = min(max_dimension / width, max_dimension / height)

        if scale_ratio < 1:
            # 按压缩比例调整图片大小
            new_width = int(width * scale_ratio)
            new_height = int(height * scale_ratio)
            img = img.resize((new_width, new_height), Image.LANCZOS)

        # 将图片转换为jpg格式的二进制数据
        buffer = BytesIO()
        if img.mode == "RGBA":
            img = img.convert("RGB")
        img.save(buffer, format='JPEG')
        binary_image = buffer.getvalue()

        # 对二进制数据进行Base64编码
        base64_image = base64.b64encode(binary_image).decode('utf-8')

        return base64_image

    def try_read_image(self, filepath):
        def is_image_file(filepath):
            # 判断文件是否为图片
            valid_image_extensions = [
                ".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
            file_extension = os.path.splitext(filepath)[1].lower()
            return file_extension in valid_image_extensions

        if is_image_file(filepath):
            logging.info(f"读取图片文件: {filepath}")
            self.image_bytes = Image.open(filepath)
            self.image_path = filepath
            self.image_flag = True
        else:
            self.image_bytes = None
            self.image_path = None
            # self.image_flag = False

    def like(self):
        if self.last_conv_id is None:
            return "点赞失败,你还没发送过消息"
        data = {
            "uuid": self.last_conv_id,
            "appraise": "good"
        }
        requests.post(self.url, json=data)
        return "👍点赞成功,感谢反馈~"

    def dislike(self):
        if self.last_conv_id is None:
            return "点踩失败,你还没发送过消息"
        data = {
            "uuid": self.last_conv_id,
            "appraise": "bad"
        }
        requests.post(self.url, json=data)
        return "👎点踩成功,感谢反馈~"

    def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
        fake_inputs = real_inputs
        display_append = ""
        limited_context = False
        return limited_context, fake_inputs, display_append, real_inputs, chatbot

    def handle_file_upload(self, files, chatbot, language):
        """if the model accepts multi modal input, implement this function"""
        if files:
            for file in files:
                if file.name:
                    logging.info(f"尝试读取图像: {file.name}")
                    self.try_read_image(file.name)
            if self.image_path is not None:
                chatbot = chatbot + [((self.image_path,), None)]
            # if self.image_bytes is not None:
            #     logging.info("使用图片作为输入")
            #     # XMChat的一轮对话中实际上只能处理一张图片
            #     self.reset()
            #     conv_id = str(uuid.uuid4())
            #     data = {
            #         "user_id": self.api_key,
            #         "session_id": self.session_id,
            #         "uuid": conv_id,
            #         "data_type": "imgbase64",
            #         "data": self.image_bytes
            #     }
            #     response = requests.post(self.url, json=data)
            #     response = json.loads(response.text)
            #     logging.info(f"图片回复: {response['data']}")
        return None, chatbot, None

    def _get_imp_style_inputs(self):
        context = """
A chat between a curious user and an artificial intelligence assistant. This artificial intelligence assistant is a multimodal chatbot named as Imp, and developed by MILVLG team from Hangzhou Dianzi University. Imp gives helpful, detailed, and polite answers to the user's questions.
""".strip()
        for ii, i in enumerate(self.history):
            if i["role"] == "user":
                if self.image_flag and ii == len(self.history) - 1:
                    context = context.replace('<image>\n', '')
                    i["content"] = '<image>\n' + i["content"]
                    self.image_flag = False
                context += ' USER: ' + i["content"].strip()# + ' '
            else:
                context += ' ASSISTANT: ' + i["content"].strip() + '</s>'
        context += ' ASSISTANT:'
        return context

    def get_answer_at_once(self):
        # question = self.history[-1]["content"].strip()
        # question = f"{self.system_prompt.strip()} USER: <image>\n{question} ASSISTANT:"
        prompt = self._get_imp_style_inputs()
        logging.info(prompt)
        # image_tok_cnt = prompt.count('<image>')
        # global model, tokenizer
        input_ids = shared.state.imp_tokenizer(prompt, return_tensors='pt').input_ids
        image_tensor = None
        if '<image>' in prompt:
            # logging.info("Preprocessing...")
            image_tensor = shared.state.imp_model.image_preprocess(self.image_bytes)
        output_ids = shared.state.imp_model.generate(
            input_ids,
            max_new_tokens=3000,
            images=image_tensor,
            # max_length=self.token_upper_limit,
            do_sample=True if self.temperature > 0 else False,
            # top_k=self.top_k,
            top_p=self.top_p,
            temperature=self.temperature,
            # repetition_penalty=self.repetition_penalty,
            num_return_sequences=1,
            use_cache=True)[0]
        response = shared.state.imp_tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
        return response, len(response)