Spaces:
Runtime error
Runtime error
from .ChromaDB import ChromaDB | |
import os | |
from .utils import luotuo_openai_embedding, tiktokenizer | |
from .utils import response_postprocess | |
def get_text_from_data( data ): | |
if "text" in data: | |
return data['text'] | |
elif "enc_text" in data: | |
from .utils import base64_to_string | |
return base64_to_string( data['enc_text'] ) | |
else: | |
print("warning! failed to get text from data ", data) | |
return "" | |
class ChatHaruhi: | |
def __init__(self, system_prompt = None, \ | |
role_name = None, role_from_hf = None, | |
role_from_jsonl = None, \ | |
story_db=None, story_text_folder = None, \ | |
llm = 'openai', \ | |
embedding = 'luotuo_openai', \ | |
max_len_story = None, max_len_history = None, | |
verbose = False): | |
super(ChatHaruhi, self).__init__() | |
self.verbose = verbose | |
# constants | |
self.story_prefix_prompt = "Classic scenes for the role are as follows:\n" | |
self.k_search = 19 | |
self.narrator = ['旁白', '', 'scene','Scene','narrator' , 'Narrator'] | |
self.dialogue_divide_token = '\n###\n' | |
self.dialogue_bra_token = '「' | |
self.dialogue_ket_token = '」' | |
if system_prompt: | |
self.system_prompt = self.check_system_prompt( system_prompt ) | |
# TODO: embedding should be the seperately defined, so refactor this part later | |
if llm == 'openai': | |
# self.llm = LangChainGPT() | |
self.llm, self.tokenizer = self.get_models('openai') | |
elif llm == 'debug': | |
self.llm, self.tokenizer = self.get_models('debug') | |
elif llm == 'spark': | |
self.llm, self.tokenizer = self.get_models('spark') | |
elif llm == 'GLMPro': | |
self.llm, self.tokenizer = self.get_models('GLMPro') | |
elif llm == 'ChatGLM2GPT': | |
self.llm, self.tokenizer = self.get_models('ChatGLM2GPT') | |
self.story_prefix_prompt = '\n' | |
elif llm == "BaiChuan2GPT": | |
self.llm, self.tokenizer = self.get_models('BaiChuan2GPT') | |
elif llm == "BaiChuanAPIGPT": | |
self.llm, self.tokenizer = self.get_models('BaiChuanAPIGPT') | |
elif llm == "ernie3.5": | |
self.llm, self.tokenizer = self.get_models('ernie3.5') | |
elif llm == "ernie4.0": | |
self.llm, self.tokenizer = self.get_models('ernie4.0') | |
elif "qwen" in llm: | |
self.llm, self.tokenizer = self.get_models(llm) | |
else: | |
print(f'warning! undefined llm {llm}, use openai instead.') | |
self.llm, self.tokenizer = self.get_models('openai') | |
if embedding == 'luotuo_openai': | |
self.embedding = luotuo_openai_embedding | |
elif embedding == 'bge_en': | |
from .utils import get_bge_embedding | |
self.embedding = get_bge_embedding | |
elif embedding == 'bge_zh': | |
from .utils import get_bge_zh_embedding | |
self.embedding = get_bge_zh_embedding | |
else: | |
print(f'warning! undefined embedding {embedding}, use luotuo_openai instead.') | |
self.embedding = luotuo_openai_embedding | |
if role_name: | |
# TODO move into a function | |
from .role_name_to_file import get_folder_role_name | |
# correct role_name to folder_role_name | |
role_name, url = get_folder_role_name(role_name) | |
unzip_folder = f'./temp_character_folder/temp_{role_name}' | |
db_folder = os.path.join(unzip_folder, f'content/{role_name}') | |
system_prompt = os.path.join(unzip_folder, f'content/system_prompt.txt') | |
if not os.path.exists(unzip_folder): | |
# not yet downloaded | |
# url = f'https://github.com/LC1332/Haruhi-2-Dev/raw/main/data/character_in_zip/{role_name}.zip' | |
import requests, zipfile, io | |
r = requests.get(url) | |
z = zipfile.ZipFile(io.BytesIO(r.content)) | |
z.extractall(unzip_folder) | |
if self.verbose: | |
print(f'loading pre-defined character {role_name}...') | |
self.db = ChromaDB() | |
self.db.load(db_folder) | |
self.system_prompt = self.check_system_prompt(system_prompt) | |
elif role_from_hf: | |
# TODO move into a function | |
from datasets import load_dataset | |
if role_from_hf.count("/") == 1: | |
dataset = load_dataset(role_from_hf) | |
datas = dataset["train"] | |
elif role_from_hf.count("/") >= 2: | |
split_index = role_from_hf.index('/') | |
second_split_index = role_from_hf.index('/', split_index+1) | |
dataset_name = role_from_hf[:second_split_index] | |
split_name = role_from_hf[second_split_index+1:] | |
fname = split_name + '.jsonl' | |
dataset = load_dataset(dataset_name,data_files={'train':fname}) | |
datas = dataset["train"] | |
if embedding == 'luotuo_openai': | |
embed_name = 'luotuo_openai' | |
elif embedding == 'bge_en': | |
embed_name = 'bge_en_s15' | |
elif embedding == 'bge_zh': | |
embed_name = 'bge_zh_s15' | |
else: | |
print('warning! unkown embedding name ', embedding ,' while loading role') | |
embed_name = 'luotuo_openai' | |
texts, vecs, self.system_prompt = self.extract_text_vec_from_datas(datas, embed_name) | |
self.build_story_db_from_vec( texts, vecs ) | |
elif role_from_jsonl: | |
import json | |
datas = [] | |
with open( role_from_jsonl , encoding="utf-8") as f: | |
for line in f: | |
try: | |
data = json.loads(line) | |
# 逐行处理JSON数据 | |
datas.append(data) | |
except: | |
print("warning! failed to load json line ", line) | |
if embedding == 'luotuo_openai': | |
embed_name = 'luotuo_openai' | |
elif embedding == 'bge_en': | |
embed_name = 'bge_en_s15' | |
elif embedding == 'bge_zh': | |
embed_name = 'bge_zh_s15' | |
else: | |
print('warning! unkown embedding name ', embedding ,' while loading role') | |
embed_name = 'luotuo_openai' | |
texts, vecs, self.system_prompt = self.extract_text_vec_from_datas(datas, embed_name) | |
self.build_story_db_from_vec( texts, vecs ) | |
elif story_db: | |
self.db = ChromaDB() | |
self.db.load(story_db) | |
elif story_text_folder: | |
# print("Building story database from texts...") | |
self.db = self.build_story_db(story_text_folder) | |
else: | |
self.db = None | |
print('warning! database not yet figured out, both story_db and story_text_folder are not inputted.') | |
# raise ValueError("Either story_db or story_text_folder must be provided") | |
self.max_len_story, self.max_len_history = self.get_tokenlen_setting('openai') | |
if max_len_history is not None: | |
self.max_len_history = max_len_history | |
# user setting will override default setting | |
if max_len_story is not None: | |
self.max_len_story = max_len_story | |
# user setting will override default setting | |
self.dialogue_history = [] | |
def extract_text_vec_from_datas( self, datas, embed_name ): | |
# extract text and vec from huggingface dataset | |
# return texts, vecs | |
from .utils import base64_to_float_array | |
texts = [] | |
vecs = [] | |
for data in datas: | |
if data[embed_name] == 'system_prompt': | |
system_prompt = get_text_from_data( data ) | |
elif data[embed_name] == 'config': | |
pass | |
else: | |
vec = base64_to_float_array( data[embed_name] ) | |
text = get_text_from_data( data ) | |
vecs.append( vec ) | |
texts.append( text ) | |
return texts, vecs, system_prompt | |
def check_system_prompt(self, system_prompt): | |
# if system_prompt end with .txt, read the file with utf-8 | |
# else, return the string directly | |
if system_prompt.endswith('.txt'): | |
with open(system_prompt, 'r', encoding='utf-8') as f: | |
return f.read() | |
else: | |
return system_prompt | |
def get_models(self, model_name): | |
# TODO: if output only require tokenizer model, no need to initialize llm | |
# return the combination of llm, embedding and tokenizer | |
if model_name == 'openai': | |
from .LangChainGPT import LangChainGPT | |
return (LangChainGPT(), tiktokenizer) | |
elif model_name == 'debug': | |
from .PrintLLM import PrintLLM | |
return (PrintLLM(), tiktokenizer) | |
elif model_name == 'spark': | |
from .SparkGPT import SparkGPT | |
return (SparkGPT(), tiktokenizer) | |
elif model_name == 'GLMPro': | |
from .GLMPro import GLMPro | |
return (GLMPro(), tiktokenizer) | |
elif model_name == 'ernie3.5': | |
from .ErnieGPT import ErnieGPT | |
return (ErnieGPT(), tiktokenizer) | |
elif model_name == 'ernie4.0': | |
from .ErnieGPT import ErnieGPT | |
return (ErnieGPT(model="ernie-bot-4"), tiktokenizer) | |
elif model_name == "ChatGLM2GPT": | |
from .ChatGLM2GPT import ChatGLM2GPT, GLM_tokenizer | |
return (ChatGLM2GPT(), GLM_tokenizer) | |
elif model_name == "BaiChuan2GPT": | |
from .BaiChuan2GPT import BaiChuan2GPT, BaiChuan_tokenizer | |
return (BaiChuan2GPT(), BaiChuan_tokenizer) | |
elif model_name == "BaiChuanAPIGPT": | |
from .BaiChuanAPIGPT import BaiChuanAPIGPT | |
return (BaiChuanAPIGPT(), tiktokenizer) | |
elif "qwen" in model_name: | |
if model_name == "qwen118k_raw": | |
from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer | |
return (Qwen118k2GPT(model = "Qwen/Qwen-1_8B-Chat"), Qwen_tokenizer) | |
from huggingface_hub import HfApi | |
from huggingface_hub.hf_api import ModelFilter | |
qwen_api = HfApi() | |
qwen_models = qwen_api.list_models( | |
filter = ModelFilter(model_name=model_name), | |
author = "silk-road" | |
) | |
qwen_models_id = [] | |
for qwen_model in qwen_models: | |
qwen_models_id.append(qwen_model.id) | |
# print(model.id) | |
if "silk-road/" + model_name in qwen_models_id: | |
from .Qwen118k2GPT import Qwen118k2GPT, Qwen_tokenizer | |
return (Qwen118k2GPT(model = "silk-road/" + model_name), Qwen_tokenizer) | |
else: | |
print(f'warning! undefined model {model_name}, use openai instead.') | |
from .LangChainGPT import LangChainGPT | |
return (LangChainGPT(), tiktokenizer) | |
# print(models_id) | |
else: | |
print(f'warning! undefined model {model_name}, use openai instead.') | |
from .LangChainGPT import LangChainGPT | |
return (LangChainGPT(), tiktokenizer) | |
def get_tokenlen_setting( self, model_name ): | |
# return the setting of story and history token length | |
if model_name == 'openai': | |
return (1500, 1200) | |
else: | |
print(f'warning! undefined model {model_name}, use openai instead.') | |
return (1500, 1200) | |
def build_story_db_from_vec( self, texts, vecs ): | |
self.db = ChromaDB() | |
self.db.init_from_docs( vecs, texts) | |
def build_story_db(self, text_folder): | |
# 实现读取文本文件夹,抽取向量的逻辑 | |
db = ChromaDB() | |
strs = [] | |
# scan all txt file from text_folder | |
for file in os.listdir(text_folder): | |
# if file name end with txt | |
if file.endswith(".txt"): | |
file_path = os.path.join(text_folder, file) | |
with open(file_path, 'r', encoding='utf-8') as f: | |
strs.append(f.read()) | |
if self.verbose: | |
print(f'starting extract embedding... for { len(strs) } files') | |
vecs = [] | |
## TODO: 建立一个新的embedding batch test的单元测试 | |
## 新的支持list batch test的embedding代码 | |
## 用新的代码替换下面的for循环 | |
## Luotuo-bert-en也发布了,所以可以避开使用openai | |
for mystr in strs: | |
vecs.append(self.embedding(mystr)) | |
db.init_from_docs(vecs, strs) | |
return db | |
def save_story_db(self, db_path): | |
self.db.save(db_path) | |
def generate_prompt( self, text, role): | |
from langchain.schema import ( | |
AIMessage, | |
HumanMessage, | |
SystemMessage | |
) | |
messages = self.generate_messages( text, role ) | |
prompt = "" | |
for msg in messages: | |
if isinstance(msg, HumanMessage): | |
prompt += msg.content + "\n" | |
elif isinstance(msg, AIMessage): | |
prompt += msg.content + "\n" | |
elif isinstance(msg, SystemMessage): | |
prompt += msg.content + "\n" | |
return prompt | |
def generate_messages( self, text, role): | |
# add system prompt | |
self.llm.initialize_message() | |
self.llm.system_message(self.system_prompt) | |
# add story | |
query = self.get_query_string(text, role) | |
self.add_story( query ) | |
self.last_query = query | |
# add query | |
self.llm.user_message(query) | |
return self.llm.messages | |
def append_response( self, response, last_query = None ): | |
if last_query == None: | |
last_query_record = "" | |
if hasattr( self, "last_query" ): | |
last_query_record = self.last_query | |
else: | |
last_query_record = last_query | |
# record dialogue history | |
self.dialogue_history.append((last_query_record, response)) | |
def chat(self, text, role): | |
# add system prompt | |
self.llm.initialize_message() | |
self.llm.system_message(self.system_prompt) | |
# add story | |
query = self.get_query_string(text, role) | |
self.add_story( query ) | |
# add history | |
self.add_history() | |
# add query | |
self.llm.user_message(query) | |
# get response | |
response_raw = self.llm.get_response() | |
response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token) | |
# record dialogue history | |
self.dialogue_history.append((query, response)) | |
return response | |
def get_query_string(self, text, role): | |
if role in self.narrator: | |
return role + ":" + text | |
else: | |
return f"{role}:{self.dialogue_bra_token}{text}{self.dialogue_ket_token}" | |
def add_story(self, query): | |
if self.db is None: | |
return | |
query_vec = self.embedding(query) | |
stories = self.db.search(query_vec, self.k_search) | |
story_string = self.story_prefix_prompt | |
sum_story_token = self.tokenizer(story_string) | |
for story in stories: | |
story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token) | |
if sum_story_token + story_token > self.max_len_story: | |
break | |
else: | |
sum_story_token += story_token | |
story_string += story + self.dialogue_divide_token | |
self.llm.user_message(story_string) | |
def add_history(self): | |
if len(self.dialogue_history) == 0: | |
return | |
sum_history_token = 0 | |
flag = 0 | |
for query, response in reversed(self.dialogue_history): | |
current_count = 0 | |
if query is not None: | |
current_count += self.tokenizer(query) | |
if response is not None: | |
current_count += self.tokenizer(response) | |
sum_history_token += current_count | |
if sum_history_token > self.max_len_history: | |
break | |
else: | |
flag += 1 | |
if flag == 0: | |
print('warning! no history added. the last dialogue is too long.') | |
for (query, response) in self.dialogue_history[-flag:]: | |
if query is not None: | |
self.llm.user_message(query) | |
if response is not None: | |
self.llm.ai_message(response) | |