import json import os import numpy as np # os.environ['http_proxy'] = "" # os.environ['https_proxy'] = "" import argparse import openai import tiktoken import torch from scipy.spatial.distance import cosine from langchain.chat_models import ChatOpenAI import gradio as gr import random import time import collections import pickle from argparse import Namespace import torch from PIL import Image from torch import cosine_similarity from transformers import AutoTokenizer, AutoModel from import ( ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate, ) from langchain.schema import ( AIMessage, HumanMessage, SystemMessage ) # OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY2") openai.proxy = "" openai.api_key = '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' # 在这里输入你的OpenAI API Token os.environ["OPENAI_API_KEY"] = openai.api_key folder_name = "Suzumiya" current_directory = os.getcwd() new_directory = os.path.join(current_directory, folder_name) device = torch.device("cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not os.path.exists(new_directory): os.makedirs(new_directory) print(f"文件夹 '{folder_name}' 创建成功!") else: print(f"文件夹 '{folder_name}' 已经存在。") enc = tiktoken.get_encoding("cl100k_base") class Run: def __init__(self, **params): """ * 命令行参数的接入 * 台词folder,记录台词 * system prompt存成txt文件,支持切换 * 支持设定max_len_story 和max_len_history * 支持设定save_path * 实现一个colab脚本,可以clone转换后的项目并运行,方便其他用户体验 """ self.title_to_text_pkl_path = params['title_to_text_pkl_path'] self.text_image_pkl_path = params['text_image_pkl_path'] self.dict_text_pkl_path = params['dict_text_pkl_path'] self.num_steps = params['num_steps'] self.texts_pkl_path = params['texts_pkl_path'] self.embeds_path = params['embeds_path'] self.embeds2_path = params['embeds2_path'] self.dict_path = params['dict_path'] self.image_path = params['image_path'] self.maps_pkl_path = params['maps_pkl_path'] self.folder = params['folder'] self.system_prompt = params['system_prompt'] self.max_len_story = params['max_len_story'] self.max_len_history = params['max_len_history'] self.save_path = params['save_path'] def read_text(self): """抽取、预存""" text_embeddings = [] title_to_text = {} texts = [] data = [] id = 0 for file in os.listdir(self.folder): if file.endswith('.txt'): title_name = file[:-4] with open(os.path.join(self.folder, file), 'r') as fr: title_to_text[title_name] = for line in title_to_text[title_name].strip().split('\n'): line = line.strip() category = {} ch = ':' if ':' in line else ':' if '旁白' in line: text = line.split(ch)[1].strip() else: text = ''.join(list(line.split(ch)[1])[1:-1]) # 提取「」内的文本 if title_name + "_" + text in texts: # 避免重复的text,导致embeds 和 maps形状不一致 continue texts.append(title_name+"_"+text) category["titles"] = file.split('.')[0] category["id"] = str(id) category["text"] = text id = id + 1 data.append(dict(category)) embeddings = self.get_embedding(texts) with open(self.texts_pkl_path, 'w+', encoding='utf-8') as f1: i = 0 for text in texts: item = {} item[text] = i json.dump(item, f1, ensure_ascii=False) f1.write('\n') i+=1 with open(self.embeds_path, 'w+', encoding='utf-8') as f2, open(self.embeds2_path, 'w+', encoding='utf-8') as f3: i = 0 for embed in embeddings: item = {} embed = embed.numpy().tolist() item[i] = embed if i < len(embeddings)/2: json.dump(item, f2, ensure_ascii=False) f2.write('\n') else: json.dump(item, f3, ensure_ascii=False) f3.write('\n') i += 1 #, text_embeddings), title_to_text) #, embeddings), data) return text_embeddings, data def store(self, path, data): with open(path, 'wb+') as f: pickle.dump(data, f) def load(self, load_texts=False, load_maps=False, load_dict_text=False, load_text_image=False, load_title_to_text=False): if load_texts: if self.texts_pkl_path: text_embeddings = {} texts = [] embeds1 = [] embeds2 = [] with open(self.texts_pkl_path, 'r') as f: for line in f: data = json.loads(line) texts.append(list(data.keys())[0]) with open(self.embeds_path, 'r') as f: for line in f: data = json.loads(line) embeds1.append(list(data.values())) with open(self.embeds2_path, 'r') as f: for line in f: data = json.loads(line) embeds2.append(list(data.values())) embeds = embeds1 + embeds2 for text, embed in zip(texts, embeds): text_embeddings[text] = embed return text_embeddings else: print("No texts_pkl_path") elif load_maps: if self.maps_pkl_path: with open(self.maps_pkl_path, 'rb') as f: return pickle.load(f) else: print("No maps_pkl_path") elif load_dict_text: if self.dict_text_pkl_path: with open(self.dict_text_pkl_path, 'rb') as f: return pickle.load(f) else: print("No dict_text_pkl_path") elif load_text_image: if self.text_image_pkl_path: with open(self.text_image_pkl_path, 'rb') as f: return pickle.load(f) else: print("No text_image_pkl_path") elif load_title_to_text: if self.title_to_text_pkl_path: with open(self.title_to_text_pkl_path, 'rb') as f: return pickle.load(f) else: print("No title_to_text_pkl_path") else: print("Please specify the loading file!") def text_to_image(self, text, save_dict_text=False): """ 给定文本出图片 计算query 和 texts 的相似度,取最高的作为new_query 查询image 到text_image_dict 读取图片名 然后到images里面加载该图片然后返回 """ if save_dict_text: text_image = collections.defaultdict() with open(self.dict_path, 'r') as f: data = f.readlines() for sub_text, image in zip(data[::2], data[1::2]): text_image[sub_text.strip()] = image.strip(), text_image) keys_embeddings = collections.defaultdict(str) for key in text_image.keys(): keys_embeddings[key] = self.get_embedding(key), keys_embeddings) if self.dict_path and self.image_path: # 加载 text-imageName text_image = self.load(load_text_image=True) keys = list(text_image.keys()) keys.insert(0, text) query_similarity = self.get_cosine_similarity(keys, get_image=True) key_index = query_similarity.argmax(dim=0) text = list(text_image.keys())[key_index] image = text_image[text] + '.jpg' if image in os.listdir(self.image_path): res = + '/' + image) # return res else: print("Image doesn't exist") else: print("No path") def text_to_text(self, text): pkl = self.load(load_texts=True) texts = [title_text.split('_')[1] for title_text in list(pkl.keys())] texts.insert(0, text) texts_similarity = self.get_cosine_similarity(texts, get_texts=True) key_index = texts_similarity.argmax(dim=0) value = list(pkl.keys())[key_index] return value # 一个封装 OpenAI 接口的函数,参数为 Prompt,返回对应结果 def get_completion_from_messages(self, messages, model="gpt-3.5-turbo", temperature=0): response = openai.ChatCompletion.create( model=model, messages=messages, temperature=temperature, # 控制模型输出的随机程度 ) # print(str(response.choices[0].message)) return response.choices[0].message["content"] def download_models(self): # Import our models. The package will take care of downloading the models automatically model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False, init_embeddings_model=None) model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args).to(device) return model def get_embedding(self, texts): tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert") model = self.download_models() # str or strList texts = texts if isinstance(texts, list) else [texts] # 截断 for i in range(len(texts)): if len(texts[i]) > self.num_steps: texts[i] = texts[i][:self.num_steps] # Tokenize the texts inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt") # Extract the embeddings # Get the embeddings inputs = with torch.no_grad(): embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output return embeddings[0] if len(texts) == 1 else embeddings def get_cosine_similarity(self, texts, get_image=False, get_texts=False): """ 计算文本列表的相似度避免重复计算query_similarity texts[0] = query """ if get_image: pkl = self.load(load_dict_text=True) elif get_texts: pkl = self.load(load_texts=True) else: # 计算query_embed pkl = {} embeddings = self.get_embedding(texts[1:]).reshape(-1, 1536) for text, embed in zip(texts, embeddings): pkl[text] = embed query_embedding = self.get_embedding(texts[0]).reshape(1, -1) texts_embeddings = np.array([np.array(value).reshape(-1, 1536) for value in pkl.values()]).squeeze(1) return cosine_similarity(query_embedding, torch.from_numpy(texts_embeddings)) def retrieve_title(self, query_text, k): # compute cosine similarity between query_embed and embeddings embed_to_title = [] texts = [query_text] texts_pkl = self.load(load_texts=True) for title_text in texts_pkl.keys(): res = title_text.split('_') embed_to_title.append(res[0]) cosine_similarities = self.get_cosine_similarity(texts, get_texts=True).numpy().tolist() # sort cosine similarity sorted_cosine_similarities = sorted(cosine_similarities, reverse=True) top_k_index = [] top_k_title = [] for i in range(len(sorted_cosine_similarities)): current_title = embed_to_title[cosine_similarities.index(sorted_cosine_similarities[i])] if current_title not in top_k_title: top_k_title.append(current_title) top_k_index.append(cosine_similarities.index(sorted_cosine_similarities[i])) if len(top_k_title) == k: break return top_k_title def organize_story_with_maxlen(self, selected_sample): maxlen = self.max_len_story title_to_text = self.load(load_title_to_text=True) story = "凉宫春日的经典桥段如下:\n" count = 0 final_selected = [] print(selected_sample) for sample_topic in selected_sample: # find sample_answer in dictionary sample_story = title_to_text[sample_topic] sample_len = len(enc.encode(sample_story)) # print(sample_topic, ' ' , sample_len) if sample_len + count > maxlen: break story += sample_story story += '\n' count += sample_len final_selected.append(sample_topic) return story, final_selected def organize_message(self, story, history_chat, history_response, new_query): messages = [{'role': 'system', 'content': self.system_prompt}, {'role': 'user', 'content': story}] n = len(history_chat) if n != len(history_response): print('warning, unmatched history_char length, clean and start new chat') # clean all history_chat = [] history_response = [] n = 0 for i in range(n): messages.append({'role': 'user', 'content': history_chat[i]}) messages.append({'role': 'user', 'content': history_response[i]}) messages.append({'role': 'user', 'content': new_query}) return messages def keep_tail(self, history_chat, history_response): max_len = self.max_len_history n = len(history_chat) if n == 0: return [], [] if n != len(history_response): print('warning, unmatched history_char length, clean and start new chat') return [], [] token_len = [] for i in range(n): chat_len = len(enc.encode(history_chat[i])) res_len = len(enc.encode(history_response[i])) token_len.append(chat_len + res_len) keep_k = 1 count = token_len[n - 1] for i in range(1, n): count += token_len[n - 1 - i] if count > max_len: break keep_k += 1 return history_chat[-keep_k:], history_response[-keep_k:] def organize_message_langchain(self, story, history_chat, history_response, new_query): # messages = [{'role':'system', 'content':SYSTEM_PROMPT}, {'role':'user', 'content':story}] messages = [ SystemMessage(content=self.system_prompt), HumanMessage(content=story) ] n = len(history_chat) if n != len(history_response): print('warning, unmatched history_char length, clean and start new chat') # clean all history_chat = [] history_response = [] n = 0 for i in range(n): messages.append(HumanMessage(content=history_chat[i])) messages.append(AIMessage(content=history_response[i])) # messages.append( {'role':'user', 'content':new_query }) messages.append(HumanMessage(content=new_query)) return messages def get_response(self, user_message, chat_history_tuple): history_chat = [] history_response = [] if len(chat_history_tuple) > 0: for cha, res in chat_history_tuple: history_chat.append(cha) history_response.append(res) history_chat, history_response = self.keep_tail(history_chat, history_response) print('history done') new_query = user_message selected_sample = self.retrieve_title(new_query, 7) print("备选辅助:", selected_sample) story, selected_sample = self.organize_story_with_maxlen(selected_sample) ## TODO: visualize seletected sample later print('当前辅助sample:', selected_sample) messages = self.organize_message_langchain(story, history_chat, history_response, new_query) print(f"messages:{messages}") chat = ChatOpenAI(temperature=0) return_msg = chat(messages) response = return_msg.content return response def save_response(self, chat_history_tuple): with open(f"{self.save_path}/conversation_{time.time()}.txt", "w") as file: for cha, res in chat_history_tuple: file.write(cha) file.write("\n---\n") file.write(res) file.write("\n---\n") def create_gradio(self): # from google.colab import drive # drive.mount(drive_path) with gr.Blocks() as demo: gr.Markdown( """ ## Chat凉宫春日 ChatHaruhi 项目地址 []( 骆驼项目地址 []( 此版本为图文版本,非最终版本,将上线更多功能,敬请期待 """ ) image_input = gr.Textbox(visible=False) with gr.Row(): chatbot = gr.Chatbot() image_output = gr.Image() role_name = gr.Textbox(label="角色名", placeholde="输入角色名") msg = gr.Textbox(label="输入") with gr.Row(): clear = gr.Button("Clear") sub = gr.Button("Submit") image_button = gr.Button("给我一个图") def respond(role_name, user_message, chat_history): role_name = "阿虚" if role_name in ['', ' '] else role_name role_name = role_name[:10] if len(role_name) > 10 else role_name user_message = user_message[:200] if len(user_message) > 200 else user_message special_chars = [':', ':', '「', '」', '\n'] for char in special_chars: role_name = role_name.replace(char, 'x') user_message = user_message.replace(char, ' ') replacement_rules = {'凉': '马', '宫': '宝', '春': '国', '日': '啊'} # for char, replacement in replacement_rules.items(): # role_name = role_name.replace(char, replacement) # user_message = user_message.replace(char, replacement) input_message = role_name + ':「' + user_message + '」' print(f"chat_history:{chat_history}") bot_message = self.get_response(input_message, chat_history) chat_history.append((input_message, bot_message)) self.save_response(chat_history) # time.sleep(1) return "", chat_history, bot_message msg.submit(respond, [role_name, msg, chatbot], [msg, chatbot, image_input]) None, None, chatbot, queue=False), inputs=[role_name, msg, chatbot], outputs=[msg, chatbot, image_input]), inputs=image_input, outputs=image_output) demo.launch(debug=True, share=True) if __name__ == '__main__': parser = argparse.ArgumentParser(description="-----[Chat凉宫春日]-----") parser.add_argument("--folder", default="../characters/haruhi/texts", help="text folder") parser.add_argument("--system_prompt", default="../characters/haruhi/system_prompt.txt", help="store system_prompt") parser.add_argument("--max_len_story", default=1500, type=int) parser.add_argument("--max_len_history", default=1200, type=int) # parser.add_argument("--save_path", default="/content/drive/MyDrive/GPTData/Haruhi-Lulu/") parser.add_argument("--save_path", default=os.getcwd() + "/Suzumiya") parser.add_argument("--texts_pkl_path", default="./pkl/texts.jsonl") parser.add_argument("--embeds_path", default="./pkl/embeds.jsonl") parser.add_argument("--embeds2_path", default="./pkl/embeds2.jsonl") parser.add_argument("--maps_pkl_path", default="./pkl/maps.pkl") parser.add_argument("--title_to_text_pkl_path", default='./pkl/title_to_text.pkl') parser.add_argument("--dict_text_pkl_path", default="./pkl/dict_text.pkl") parser.add_argument("--text_image_pkl_path", default="./pkl/text_image.pkl") parser.add_argument("--dict_path", default="../characters/haruhi/text_image_dict.txt") parser.add_argument("--image_path", default="../characters/haruhi/images") parser.add_argument("--num_steps", default=510, type=int) options = parser.parse_args() params = { "folder": options.folder, "system_prompt": options.system_prompt, "max_len_story": options.max_len_story, "max_len_history": options.max_len_history, "save_path": options.save_path, "texts_pkl_path": options.texts_pkl_path, "embeds_path": options.embeds_path, "embeds2_path": options.embeds2_path, "title_to_text_pkl_path": options.title_to_text_pkl_path, "maps_pkl_path": options.maps_pkl_path, "dict_text_pkl_path": options.dict_text_pkl_path, "text_image_pkl_path": options.text_image_pkl_path, "dict_path": options.dict_path, "image_path": options.image_path, "num_steps": options.num_steps, } run = Run(**params) # selected_samples = run.retrieve_title("hello", 7) # story, selected_samples = run.organize_story_with_maxlen(selected_samples) # print(story, selected_samples) run.read_text() # run.text_to_image("hello", save_dict_text=True) run.create_gradio() # a = run.load(load_texts=True) # print(len(a)) # for item in a: # print(item) # print(len(a)) # a = run.load(load_dict_text=True) # print(a) # print(len(a)) # a = run.load(load_text_image=True) # print(a) # print(len(a)) # a = run.load(load_title_to_text=True) # print(a) # print(len(a)) # b = run.load(load_maps=True) # print(len(b)) # print(run.load(load_title_to_text) # history_chat = [] # history_response = [] # chat_timer = 5 # new_query = '鲁鲁:你好我是新同学鲁鲁' # # # selected_sample = run.retrieve_title(new_query, 7) # # print('限制长度之前:', selected_sample) # # story, selected_sample = run.organize_story_with_maxlen(selected_sample) # # print('当前辅助sample:', selected_sample) # # messages = run.organize_message(story, history_chat, history_response, new_query) # # response = run.get_completion_from_messages(messages) # # print(response) # # history_chat.append(new_query) # history_response.append(response) # # history_chat, history_response = run.keep_tail(history_chat, history_response) # print(history_chat, history_response)