|
from __future__ import annotations |
|
|
|
import base64 |
|
import json |
|
import logging |
|
import os |
|
import uuid |
|
from io import BytesIO |
|
|
|
import requests |
|
from PIL import Image |
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
|
from ..index_func import * |
|
from ..presets import * |
|
from ..utils import * |
|
from .base_model import BaseLLMModel |
|
from .. import shared |
|
|
|
imp_model = AutoModelForCausalLM.from_pretrained( |
|
"MILVLG/imp-v1-3b", |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True) |
|
imp_tokenizer = AutoTokenizer.from_pretrained("MILVLG/imp-v1-3b", trust_remote_code=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class XMChat(BaseLLMModel): |
|
def __init__(self, api_key, user_name="", common_model=None, common_tokenizer=None): |
|
super().__init__(model_name="xmchat", user=user_name) |
|
self.api_key = api_key |
|
self.image_flag = False |
|
self.session_id = None |
|
self.reset() |
|
self.image_bytes = None |
|
self.image_path = None |
|
self.xm_history = [] |
|
self.url = "https://xmbot.net/web" |
|
self.last_conv_id = None |
|
self.max_generation_token = 100 |
|
|
|
self.common_model = common_model |
|
self.common_tokenizer = common_tokenizer |
|
self.system_prompt = "A chat between a curious user and an artificial intelligence assistant. This artificial intelligence assistant is a chatbot named as Imp, and developed by MILVLG team. Imp gives helpful, detailed, and polite answers to the user's questions." |
|
|
|
def reset(self, remain_system_prompt=False): |
|
logging.info("Reseting...") |
|
self.session_id = str(uuid.uuid4()) |
|
self.last_conv_id = None |
|
self.image_bytes = None |
|
self.image_flag = False |
|
return super().reset() |
|
|
|
def image_to_base64(self, image_path): |
|
|
|
img = Image.open(image_path) |
|
|
|
|
|
width, height = img.size |
|
|
|
|
|
max_dimension = 2048 |
|
scale_ratio = min(max_dimension / width, max_dimension / height) |
|
|
|
if scale_ratio < 1: |
|
|
|
new_width = int(width * scale_ratio) |
|
new_height = int(height * scale_ratio) |
|
img = img.resize((new_width, new_height), Image.LANCZOS) |
|
|
|
|
|
buffer = BytesIO() |
|
if img.mode == "RGBA": |
|
img = img.convert("RGB") |
|
img.save(buffer, format='JPEG') |
|
binary_image = buffer.getvalue() |
|
|
|
|
|
base64_image = base64.b64encode(binary_image).decode('utf-8') |
|
|
|
return base64_image |
|
|
|
def try_read_image(self, filepath): |
|
def is_image_file(filepath): |
|
|
|
valid_image_extensions = [ |
|
".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"] |
|
file_extension = os.path.splitext(filepath)[1].lower() |
|
return file_extension in valid_image_extensions |
|
|
|
if is_image_file(filepath): |
|
logging.info(f"读取图片文件: {filepath}") |
|
self.image_bytes = Image.open(filepath) |
|
self.image_path = filepath |
|
self.image_flag = True |
|
else: |
|
self.image_bytes = None |
|
self.image_path = None |
|
|
|
|
|
def like(self): |
|
if self.last_conv_id is None: |
|
return "点赞失败,你还没发送过消息" |
|
data = { |
|
"uuid": self.last_conv_id, |
|
"appraise": "good" |
|
} |
|
requests.post(self.url, json=data) |
|
return "👍点赞成功,感谢反馈~" |
|
|
|
def dislike(self): |
|
if self.last_conv_id is None: |
|
return "点踩失败,你还没发送过消息" |
|
data = { |
|
"uuid": self.last_conv_id, |
|
"appraise": "bad" |
|
} |
|
requests.post(self.url, json=data) |
|
return "👎点踩成功,感谢反馈~" |
|
|
|
def prepare_inputs(self, real_inputs, use_websearch, files, reply_language, chatbot): |
|
fake_inputs = real_inputs |
|
display_append = "" |
|
limited_context = False |
|
return limited_context, fake_inputs, display_append, real_inputs, chatbot |
|
|
|
def handle_file_upload(self, files, chatbot, language): |
|
"""if the model accepts multi modal input, implement this function""" |
|
if files: |
|
for file in files: |
|
if file.name: |
|
logging.info(f"尝试读取图像: {file.name}") |
|
self.try_read_image(file.name) |
|
if self.image_path is not None: |
|
chatbot = chatbot + [((self.image_path,), None)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return None, chatbot, None |
|
|
|
def _get_imp_style_inputs(self): |
|
context = """ |
|
A chat between a curious user and an artificial intelligence assistant. This artificial intelligence assistant is a multimodal chatbot named as Imp, and developed by MILVLG team from Hangzhou Dianzi University. Imp gives helpful, detailed, and polite answers to the user's questions. |
|
""".strip() |
|
for ii, i in enumerate(self.history): |
|
if i["role"] == "user": |
|
if self.image_flag and ii == len(self.history) - 1: |
|
context = context.replace('<image>\n', '') |
|
i["content"] = '<image>\n' + i["content"] |
|
self.image_flag = False |
|
context += ' USER: ' + i["content"].strip() |
|
else: |
|
context += ' ASSISTANT: ' + i["content"].strip() + '</s>' |
|
context += ' ASSISTANT:' |
|
return context |
|
|
|
def get_answer_at_once(self): |
|
|
|
|
|
global imp_model, imp_tokenizer |
|
prompt = self._get_imp_style_inputs() |
|
logging.info(prompt) |
|
|
|
|
|
input_ids = imp_tokenizer(prompt, return_tensors='pt').input_ids |
|
image_tensor = None |
|
if '<image>' in prompt: |
|
|
|
image_tensor = imp_model.image_preprocess(self.image_bytes) |
|
output_ids = imp_model.generate( |
|
input_ids, |
|
max_new_tokens=3000, |
|
images=image_tensor, |
|
|
|
do_sample=True if self.temperature > 0 else False, |
|
|
|
top_p=self.top_p, |
|
temperature=self.temperature, |
|
|
|
num_return_sequences=1, |
|
use_cache=True)[0] |
|
response = imp_tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip() |
|
return response, len(response) |
|
|