# coding=utf-8 from typing import Dict from typing import List from typing import Tuple from typing import Union from pathlib import Path from src.logger import LoggerFactory from src.prompt_concat import GetManualTestSamples, CreateTestDataset from src.utils import decode_csv_to_json, load_json, save_to_json from threading import Thread from transformers import ( AutoModelForCausalLM, AutoTokenizer, GenerationConfig, TextIteratorStreamer, ) from typing import List import gradio as gr import logging import os import shutil import torch import warnings import random import spaces logger = LoggerFactory.create_logger(name="test", level=logging.INFO) warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') MODEL_PATH = os.environ.get('MODEL_PATH', 'IndexTeam/Index-1.9B-Character') TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True) character_path = "./character" def _resolve_path(path: Union[str, Path]) -> Path: return Path(path).expanduser().resolve() # logger = LoggerFactory.create_logger(name="test", level=logging.INFO) # warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') # config_data = load_json("config/config.json") # model_path = config_data["huggingface_local_path"] # character_path = "./character" # tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) # model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto", # trust_remote_code=True) def generate_with_question(question, role_name, role_file_path): question_in = "\n".join(["\n".join(pair) for pair in question]) g = GetManualTestSamples( role_name=role_name, role_data_path=f"./character/{role_file_path}.json", save_samples_dir="./character", save_samples_path= role_file_path + "_rag.json", prompt_path="./prompt/dataset_character.txt", max_seq_len=4000 ) g.get_qa_samples_by_query( questions_query=question_in, keep_retrieve_results_flag=True ) def create_datasets(role_name, role_file_path): testset = [] role_samples_path = os.path.join("./character", role_file_path + "_rag.json") c = CreateTestDataset(role_name=role_name, role_samples_path=role_samples_path, role_data_path=role_samples_path, prompt_path="./prompt/dataset_character.txt" ) res = c.load_samples() testset.extend(res) save_to_json(testset, f"./character/{role_file_path}_测试问题.json") @spaces.GPU def hf_gen(dialog: List, role_name, role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): generate_with_question(dialog, role_name,role_file_path) create_datasets(role_name,role_file_path) json_data = load_json(f"{character_path}/{role_file_path}_测试问题.json")[0] text = json_data["input_text"] inputs = tokenizer(text, return_tensors="pt") if torch.cuda.is_available(): model.to("cuda") inputs.to("cuda") streamer = TextIteratorStreamer(tokenizer, **tokenizer.init_kwargs) generation_kwargs = dict( inputs, do_sample=True, top_k=int(top_k), top_p=float(top_p), temperature=float(temperature), repetition_penalty=float(repetition_penalty), max_new_tokens=int(max_dec_len), pad_token_id=tokenizer.eos_token_id, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() answer = "" for new_text in streamer: answer += new_text yield answer[len(text):] @spaces.GPU def generate(chat_history: List, query, role_name, role_desc, role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): """generate after hitting "submit" button Args: chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records query (str): query of current round top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. temperature (float): strictly positive float value used to modulate the logits distribution. max_dec_len (int): The maximum numbers of tokens to generate. Yields: List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n], [q_n+1, a_n+1]]. chat_history + QA of current round. """ assert query != "", "Input must not be empty!!!" # apply chat template chat_history.append([f"user:{query}", ""]) if role_name == "三三": role_file_path = "三三" for answer in hf_gen(chat_history, role_name,role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): chat_history[-1][1] = role_name + ":" + answer yield gr.update(value=""), chat_history @spaces.GPU def regenerate(chat_history: List,role_name, role_description, role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): """re-generate the answer of last round's query Args: chat_history (List): [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. list that stores all QA records top_p (float): only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. temperature (float): strictly positive float value used to modulate the logits distribution. max_dec_len (int): The maximum numbers of tokens to generate. Yields: List: [[q_1, a_1], [q_2, a_2], ..., [q_n, a_n]]. chat_history """ assert len(chat_history) >= 1, "History is empty. Nothing to regenerate!!" if len(chat_history[-1]) > 1: chat_history[-1][1] = "" # apply chat template if role_name == "三三": role_file_path = "三三" for answer in hf_gen(chat_history, role_name,role_file_path, top_k, top_p, temperature, repetition_penalty, max_dec_len): chat_history[-1][1] = role_name + ":" + answer yield gr.update(value=""), chat_history def clear_history(): """clear all chat history Returns: List: empty chat history """ torch.cuda.empty_cache() return [] def delete_current_user(user_role_path): try: role_upload_path = os.path.join(character_path, user_role_path + ".csv") role_path = os.path.join(character_path, user_role_path + ".json") rag_path = os.path.join(character_path, user_role_path + "_rag.json") question_path = os.path.join(character_path, user_role_path + "_测试问题.json") files_to_delete = [role_upload_path, role_path, rag_path, question_path] for file_path in files_to_delete: os.remove(file_path) except Exception as e: print(e) # launch gradio demo with gr.Blocks(theme="soft") as demo: gr.Markdown("""# Index-1.9B RolePlay Gradio Demo""") with gr.Row(): with gr.Column(scale=1): top_k = gr.Slider(0, 10, value=5, step=1, label="top_k") top_p = gr.Slider(0, 1, value=0.8, step=0.8, label="top_p") temperature = gr.Slider(0.1, 2.0, value=0.85, step=0.1, label="temp") repetition_penalty = gr.Slider(0.1, 2.0, value=1.0, step=0.1, label="repp") max_dec_len = gr.Slider(1, 4096, value=512, step=1, label="max_new") file_input = gr.File(label="上传角色对话语料(.csv)") role_description = gr.Textbox(label="您创建的角色描述", placeholder="输入角色描述", lines=2) upload_button = gr.Button("生成角色!") new_path = gr.State() def generate_file(file_obj, role_info): random.seed() alphabet = 'abcdefghijklmnopqrstuvwxyz!@#$%^&*()' random_char = "".join(random.choice(alphabet) for _ in range(10)) role_name = os.path.basename(file_obj).split(".")[0] new_path = role_name + random_char new_save_path = os.path.join(character_path, new_path+".csv") shutil.copy(file_obj, new_save_path) new_file_path = os.path.join(character_path, new_path) decode_csv_to_json(os.path.join(character_path, new_path + ".csv"), role_name, role_info, new_file_path + ".json" ) gr.Info(f"{role_name}生成成功") return new_path upload_button.click(generate_file, inputs=[file_input, role_description],outputs=new_path) with gr.Column(scale=10): chatbot = gr.Chatbot(bubble_full_width=False, height=400, label='Index-1.9B RolePlay') with gr.Row(): role_name = gr.Textbox(label="对话的角色名字", value="三三", placeholder="如果您没有创建角色,可以直接输入三三。如果已经创建好了对应的角色,请在这里输入角色的名称!", lines=2) user_input = gr.Textbox(label="用户问题", placeholder="输入你的问题!", lines=2) with gr.Row(): submit = gr.Button("🚀 Submit") clear = gr.Button("🧹 Clear") regen = gr.Button("🔄 Regenerate") submit.click(generate, inputs=[chatbot, user_input, role_name, role_description, new_path, top_k, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot]) regen.click(regenerate, inputs=[chatbot, role_name, role_description, new_path, top_k, top_p, temperature, repetition_penalty, max_dec_len], outputs=[user_input, chatbot]) clear.click(clear_history, inputs=[], outputs=[chatbot]) demo.queue().launch()