File size: 8,046 Bytes
0bae6cd 2524499 0bae6cd 2524499 0bae6cd 2524499 0bae6cd 2524499 0bae6cd 2524499 0bae6cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
from __future__ import annotations
import base64
import json
import logging
import os
import uuid
from io import BytesIO
import requests
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from ..index_func import *
from ..presets import *
from ..utils import *
from .base_model import BaseLLMModel
from .. import shared
imp_model = AutoModelForCausalLM.from_pretrained(
"MILVLG/imp-v1-3b",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True)
imp_tokenizer = AutoTokenizer.from_pretrained("MILVLG/imp-v1-3b", trust_remote_code=True)
# print('model loading')
# model = AutoModelForCausalLM.from_pretrained(
# "/home/shaozw/labs/imp-v0",
# torch_dtype=torch.float16,
# device_map="auto",
# trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained("/home/shaozw/labs/imp-v0", trust_remote_code=True)
# print('model loaded')
class XMChat(BaseLLMModel):
def __init__(self, api_key, user_name="", common_model=None, common_tokenizer=None):
super().__init__(model_name="xmchat", user=user_name)
self.api_key = api_key
self.image_flag = False
self.session_id = None
self.reset()
self.image_bytes = None
self.image_path = None
self.xm_history = []
self.url = "https://xmbot.net/web"
self.last_conv_id = None
self.max_generation_token = 100
# [Edited by zhenwei - 2024-01-26 10:35]
self.common_model = common_model
self.common_tokenizer = common_tokenizer
self.system_prompt = "A chat between a curious user and an artificial intelligence assistant. This artificial intelligence assistant is a chatbot named as Imp, and developed by MILVLG team. Imp gives helpful, detailed, and polite answers to the user's questions."
def reset(self, remain_system_prompt=False):
logging.info("Reseting...")
self.session_id = str(uuid.uuid4())
self.last_conv_id = None
self.image_bytes = None
self.image_flag = False
return super().reset()
def image_to_base64(self, image_path):
# 打开并加载图片
img = Image.open(image_path)
# 获取图片的宽度和高度
width, height = img.size
# 计算压缩比例,以确保最长边小于4096像素
max_dimension = 2048
scale_ratio = min(max_dimension / width, max_dimension / height)
if scale_ratio < 1:
# 按压缩比例调整图片大小
new_width = int(width * scale_ratio)
new_height = int(height * scale_ratio)
img = img.resize((new_width, new_height), Image.LANCZOS)
# 将图片转换为jpg格式的二进制数据
buffer = BytesIO()
if img.mode == "RGBA":
img = img.convert("RGB")
img.save(buffer, format='JPEG')
binary_image = buffer.getvalue()
# 对二进制数据进行Base64编码
base64_image = base64.b64encode(binary_image).decode('utf-8')
return base64_image
def try_read_image(self, filepath):
def is_image_file(filepath):
# 判断文件是否为图片
valid_image_extensions = [
".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"]
file_extension = os.path.splitext(filepath)[1].lower()
return file_extension in valid_image_extensions
if is_image_file(filepath):
logging.info(f"读取图片文件: {filepath}")
self.image_bytes = Image.open(filepath)
self.image_path = filepath
self.image_flag = True
else:
self.image_bytes = None
self.image_path = None
# self.image_flag = False
def like(self):
if self.last_conv_id is None:
return "点赞失败,你还没发送过消息"
data = {
"uuid": self.last_conv_id,
"appraise": "good"
}
requests.post(self.url, json=data)
return "👍点赞成功,感谢反馈~"
def dislike(self):
if self.last_conv_id is None:
return "点踩失败,你还没发送过消息"
data = {
"uuid": self.last_conv_id,
"appraise": "bad"
}
requests.post(self.url, json=data)
return "👎点踩成功,感谢反馈~"
def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot):
fake_inputs = real_inputs
display_append = ""
limited_context = False
return limited_context, fake_inputs, display_append, real_inputs, chatbot
def handle_file_upload(self, files, chatbot, language):
"""if the model accepts multi modal input, implement this function"""
if files:
for file in files:
if file.name:
logging.info(f"尝试读取图像: {file.name}")
self.try_read_image(file.name)
if self.image_path is not None:
chatbot = chatbot + [((self.image_path,), None)]
# if self.image_bytes is not None:
# logging.info("使用图片作为输入")
# # XMChat的一轮对话中实际上只能处理一张图片
# self.reset()
# conv_id = str(uuid.uuid4())
# data = {
# "user_id": self.api_key,
# "session_id": self.session_id,
# "uuid": conv_id,
# "data_type": "imgbase64",
# "data": self.image_bytes
# }
# response = requests.post(self.url, json=data)
# response = json.loads(response.text)
# logging.info(f"图片回复: {response['data']}")
return None, chatbot, None
def _get_imp_style_inputs(self):
context = """
A chat between a curious user and an artificial intelligence assistant. This artificial intelligence assistant is a multimodal chatbot named as Imp, and developed by MILVLG team from Hangzhou Dianzi University. Imp gives helpful, detailed, and polite answers to the user's questions.
""".strip()
for ii, i in enumerate(self.history):
if i["role"] == "user":
if self.image_flag and ii == len(self.history) - 1:
context = context.replace('<image>\n', '')
i["content"] = '<image>\n' + i["content"]
self.image_flag = False
context += ' USER: ' + i["content"].strip()# + ' '
else:
context += ' ASSISTANT: ' + i["content"].strip() + '</s>'
context += ' ASSISTANT:'
return context
def get_answer_at_once(self):
# question = self.history[-1]["content"].strip()
# question = f"{self.system_prompt.strip()} USER: <image>\n{question} ASSISTANT:"
global imp_model, imp_tokenizer
prompt = self._get_imp_style_inputs()
logging.info(prompt)
# image_tok_cnt = prompt.count('<image>')
# global model, tokenizer
input_ids = imp_tokenizer(prompt, return_tensors='pt').input_ids
image_tensor = None
if '<image>' in prompt:
# logging.info("Preprocessing...")
image_tensor = imp_model.image_preprocess(self.image_bytes)
output_ids = imp_model.generate(
input_ids,
max_new_tokens=3000,
images=image_tensor,
# max_length=self.token_upper_limit,
do_sample=True if self.temperature > 0 else False,
# top_k=self.top_k,
top_p=self.top_p,
temperature=self.temperature,
# repetition_penalty=self.repetition_penalty,
num_return_sequences=1,
use_cache=True)[0]
response = imp_tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
return response, len(response)
|