🤭 Please refer to https://github.com/svjack/Genshin-Impact-Character-Chat to get more info
Install
pip install peft transformers bitsandbytes ipykernel rapidfuzz
Run by transformers
import json
from dataclasses import dataclass
from enum import Enum
from typing import List, Dict, Tuple, Literal
class Roles(Enum):
system = "system"
user = "user"
assistant = "assistant"
tool = "tool"
class MessagesFormatterType(Enum):
"""
Enum representing different types of predefined messages formatters.
"""
MISTRAL = 1
@dataclass
class PromptMarkers:
start: str
end: str
class MessagesFormatter:
def __init__(
self,
pre_prompt: str,
prompt_markers: Dict[Roles, PromptMarkers],
include_sys_prompt_in_first_user_message: bool,
default_stop_sequences: List[str],
use_user_role_for_function_call_result: bool = True,
strip_prompt: bool = True,
bos_token: str = "<s>",
eos_token: str = "</s>"
):
self.pre_prompt = pre_prompt
self.prompt_markers = prompt_markers
self.include_sys_prompt_in_first_user_message = include_sys_prompt_in_first_user_message
self.default_stop_sequences = default_stop_sequences
self.use_user_role_for_function_call_result = use_user_role_for_function_call_result
self.strip_prompt = strip_prompt
self.bos_token = bos_token
self.eos_token = eos_token
self.added_system_prompt = False
def get_bos_token(self) -> str:
return self.bos_token
def format_conversation(
self,
messages: List[Dict[str, str]],
response_role: Literal[Roles.user, Roles.assistant] | None = None,
) -> Tuple[str, Roles]:
formatted_messages = self.pre_prompt
last_role = Roles.assistant
self.added_system_prompt = False
for message in messages:
role = Roles(message["role"])
content = self._format_message_content(message["content"], role)
if role == Roles.system:
formatted_messages += self._format_system_message(content)
last_role = Roles.system
elif role == Roles.user:
formatted_messages += self._format_user_message(content)
last_role = Roles.user
elif role == Roles.assistant:
formatted_messages += self._format_assistant_message(content)
last_role = Roles.assistant
elif role == Roles.tool:
formatted_messages += self._format_tool_message(content)
last_role = Roles.tool
return self._format_response(formatted_messages, last_role, response_role)
def _format_message_content(self, content: str, role: Roles) -> str:
if self.strip_prompt:
return content.strip()
return content
def _format_system_message(self, content: str) -> str:
formatted_message = self.prompt_markers[Roles.system].start + content + self.prompt_markers[Roles.system].end
self.added_system_prompt = True
if self.include_sys_prompt_in_first_user_message:
formatted_message = self.prompt_markers[Roles.user].start + formatted_message
return formatted_message
def _format_user_message(self, content: str) -> str:
if self.include_sys_prompt_in_first_user_message and self.added_system_prompt:
self.added_system_prompt = False
return content + self.prompt_markers[Roles.user].end
return self.prompt_markers[Roles.user].start + content + self.prompt_markers[Roles.user].end
def _format_assistant_message(self, content: str) -> str:
return self.prompt_markers[Roles.assistant].start + content + self.prompt_markers[Roles.assistant].end
def _format_tool_message(self, content: str) -> str:
if isinstance(content, list):
content = "\n".join(json.dumps(m, indent=2) for m in content)
if self.use_user_role_for_function_call_result:
return self._format_user_message(content)
else:
return self.prompt_markers[Roles.tool].start + content + self.prompt_markers[Roles.tool].end
def _format_response(
self,
formatted_messages: str,
last_role: Roles,
response_role: Literal[Roles.user, Roles.assistant] | None = None,
) -> Tuple[str, Roles]:
if response_role is None:
response_role = Roles.assistant if last_role != Roles.assistant else Roles.user
prompt_start = self.prompt_markers[response_role].start.strip() if self.strip_prompt else self.prompt_markers[
response_role].start
return formatted_messages + prompt_start, response_role
mixtral_prompt_markers = {
Roles.system: PromptMarkers("", """\n\n"""),
Roles.user: PromptMarkers("""[INST] """, """ [/INST]"""),
Roles.assistant: PromptMarkers("""""", """</s>"""),
Roles.tool: PromptMarkers("", ""),
}
mixtral_formatter = MessagesFormatter(
"",
mixtral_prompt_markers,
True,
["</s>"],
)
from transformers import TextStreamer, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3",)
mis_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", load_in_4bit = True)
mis_model = PeftModel.from_pretrained(mis_model,
"svjack/DPO_Genshin_Impact_Mistral_Plot_Engine_Step_Json_Short_lora_small"
)
mis_model = mis_model.eval()
streamer = TextStreamer(tokenizer)
def mistral_hf_predict(messages, mis_model = mis_model,
tokenizer = tokenizer, streamer = streamer,
do_sample = True,
top_p = 0.95,
top_k = 40,
max_new_tokens = 512,
max_input_length = 3500,
temperature = 0.9,
repetition_penalty = 1.0,
device = "cuda"):
#encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
#model_inputs = encodeds.to(device)
prompt, _ = mixtral_formatter.format_conversation(messages)
model_inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)
generated_ids = mis_model.generate(model_inputs, max_new_tokens=max_new_tokens,
do_sample=do_sample,
streamer = streamer,
top_p = top_p,
top_k = top_k,
temperature = temperature,
repetition_penalty = repetition_penalty,
)
out = tokenizer.batch_decode(generated_ids)[0].split("[/INST]")[-1].replace("</s>", "").strip()
return out
from rapidfuzz import fuzz
from IPython.display import clear_output
def run_step_infer_times(x, times = 5, temperature = 0.01,
repetition_penalty = 1.0,
sim_val = 70
):
req = []
for _ in range(times):
clear_output(wait = True)
out = mistral_hf_predict([
{
"role": "system",
"content": ""
},
{
"role": "user",
"content": x
},
],
repetition_penalty = repetition_penalty,
temperature = temperature,
max_new_tokens = 2070,
max_input_length = 6000,
)
if req:
val = max(map(lambda x: fuzz.ratio(x, out), req))
#print(val)
#print(req)
if val < sim_val:
req.append(out.strip())
x = x.strip() + "\n" + out.strip()
else:
req.append(out.strip())
x = x.strip() + "\n" + out.strip()
return req
out_l = run_step_infer_times(
'''
故事标题:归乡
故事背景:在须弥城门口,派蒙与纳西妲偶遇并帮助一只昏迷的元素生命找寻家园。过程中揭示了这只生物并非普通的蕈兽,而是元素生物,并且它们曾受到过‘末日’的影响,家园被侵蚀。纳西妲回忆起晶体里的力量可能与一个预言有关,为了拯救它们的家园,她必须解决‘禁忌知识’问题,但这个过程对她自身也会产生干扰。
参与角色:派蒙、纳西妲、浮游水蕈兽、旅行者
''',
temperature=0.1,
repetition_penalty = 1.0,
times = 10
)
clear_output(wait = True)
print("\n".join(out_l))
Output
{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '派蒙在须弥城门口发现纳西妲,并询问她是否来寻找‘元素生物’。纳西妲确认后,两人决定一起寻找。过程中,纳西妲解释了这只生物的特殊性以及它们的家园被‘末日’影响。'}
{'参与者1': '派蒙', '参与者2': '浮游水蕈兽', '当前故事背景': '派蒙与纳西妲找到昏迷的元素生物,并通过沟通发现它们的特殊性。纳西妲提出要帮助它们找回家园,而这涉及到‘禁忌知识’和晶体的力量。'}
{'参与者1': '纳西妲', '参与者2': '旅行者', '当前故事背景': '纳西妲向旅行者解释了‘元素生物’的特殊性以及它们的家园被‘末日’影响,并提出要解决‘禁忌知识’问题以拯救它们。'}
{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '派蒙对纳西妲回忆中晶体力量的含义感到疑惑,纳西妲解释这可能与一个预言有关,但她自己对此有所犹豫。'}
{'参与者1': '纳西妲', '参与者2': '浮游水蕈兽', '当前故事背景': '纳西妲试图通过沟通与晶体交流,但发现自己的意识被污染,这让她对‘禁忌知识’有所怀疑。'}
{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '派蒙和纳西妲讨论了‘禁忌知识’可能对纳西妲的影响,以及她的决定要面对并解决这个问题。'}
{'参与者1': '纳西妲', '参与者2': '浮游水蕈兽', '当前故事背景': '纳西妲提出要带回晶体,并希望旅行者能帮助她解决‘禁忌知识’问题,以拯救元素生物的家园。'}
train_2024-05-30-19-38-44
This model is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.3 on the dpo_genshin_impact_plot_engine_step_short_json dataset.
Model description
More information needed
Intended uses & limitations
More information needed
Training and evaluation data
More information needed
Training procedure
Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 5e-05
- train_batch_size: 2
- eval_batch_size: 8
- seed: 42
- gradient_accumulation_steps: 8
- total_train_batch_size: 16
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: cosine
- num_epochs: 3.0
- mixed_precision_training: Native AMP
Training results
Framework versions
- PEFT 0.11.1
- Transformers 4.41.1
- Pytorch 2.3.0+cu121
- Datasets 2.19.1
- Tokenizers 0.19.1
- Downloads last month
- 1
Model tree for svjack/DPO_Genshin_Impact_Mistral_Plot_Engine_Step_Json_Short_lora_small
Base model
mistralai/Mistral-7B-v0.3
Finetuned
mistralai/Mistral-7B-Instruct-v0.3