Spaces:
Runtime error
Runtime error
import os | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 指定要使用的GPU设备编号 | |
from transformers import pipeline | |
import argparse | |
import openai | |
import tiktoken | |
import torch | |
from scipy.spatial.distance import cosine | |
from transformers import AutoModel, AutoTokenizer | |
from argparse import Namespace | |
from langchain.chat_models import ChatOpenAI | |
import gradio as gr | |
import random | |
import time | |
from langchain.prompts.chat import ( | |
ChatPromptTemplate, | |
SystemMessagePromptTemplate, | |
AIMessagePromptTemplate, | |
HumanMessagePromptTemplate, | |
) | |
from langchain.schema import ( | |
AIMessage, | |
HumanMessage, | |
SystemMessage | |
) | |
from text import Text | |
def download_models(): | |
# Import our models. The package will take care of downloading the models automatically | |
model_args = Namespace(do_mlm=None, pooler_type="cls", temp=0.05, mlp_only_train=False, | |
init_embeddings_model=None) | |
model = AutoModel.from_pretrained("silk-road/luotuo-bert", trust_remote_code=True, model_args=model_args) | |
return model | |
# OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY2") | |
# openai.api_key = 'sk-DfFyRKch' # 在这里输入你的OpenAI API Token | |
# os.environ["OPENAI_API_KEY"] = openai.api_key | |
folder_name = "Suzumiya" | |
current_directory = os.getcwd() | |
new_directory = os.path.join(current_directory, folder_name) | |
pkl_path = './pkl/texts.pkl' | |
text_image_pkl_path='./pkl/text_image.pkl' | |
dict_path = "../characters/haruhi/text_image_dict.txt" | |
dict_text_pkl_path = './pkl/dict_text.pkl' | |
image_path = "../characters/haruhi/images" | |
model = download_models() | |
text = Text("../characters/haruhi/texts", text_image_pkl_path=text_image_pkl_path, | |
dict_text_pkl_path=dict_text_pkl_path, model=model, num_steps=50, pkl_path=pkl_path, | |
dict_path=dict_path, image_path=image_path) | |
if not os.path.exists(new_directory): | |
os.makedirs(new_directory) | |
print(f"文件夹 '{folder_name}' 创建成功!") | |
else: | |
print(f"文件夹 '{folder_name}' 已经存在。") | |
enc = tiktoken.get_encoding("cl100k_base") | |
class Run: | |
def __init__(self, **params): | |
""" | |
* 命令行参数的接入 | |
* 台词folder,记录台词 | |
* system prompt存成txt文件,支持切换 | |
* 支持设定max_len_story 和max_len_history | |
* 支持设定save_path | |
* 实现一个colab脚本,可以clone转换后的项目并运行,方便其他用户体验 | |
""" | |
self.folder = params['folder'] | |
# self.system_prompt = params['system_prompt'] | |
with open(params['system_prompt'], 'r') as f: | |
self.system_prompt = f.read() | |
self.max_len_story = params['max_len_story'] | |
self.max_len_history = params['max_len_history'] | |
self.save_path = params['save_path'] | |
self.titles, self.title_to_text = self.read_prompt_data() | |
self.embeddings, self.embed_to_title = self.title_text_embedding(self.titles, self.title_to_text) | |
# self.embeddings, self.embed_to_title = [], [] | |
# 一个封装 OpenAI 接口的函数,参数为 Prompt,返回对应结果 | |
def get_completion_from_messages(self, messages, model="gpt-3.5-turbo", temperature=0): | |
response = openai.ChatCompletion.create( | |
model=model, | |
messages=messages, | |
temperature=temperature, # 控制模型输出的随机程度 | |
) | |
# print(str(response.choices[0].message)) | |
return response.choices[0].message["content"] | |
def read_prompt_data(self): | |
""" | |
read prompt-data for in-context-learning | |
""" | |
titles = [] | |
title_to_text = {} | |
for file in os.listdir(self.folder): | |
if file.endswith('.txt'): | |
title_name = file[:-4] | |
titles.append(title_name) | |
with open(os.path.join(self.folder, file), 'r') as f: | |
title_to_text[title_name] = f.read() | |
return titles, title_to_text | |
def get_embedding(self, text): | |
tokenizer = AutoTokenizer.from_pretrained("silk-road/luotuo-bert") | |
model = download_models() | |
if len(text) > 512: | |
text = text[:512] | |
texts = [text] | |
# Tokenize the text | |
inputs = tokenizer(texts, padding=True, truncation=False, return_tensors="pt") | |
# Extract the embeddings | |
# Get the embeddings | |
with torch.no_grad(): | |
embeddings = model(**inputs, output_hidden_states=True, return_dict=True, sent_emb=True).pooler_output | |
return embeddings[0] | |
def title_text_embedding(self, titles, title_to_text): | |
"""titles-text-embeddings""" | |
embeddings = [] | |
embed_to_title = [] | |
for title in titles: | |
text = title_to_text[title] | |
# divide text with \n\n | |
divided_texts = text.split('\n\n') | |
for divided_text in divided_texts: | |
embed = self.get_embedding(divided_text) | |
embeddings.append(embed) | |
embed_to_title.append(title) | |
return embeddings, embed_to_title | |
def get_cosine_similarity(self, embed1, embed2): | |
return torch.nn.functional.cosine_similarity(embed1, embed2, dim=0) | |
def retrieve_title(self, query_embed, embeddings, embed_to_title, k): | |
# compute cosine similarity between query_embed and embeddings | |
cosine_similarities = [] | |
for embed in embeddings: | |
cosine_similarities.append(self.get_cosine_similarity(query_embed, embed)) | |
# sort cosine similarity | |
sorted_cosine_similarities = sorted(cosine_similarities, reverse=True) | |
top_k_index = [] | |
top_k_title = [] | |
for i in range(len(sorted_cosine_similarities)): | |
current_title = embed_to_title[cosine_similarities.index(sorted_cosine_similarities[i])] | |
if current_title not in top_k_title: | |
top_k_title.append(current_title) | |
top_k_index.append(cosine_similarities.index(sorted_cosine_similarities[i])) | |
if len(top_k_title) == k: | |
break | |
return top_k_title | |
def organize_story_with_maxlen(self, selected_sample): | |
maxlen = self.max_len_story | |
# title_to_text, _ = self.read_prompt_data() | |
story = "凉宫春日的经典桥段如下:\n" | |
count = 0 | |
final_selected = [] | |
print(selected_sample) | |
for sample_topic in selected_sample: | |
# find sample_answer in dictionary | |
sample_story = self.title_to_text[sample_topic] | |
sample_len = len(enc.encode(sample_story)) | |
# print(sample_topic, ' ' , sample_len) | |
if sample_len + count > maxlen: | |
break | |
story += sample_story | |
story += '\n' | |
count += sample_len | |
final_selected.append(sample_topic) | |
return story, final_selected | |
def organize_message(self, story, history_chat, history_response, new_query): | |
messages = [{'role': 'system', 'content': self.system_prompt}, {'role': 'user', 'content': story}] | |
n = len(history_chat) | |
if n != len(history_response): | |
print('warning, unmatched history_char length, clean and start new chat') | |
# clean all | |
history_chat = [] | |
history_response = [] | |
n = 0 | |
for i in range(n): | |
messages.append({'role': 'user', 'content': history_chat[i]}) | |
messages.append({'role': 'user', 'content': history_response[i]}) | |
messages.append({'role': 'user', 'content': new_query}) | |
return messages | |
def keep_tail(self, history_chat, history_response): | |
max_len = self.max_len_history | |
n = len(history_chat) | |
if n == 0: | |
return [], [] | |
if n != len(history_response): | |
print('warning, unmatched history_char length, clean and start new chat') | |
return [], [] | |
token_len = [] | |
for i in range(n): | |
chat_len = len(enc.encode(history_chat[i])) | |
res_len = len(enc.encode(history_response[i])) | |
token_len.append(chat_len + res_len) | |
keep_k = 1 | |
count = token_len[n - 1] | |
for i in range(1, n): | |
count += token_len[n - 1 - i] | |
if count > max_len: | |
break | |
keep_k += 1 | |
return history_chat[-keep_k:], history_response[-keep_k:] | |
def organize_message_langchain(self, story, history_chat, history_response, new_query): | |
# messages = [{'role':'system', 'content':SYSTEM_PROMPT}, {'role':'user', 'content':story}] | |
messages = [ | |
SystemMessage(content=self.system_prompt), | |
HumanMessage(content=story) | |
] | |
n = len(history_chat) | |
if n != len(history_response): | |
print('warning, unmatched history_char length, clean and start new chat') | |
# clean all | |
history_chat = [] | |
history_response = [] | |
n = 0 | |
for i in range(n): | |
messages.append(HumanMessage(content=history_chat[i])) | |
messages.append(AIMessage(content=history_response[i])) | |
# messages.append( {'role':'user', 'content':new_query }) | |
messages.append(HumanMessage(content=new_query)) | |
return messages | |
def get_response(self, user_message, chat_history_tuple): | |
history_chat = [] | |
history_response = [] | |
if len(chat_history_tuple) > 0: | |
for cha, res in chat_history_tuple: | |
history_chat.append(cha) | |
history_response.append(res) | |
history_chat, history_response = self.keep_tail(history_chat, history_response) | |
print('history done') | |
new_query = user_message | |
query_embed = self.get_embedding(new_query) | |
# print("1") | |
# embeddings, embed_to_title = self.title_text_embedding(self.titles, self.title_to_text) | |
print("2") | |
selected_sample = self.retrieve_title(query_embed, self.embeddings, self.embed_to_title, 7) | |
print("3") | |
story, selected_sample = self.organize_story_with_maxlen(selected_sample) | |
## TODO: visualize seletected sample later | |
print('当前辅助sample:', selected_sample) | |
messages = self.organize_message_langchain(story, history_chat, history_response, new_query) | |
chat = ChatOpenAI(temperature=0) | |
return_msg = chat(messages) | |
response = return_msg.content | |
return response | |
def save_response(self, chat_history_tuple): | |
with open(f"{self.save_path}/conversation_{time.time()}.txt", "w") as file: | |
for cha, res in chat_history_tuple: | |
file.write(cha) | |
file.write("\n---\n") | |
file.write(res) | |
file.write("\n---\n") | |
def create_gradio(self): | |
# from google.colab import drive | |
# drive.mount(drive_path) | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
## Chat凉宫春日 ChatHaruhi | |
此版本为测试版本,非正式版本,正式版本功能更多,敬请期待 | |
""" | |
) | |
image_input = gr.Textbox(visible=False) | |
japanese_input = gr.Textbox(visible=False) | |
with gr.Row(): | |
chatbot = gr.Chatbot() | |
image_output = gr.Image() | |
role_name = gr.Textbox(label="角色名", placeholde="输入角色名") | |
msg = gr.Textbox(label="输入") | |
with gr.Row(): | |
clear = gr.Button("Clear") | |
sub = gr.Button("Submit") | |
image_button = gr.Button("给我一个图") | |
japanese_output = gr.Textbox(interactive=False) | |
def respond(role_name, user_message, chat_history): | |
input_message = role_name + ':「' + user_message + '」' | |
bot_message = self.get_response(input_message, chat_history) | |
chat_history.append((input_message, bot_message)) | |
self.save_response(chat_history) | |
# time.sleep(1) | |
jp_text = pipe(f'<-zh2ja-> {bot_message}')[0]['translation_text'] | |
return "" , chat_history, bot_message, jp_text | |
clear.click(lambda: None, None, chatbot, queue=False) | |
msg.submit(respond, [role_name, msg, chatbot], [msg, chatbot, image_input, japanese_output]) | |
sub.click(fn=respond, inputs=[role_name, msg, chatbot], outputs=[msg, chatbot, image_input, japanese_output]) | |
# with gr.Tab("text_to_text"): | |
# text_input = gr.Textbox() | |
# text_output = gr.Textbox() | |
# text_button = gr.Button('begin') | |
# text_button.click(text.text_to_text, inputs=text_input, outputs=text_output) | |
# with gr.Tab("text_to_iamge"): | |
# with gr.Row(): | |
# image_input = gr.Textbox() | |
# image_output = gr.Image() | |
# image_button = gr.Button("给我一个图") | |
image_button.click(text.text_to_image, inputs=image_input, outputs=image_output) | |
demo.launch(debug=True,share=True) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description="-----[Chat凉宫春日]-----") | |
parser.add_argument("--folder", default="../characters/haruhi/texts", help="text folder") | |
parser.add_argument("--system_prompt", default="../characters/haruhi/system_prompt.txt", help="store system_prompt") | |
parser.add_argument("--max_len_story", default=1500, type=int) | |
parser.add_argument("--max_len_history", default=1200, type=int) | |
# parser.add_argument("--save_path", default="/content/drive/MyDrive/GPTData/Haruhi-Lulu/") | |
parser.add_argument("--save_path", default=os.getcwd()+"/Suzumiya") | |
options = parser.parse_args() | |
params = { | |
"folder": options.folder, | |
"system_prompt": options.system_prompt, | |
"max_len_story": options.max_len_story, | |
"max_len_history": options.max_len_history, | |
"save_path": options.save_path | |
} | |
pipe = pipeline(model="engmatic-earth/mt5-zh-ja-en-trimmed-fine-tuned-v1", device=0,max_length=120) | |
run = Run(**params) | |
run.create_gradio() | |
# history_chat = [] | |
# history_response = [] | |
# chat_timer = 5 | |
# new_query = '鲁鲁:你好我是新同学鲁鲁' | |
# query_embed = run.get_embedding(new_query) | |
# titles, title_to_text = run.read_prompt_data() | |
# embeddings, embed_to_title = run.title_text_embedding(titles, title_to_text) | |
# selected_sample = run.retrieve_title(query_embed, embeddings, embed_to_title, 7) | |
# print('限制长度之前:', selected_sample) | |
# story, selected_sample = run.organize_story_with_maxlen(selected_sample) | |
# print('当前辅助sample:', selected_sample) | |
# messages = run.organize_message(story, history_chat, history_response, new_query) | |
# response = run.get_completion_from_messages(messages) | |
# print(response) | |
# history_chat.append(new_query) | |
# history_response.append(response) | |
# history_chat, history_response = run.keep_tail(history_chat, history_response) | |
# print(history_chat, history_response) | |