|
from .utils import base64_to_float_array, base64_to_string |
|
|
|
def get_text_from_data( data ): |
|
if "text" in data: |
|
return data['text'] |
|
elif "enc_text" in data: |
|
|
|
return base64_to_string( data['enc_text'] ) |
|
else: |
|
print("warning! failed to get text from data ", data) |
|
return "" |
|
|
|
def parse_rag(text): |
|
lines = text.split("\n") |
|
ans = [] |
|
|
|
for i, line in enumerate(lines): |
|
if "{{RAG对话}}" in line: |
|
ans.append({"n": 1, "max_token": -1, "query": "default", "lid": i}) |
|
elif "{{RAG对话|" in line: |
|
query_info = line.split("|")[1].rstrip("}}") |
|
ans.append({"n": 1, "max_token": -1, "query": query_info, "lid": i}) |
|
elif "{{RAG多对话|" in line: |
|
parts = line.split("|") |
|
max_token = int(parts[1].split("<=")[1]) |
|
max_n = int(parts[2].split("<=")[1].rstrip("}}")) |
|
ans.append({"n": max_n, "max_token": max_token, "query": "default", "lid": i}) |
|
|
|
return ans |
|
|
|
class ChatHaruhi: |
|
def __init__(self, |
|
role_name = None, |
|
user_name = None, |
|
persona = None, |
|
stories = None, |
|
story_vecs = None, |
|
role_from_hf = None, |
|
role_from_jsonl = None, |
|
llm = None, |
|
llm_async = None, |
|
user_name_in_message = "default", |
|
verbose = None, |
|
embed_name = None, |
|
embedding = None, |
|
db = None, |
|
token_counter = "default", |
|
max_input_token = 1800, |
|
max_len_story_haruhi = 1000, |
|
max_story_n_haruhi = 5 |
|
): |
|
|
|
self.verbose = True if verbose is None or verbose else False |
|
|
|
self.db = db |
|
|
|
self.embed_name = embed_name |
|
|
|
self.max_len_story_haruhi = max_len_story_haruhi |
|
self.max_story_n_haruhi = max_story_n_haruhi |
|
|
|
self.last_query_msg = None |
|
|
|
if embedding is None: |
|
self.embedding = self.set_embedding_with_name( embed_name ) |
|
|
|
if persona and role_name and stories and story_vecs and len(stories) == len(story_vecs): |
|
|
|
self.persona, self.role_name, self.user_name = persona, role_name, user_name |
|
self.build_db(stories, story_vecs) |
|
elif persona and role_name and stories: |
|
|
|
story_vecs = self.extract_story_vecs(stories) |
|
self.persona, self.role_name, self.user_name = persona, role_name, user_name |
|
self.build_db(stories, story_vecs) |
|
elif role_from_hf: |
|
|
|
self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_hf(role_from_hf) |
|
if new_role_name: |
|
self.role_name = new_role_name |
|
else: |
|
self.role_name = role_name |
|
self.user_name = user_name |
|
self.build_db(self.stories, self.story_vecs) |
|
elif role_from_jsonl: |
|
|
|
self.persona, new_role_name, self.stories, self.story_vecs = self.load_role_from_jsonl(role_from_jsonl) |
|
if new_role_name: |
|
self.role_name = new_role_name |
|
else: |
|
self.role_name = role_name |
|
self.user_name = user_name |
|
self.build_db(self.stories, self.story_vecs) |
|
elif persona and role_name: |
|
|
|
self.persona, self.role_name, self.user_name = persona, role_name, user_name |
|
self.db = None |
|
elif role_name and self.check_sugar( role_name ): |
|
|
|
self.persona, self.role_name, self.stories, self.story_vecs = self.load_role_from_sugar( role_name ) |
|
self.build_db(self.stories, self.story_vecs) |
|
|
|
|
|
else: |
|
raise ValueError("persona和role_name必须同时设置,或者role_name是ChatHaruhi的预设人物") |
|
|
|
self.llm, self.llm_async = llm, llm_async |
|
if not self.llm and self.verbose: |
|
print("warning, llm没有设置,仅get_message起作用,调用chat将回复idle message") |
|
|
|
self.user_name_in_message = user_name_in_message |
|
self.previous_user_pool = set([user_name]) if user_name else set() |
|
self.current_user_name_in_message = user_name_in_message.lower() == "add" |
|
|
|
self.idle_message = "idel message, you see this because self.llm has not been set." |
|
|
|
if token_counter.lower() == "default": |
|
|
|
from .utils import tiktoken_counter |
|
self.token_counter = tiktoken_counter |
|
elif token_counter == None: |
|
self.token_counter = lambda x: 0 |
|
else: |
|
self.token_counter = token_counter |
|
if self.verbose: |
|
print("user set costomized token_counter") |
|
|
|
self.max_input_token = max_input_token |
|
|
|
self.history = [] |
|
|
|
def check_sugar(self, role_name): |
|
from .sugar_map import sugar_role_names, enname2zhname |
|
return role_name in sugar_role_names |
|
|
|
def load_role_from_sugar(self, role_name): |
|
from .sugar_map import sugar_role_names, enname2zhname |
|
en_role_name = sugar_role_names[role_name] |
|
new_role_name = enname2zhname[en_role_name] |
|
role_from_hf = "silk-road/ChatHaruhi-RolePlaying/" + en_role_name |
|
persona, _, stories, story_vecs = self.load_role_from_hf(role_from_hf) |
|
|
|
return persona, new_role_name, stories, story_vecs |
|
|
|
def add_rag_prompt_after_persona( self ): |
|
rag_sentence = "{{RAG多对话|token<=" + str(self.max_len_story_haruhi) + "|n<=" + str(self.max_story_n_haruhi) + "}}" |
|
self.persona += "Classic scenes for the role are as follows:\n" + rag_sentence + "\n" |
|
|
|
def set_embedding_with_name(self, embed_name): |
|
if embed_name is None or embed_name == "bge_zh": |
|
from .embeddings import get_bge_zh_embedding |
|
self.embed_name = "bge_zh" |
|
return get_bge_zh_embedding |
|
elif embed_name == "foo": |
|
from .embeddings import foo_embedding |
|
return foo_embedding |
|
elif embed_name == "bce": |
|
from .embeddings import foo_bce |
|
return foo_bce |
|
elif embed_name == "openai" or embed_name == "luotuo_openai": |
|
from .embeddings import foo_openai |
|
return foo_openai |
|
|
|
def set_new_user(self, user): |
|
if len(self.previous_user_pool) > 0 and user not in self.previous_user_pool: |
|
if self.user_name_in_message.lower() == "default": |
|
if self.verbose: |
|
print(f'new user {user} included in conversation') |
|
self.current_user_name_in_message = True |
|
self.user_name = user |
|
self.previous_user_pool.add(user) |
|
|
|
def chat(self, user, text): |
|
self.set_new_user(user) |
|
message = self.get_message(user, text) |
|
if self.llm: |
|
response = self.llm(message) |
|
self.append_message(response) |
|
return response |
|
return None |
|
|
|
async def async_chat(self, user, text): |
|
self.set_new_user(user) |
|
message = self.get_message(user, text) |
|
if self.llm_async: |
|
response = await self.llm_async(message) |
|
self.append_message(response) |
|
return response |
|
|
|
def parse_rag_from_persona(self, persona, text = None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
query_rags = parse_rag( persona ) |
|
|
|
if text is not None: |
|
for rag in query_rags: |
|
if rag['query'] == "default": |
|
rag['query'] = text |
|
|
|
return query_rags, self.token_counter(persona) |
|
|
|
def append_message( self, response , speaker = None ): |
|
if self.last_query_msg is not None: |
|
self.history.append(self.last_query_msg) |
|
self.last_query_msg = None |
|
|
|
if speaker is None: |
|
|
|
self.history.append({"speaker":"{{role}}","content":response}) |
|
|
|
else: |
|
self.history.append({"speaker":speaker,"content":response}) |
|
|
|
def check_recompute_stories_token(self): |
|
return len(self.db.metas) == len(self.db.stories) |
|
|
|
def recompute_stories_token(self): |
|
self.db.metas = [self.token_counter(story) for story in self.db.stories] |
|
|
|
def rag_retrieve( self, query, n, max_token, avoid_ids = [] ): |
|
|
|
query_vec = self.embedding(query) |
|
|
|
self.db.clean_flag() |
|
self.db.disable_story_with_ids( avoid_ids ) |
|
|
|
retrieved_ids = self.db.search( query_vec, n ) |
|
|
|
if self.check_recompute_stories_token(): |
|
self.recompute_stories_token() |
|
|
|
sum_token = 0 |
|
|
|
ans = [] |
|
|
|
for i in range(0, len(retrieved_ids)): |
|
if i == 0: |
|
sum_token += self.db.metas[retrieved_ids[i]] |
|
ans.append(retrieved_ids[i]) |
|
continue |
|
else: |
|
sum_token += self.db.metas[retrieved_ids[i]] |
|
if sum_token <= max_token: |
|
ans.append(retrieved_ids[i]) |
|
else: |
|
break |
|
|
|
return ans |
|
|
|
|
|
def rag_retrieve_all( self, query_rags, rest_limit ): |
|
|
|
retrieved_ids = [] |
|
rag_ids = [] |
|
|
|
for query_rag in query_rags: |
|
query = query_rag['query'] |
|
n = query_rag['n'] |
|
max_token = rest_limit |
|
if rest_limit > query_rag['max_token'] and query_rag['max_token'] > 0: |
|
max_token = query_rag['max_token'] |
|
|
|
rag_id = self.rag_retrieve( query, n, max_token, avoid_ids = retrieved_ids ) |
|
rag_ids.append( rag_id ) |
|
retrieved_ids += rag_id |
|
|
|
return rag_ids |
|
|
|
def append_history_under_limit(self, message, rest_limit): |
|
|
|
|
|
|
|
|
|
current_limit = rest_limit |
|
|
|
history_list = [] |
|
|
|
for item in reversed(self.history): |
|
current_token = self.token_counter(item['content']) |
|
current_limit -= current_token |
|
if current_limit < 0: |
|
break |
|
else: |
|
history_list.append(item) |
|
|
|
history_list = list(reversed(history_list)) |
|
|
|
|
|
|
|
for item in history_list: |
|
if item['speaker'] == "{{role}}": |
|
message.append({"role":"assistant","content":item['content']}) |
|
else: |
|
message.append({"role":"user","content":item['content']}) |
|
|
|
return message |
|
|
|
def get_message(self, user, text): |
|
query_token = self.token_counter(text) |
|
|
|
|
|
query_rags, persona_token = self.parse_rag_from_persona( self.persona, text ) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rest_limit = self.max_input_token - persona_token - query_token |
|
|
|
if self.verbose: |
|
print(f"query_rags: {query_rags} rest_limit = { rest_limit }") |
|
|
|
rag_ids = self.rag_retrieve_all( query_rags, rest_limit ) |
|
|
|
|
|
augmented_persona = self.augment_persona( self.persona, rag_ids, query_rags ) |
|
|
|
system_prompt = self.package_system_prompt( self.role_name, augmented_persona ) |
|
|
|
token_for_system = self.token_counter( system_prompt ) |
|
|
|
rest_limit = self.max_input_token - token_for_system - query_token |
|
|
|
message = [{"role":"system","content":system_prompt}] |
|
|
|
message = self.append_history_under_limit( message, rest_limit ) |
|
|
|
|
|
|
|
message.append({"role":"user","content":text}) |
|
|
|
self.last_query_msg = {"speaker":user,"content":text} |
|
|
|
return message |
|
|
|
def package_system_prompt(self, role_name, augmented_persona): |
|
bot_name = role_name |
|
return f"""You are now in roleplay conversation mode. Pretend to be {bot_name} whose persona follows: |
|
{augmented_persona} |
|
|
|
You will stay in-character whenever possible, and generate responses as if you were {bot_name}""" |
|
|
|
|
|
def augment_persona(self, persona, rag_ids, query_rags): |
|
lines = persona.split("\n") |
|
for rag_id, query_rag in zip(rag_ids, query_rags): |
|
lid = query_rag['lid'] |
|
new_text = "" |
|
for id in rag_id: |
|
new_text += "###\n" + self.db.stories[id].strip() + "\n" |
|
new_text = new_text.strip() |
|
lines[lid] = new_text |
|
return "\n".join(lines) |
|
|
|
def load_role_from_jsonl( self, role_from_jsonl ): |
|
import json |
|
datas = [] |
|
with open(role_from_jsonl, 'r') as f: |
|
for line in f: |
|
try: |
|
datas.append(json.loads(line)) |
|
except: |
|
continue |
|
|
|
column_name = "" |
|
|
|
from .embeddings import embedname2columnname |
|
|
|
if self.embed_name in embedname2columnname: |
|
column_name = embedname2columnname[self.embed_name] |
|
else: |
|
print('warning! unkown embedding name ', self.embed_name ,' while loading role') |
|
column_name = 'luotuo_openai' |
|
|
|
stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name) |
|
|
|
return persona, None, stories, story_vecs |
|
|
|
|
|
def load_role_from_hf(self, role_from_hf): |
|
|
|
|
|
|
|
from datasets import load_dataset |
|
|
|
if role_from_hf.count("/") == 1: |
|
dataset = load_dataset(role_from_hf) |
|
datas = dataset["train"] |
|
elif role_from_hf.count("/") >= 2: |
|
split_index = role_from_hf.index('/') |
|
second_split_index = role_from_hf.index('/', split_index+1) |
|
dataset_name = role_from_hf[:second_split_index] |
|
split_name = role_from_hf[second_split_index+1:] |
|
|
|
fname = split_name + '.jsonl' |
|
dataset = load_dataset(dataset_name,data_files={'train':fname}) |
|
datas = dataset["train"] |
|
|
|
column_name = "" |
|
|
|
from .embeddings import embedname2columnname |
|
|
|
if self.embed_name in embedname2columnname: |
|
column_name = embedname2columnname[self.embed_name] |
|
else: |
|
print('warning! unkown embedding name ', self.embed_name ,' while loading role') |
|
column_name = 'luotuo_openai' |
|
|
|
stories, story_vecs, persona = self.extract_text_vec_from_datas(datas, column_name) |
|
|
|
return persona, None, stories, story_vecs |
|
|
|
def extract_text_vec_from_datas(self, datas, column_name): |
|
|
|
|
|
|
|
|
|
|
|
texts = [] |
|
vecs = [] |
|
for data in datas: |
|
if data[column_name] == 'system_prompt': |
|
system_prompt = get_text_from_data( data ) |
|
elif data[column_name] == 'config': |
|
pass |
|
else: |
|
vec = base64_to_float_array( data[column_name] ) |
|
text = get_text_from_data( data ) |
|
vecs.append( vec ) |
|
texts.append( text ) |
|
return texts, vecs, system_prompt |
|
|
|
def extract_story_vecs(self, stories): |
|
|
|
|
|
if self.verbose: |
|
print(f"re-extract vector for {len(stories)} stories") |
|
|
|
story_vecs = [] |
|
|
|
from .embeddings import embedshortname2model_name |
|
from .embeddings import device |
|
|
|
if device.type != "cpu" and self.embed_name in embedshortname2model_name: |
|
|
|
model_name = embedshortname2model_name[self.embed_name] |
|
|
|
from .utils import get_general_embeddings_safe |
|
story_vecs = get_general_embeddings_safe( stories, model_name = model_name ) |
|
|
|
else: |
|
from tqdm import tqdm |
|
for story in tqdm(stories): |
|
story_vecs.append(self.embedding(story)) |
|
|
|
return story_vecs |
|
|
|
def build_db(self, stories, story_vecs): |
|
|
|
if self.db is None: |
|
from .NaiveDB import NaiveDB |
|
self.db = NaiveDB() |
|
self.db.build_db(stories, story_vecs) |
|
|
|
|