from .utils import base64_to_float_array, base64_to_string def get_text_from_data( data ): if "text" in data: return data['text'] elif "enc_text" in data: # from .utils import base64_to_string return base64_to_string( data['enc_text'] ) else: print("warning! failed to get text from data ", data) return "" def parse_rag(text): lines = text.split("\n") ans = [] for i, line in enumerate(lines): if "{{RAG对话}}" in line: ans.append({"n": 1, "max_token": -1, "query": "default", "lid": i}) elif "{{RAG对话|" in line: query_info = line.split("|")[1].rstrip("}}") ans.append({"n": 1, "max_token": -1, "query": query_info, "lid": i}) elif "{{RAG多对话|" in line: parts = line.split("|") max_token = int(parts[1].split("<=")[1]) max_n = int(parts[2].split("<=")[1].rstrip("}}")) ans.append({"n": max_n, "max_token": max_token, "query": "default", "lid": i}) return ans class ChatHaruhi: def __init__(self, role_name = None, user_name = None, persona = None, stories = None, story_vecs = None, role_from_hf = None, role_from_jsonl = None, llm = None, # 默认的message2response的函数 llm_async = None, # 默认的message2response的async函数 user_name_in_message = "default", verbose = None, embed_name = None, embedding = None, db = None, token_counter = "default", max_input_token = 1800, max_len_story_haruhi = 1000, max_story_n_haruhi = 5 ): self.verbose = True if verbose is None or verbose else False self.db = db self.embed_name = embed_name self.max_len_story_haruhi = max_len_story_haruhi # 这个设置只对过往Haruhi的sugar角色有效 self.max_story_n_haruhi = max_story_n_haruhi # 这个设置只对过往Haruhi的sugar角色有效 self.last_query_msg = None if embedding is None: self.embedding = self.set_embedding_with_name( embed_name ) if persona and role_name and stories and story_vecs and len(stories) == len(story_vecs): # 完全从外部设置,这个时候要求story_vecs和embedding的返回长度一致 self.persona, self.role_name, self.user_name = persona, role_name, user_name self.build_db(stories, story_vecs) elif persona and role_name and stories: # 从stories中提取story_vecs,重新用self.embedding进行embedding story_vecs = self.extract_story_vecs(stories) self.persona, self.role_name, self.user_name = persona, role_name, user_name self.build_db(stories, story_vecs) elif role_from_hf: # 从hf加载role self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf) if new_role_name: self.role_name = new_role_name else: self.role_name = role_name self.user_name = user_name self.build_db(self.stories, self.story_vecs) elif role_from_jsonl: # 从jsonl加载role self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_jsonl(role_from_jsonl) if new_role_name: self.role_name = new_role_name else: self.role_name = role_name self.user_name = user_name self.build_db(self.stories, self.story_vecs) elif persona and role_name: # 这个时候也就是说没有任何的RAG, self.persona, self.role_name, self.user_name = persona, role_name, user_name self.db = None elif role_name and self.check_sugar( role_name ): # 这个时候是sugar的role self.persona, self.role_name, self.stories, self.story_vecs = self.load_role_from_sugar( role_name ) self.build_db(self.stories, self.story_vecs) # 与 江YH讨论 所有的载入方式都要在外部使用 add_rag_prompt_after_persona() 防止混淆 # self.add_rag_prompt_after_persona() else: raise ValueError("persona和role_name必须同时设置,或者role_name是ChatHaruhi的预设人物") self.llm, self.llm_async = llm, llm_async if not self.llm and self.verbose: print("warning, llm没有设置,仅get_message起作用,调用chat将回复idle message") self.user_name_in_message = user_name_in_message self.previous_user_pool = set([user_name]) if user_name else set() self.current_user_name_in_message = user_name_in_message.lower() == "add" self.idle_message = "idel message, you see this because self.llm has not been set." if token_counter.lower() == "default": # TODO change load from util from .utils import tiktoken_counter self.token_counter = tiktoken_counter elif token_counter == None: self.token_counter = lambda x: 0 else: self.token_counter = token_counter if self.verbose: print("user set costomized token_counter") self.max_input_token = max_input_token self.history = [] def check_sugar(self, role_name): from .sugar_map import sugar_role_names, enname2zhname return role_name in sugar_role_names def load_role_from_sugar(self, role_name): from .sugar_map import sugar_role_names, enname2zhname en_role_name = sugar_role_names[role_name] new_role_name = enname2zhname[en_role_name] role_from_hf = "silk-road/ChatHaruhi-RolePlaying/" + en_role_name persona, _, stories, story_vecs = self.load_role_from_hf(role_from_hf) return persona, new_role_name, stories, story_vecs def add_rag_prompt_after_persona( self ): rag_sentence = "{{RAG多对话|token<=" + str(self.max_len_story_haruhi) + "|n<=" + str(self.max_story_n_haruhi) + "}}" self.persona += "Classic scenes for the role are as follows:\n" + rag_sentence + "\n" def set_embedding_with_name(self, embed_name): if embed_name is None or embed_name == "bge_zh": from .embeddings import get_bge_zh_embedding self.embed_name = "bge_zh" return get_bge_zh_embedding elif embed_name == "foo": from .embeddings import foo_embedding return foo_embedding elif embed_name == "bce": from .embeddings import foo_bce return foo_bce elif embed_name == "openai" or embed_name == "luotuo_openai": from .embeddings import foo_openai return foo_openai def set_new_user(self, user): if len(self.previous_user_pool) > 0 and user not in self.previous_user_pool: if self.user_name_in_message.lower() == "default": if self.verbose: print(f'new user {user} included in conversation') self.current_user_name_in_message = True self.user_name = user self.previous_user_pool.add(user) def chat(self, user, text): self.set_new_user(user) message = self.get_message(user, text) if self.llm: response = self.llm(message) self.append_message(response) return response return None async def async_chat(self, user, text): self.set_new_user(user) message = self.get_message(user, text) if self.llm_async: response = await self.llm_async(message) self.append_message(response) return response def parse_rag_from_persona(self, persona, text = None): #每个query_rag需要饱含 # "n" 需要几个story # "max_token" 最多允许多少个token,如果-1则不限制 # "query" 需要查询的内容,如果等同于"default"则替换为text # "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容 query_rags = parse_rag( persona ) if text is not None: for rag in query_rags: if rag['query'] == "default": rag['query'] = text return query_rags, self.token_counter(persona) def append_message( self, response , speaker = None ): if self.last_query_msg is not None: self.history.append(self.last_query_msg) self.last_query_msg = None if speaker is None: # 如果role是none,则认为是本角色{{role}}输出的句子 self.history.append({"speaker":"{{role}}","content":response}) # 叫speaker是为了和role进行区分 else: self.history.append({"speaker":speaker,"content":response}) def check_recompute_stories_token(self): return len(self.db.metas) == len(self.db.stories) def recompute_stories_token(self): self.db.metas = [self.token_counter(story) for story in self.db.stories] def rag_retrieve( self, query, n, max_token, avoid_ids = [] ): # 返回一个rag_id的列表 query_vec = self.embedding(query) self.db.clean_flag() self.db.disable_story_with_ids( avoid_ids ) retrieved_ids = self.db.search( query_vec, n ) if self.check_recompute_stories_token(): self.recompute_stories_token() sum_token = 0 ans = [] for i in range(0, len(retrieved_ids)): if i == 0: sum_token += self.db.metas[retrieved_ids[i]] ans.append(retrieved_ids[i]) continue else: sum_token += self.db.metas[retrieved_ids[i]] if sum_token <= max_token: ans.append(retrieved_ids[i]) else: break return ans def rag_retrieve_all( self, query_rags, rest_limit ): # 返回一个rag_ids的列表 retrieved_ids = [] rag_ids = [] for query_rag in query_rags: query = query_rag['query'] n = query_rag['n'] max_token = rest_limit if rest_limit > query_rag['max_token'] and query_rag['max_token'] > 0: max_token = query_rag['max_token'] rag_id = self.rag_retrieve( query, n, max_token, avoid_ids = retrieved_ids ) rag_ids.append( rag_id ) retrieved_ids += rag_id return rag_ids def append_history_under_limit(self, message, rest_limit): # 返回一个messages的列表 # print("call append history_under_limit") # 从后往前计算token,不超过rest limit, # 如果speaker是{{role}J,则message的role是assistant current_limit = rest_limit history_list = [] for item in reversed(self.history): current_token = self.token_counter(item['content']) current_limit -= current_token if current_limit < 0: break else: history_list.append(item) history_list = list(reversed(history_list)) # TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息 for item in history_list: if item['speaker'] == "{{role}}": message.append({"role":"assistant","content":item['content']}) else: message.append({"role":"user","content":item['content']}) return message def get_message(self, user, text): query_token = self.token_counter(text) # 首先获取需要多少个rag story query_rags, persona_token = self.parse_rag_from_persona( self.persona, text ) #每个query_rag需要饱含 # "n" 需要几个story # "max_token" 最多允许多少个token,如果-1则不限制 # "query" 需要查询的内容,如果等同于"default"则替换为text # "lid" 需要替换的行,这里直接进行行替换,忽视行的其他内容 rest_limit = self.max_input_token - persona_token - query_token if self.verbose: print(f"query_rags: {query_rags} rest_limit = { rest_limit }") rag_ids = self.rag_retrieve_all( query_rags, rest_limit ) # 将rag_ids对应的故事 替换到persona中 augmented_persona = self.augment_persona( self.persona, rag_ids, query_rags ) system_prompt = self.package_system_prompt( self.role_name, augmented_persona ) token_for_system = self.token_counter( system_prompt ) rest_limit = self.max_input_token - token_for_system - query_token message = [{"role":"system","content":system_prompt}] message = self.append_history_under_limit( message, rest_limit ) # TODO: 之后为了解决多人对话,这了content还会额外增加speaker: content这样的信息 message.append({"role":"user","content":text}) self.last_query_msg = {"speaker":user,"content":text} return message def package_system_prompt(self, role_name, augmented_persona): bot_name = role_name return f"""You are now in roleplay conversation mode. Pretend to be {bot_name} whose persona follows: {augmented_persona} You will stay in-character whenever possible, and generate responses as if you were {bot_name}""" def augment_persona(self, persona, rag_ids, query_rags): lines = persona.split("\n") for rag_id, query_rag in zip(rag_ids, query_rags): lid = query_rag['lid'] new_text = "" for id in rag_id: new_text += "###\n" + self.db.stories[id].strip() + "\n" new_text = new_text.strip() lines[lid] = new_text return "\n".join(lines) def load_role_from_jsonl( self, role_from_jsonl ): import json datas = [] with open(role_from_jsonl, 'r') as f: for line in f: try: datas.append(json.loads(line)) except: continue column_name = "" from .embeddings import embedname2columnname if self.embed_name in embedname2columnname: column_name = embedname2columnname[self.embed_name] else: print('warning! unkown embedding name ', self.embed_name ,' while loading role') column_name = 'luotuo_openai' stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name) return persona, None, stories, story_vecs def load_role_from_hf(self, role_from_hf): # 从hf加载role # self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf) from datasets import load_dataset if role_from_hf.count("/") == 1: dataset = load_dataset(role_from_hf) datas = dataset["train"] elif role_from_hf.count("/") >= 2: split_index = role_from_hf.index('/') second_split_index = role_from_hf.index('/', split_index+1) dataset_name = role_from_hf[:second_split_index] split_name = role_from_hf[second_split_index+1:] fname = split_name + '.jsonl' dataset = load_dataset(dataset_name,data_files={'train':fname}) datas = dataset["train"] column_name = "" from .embeddings import embedname2columnname if self.embed_name in embedname2columnname: column_name = embedname2columnname[self.embed_name] else: print('warning! unkown embedding name ', self.embed_name ,' while loading role') column_name = 'luotuo_openai' stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name) return persona, None, stories, story_vecs def extract_text_vec_from_datas(self, datas, column_name): # 从datas中提取text和vec # extract text and vec from huggingface dataset # return texts, vecs # from .utils import base64_to_float_array texts = [] vecs = [] for data in datas: if data[column_name] == 'system_prompt': system_prompt = get_text_from_data( data ) elif data[column_name] == 'config': pass else: vec = base64_to_float_array( data[column_name] ) text = get_text_from_data( data ) vecs.append( vec ) texts.append( text ) return texts, vecs, system_prompt def extract_story_vecs(self, stories): # 从stories中提取story_vecs if self.verbose: print(f"re-extract vector for {len(stories)} stories") story_vecs = [] from .embeddings import embedshortname2model_name from .embeddings import device if device.type != "cpu" and self.embed_name in embedshortname2model_name: # model_name = "BAAI/bge-small-zh-v1.5" model_name = embedshortname2model_name[self.embed_name] from .utils import get_general_embeddings_safe story_vecs = get_general_embeddings_safe( stories, model_name = model_name ) # 使用batch的方式进行embedding,非常快 else: from tqdm import tqdm for story in tqdm(stories): story_vecs.append(self.embedding(story)) return story_vecs def build_db(self, stories, story_vecs): # db的构造函数 if self.db is None: from .NaiveDB import NaiveDB self.db = NaiveDB() self.db.build_db(stories, story_vecs)