Edit model card

🤭 Please refer to https://github.com/svjack/Genshin-Impact-Character-Chat to get more info

Install

pip install peft transformers bitsandbytes ipykernel rapidfuzz

Run by transformers

import json
from dataclasses import dataclass
from enum import Enum
from typing import List, Dict, Tuple, Literal

class Roles(Enum):
    system = "system"
    user = "user"
    assistant = "assistant"
    tool = "tool"

class MessagesFormatterType(Enum):
    """
    Enum representing different types of predefined messages formatters.
    """

    MISTRAL = 1

@dataclass
class PromptMarkers:
    start: str
    end: str

class MessagesFormatter:
    def __init__(
            self,
            pre_prompt: str,
            prompt_markers: Dict[Roles, PromptMarkers],
            include_sys_prompt_in_first_user_message: bool,
            default_stop_sequences: List[str],
            use_user_role_for_function_call_result: bool = True,
            strip_prompt: bool = True,
            bos_token: str = "<s>",
            eos_token: str = "</s>"
    ):
        self.pre_prompt = pre_prompt
        self.prompt_markers = prompt_markers
        self.include_sys_prompt_in_first_user_message = include_sys_prompt_in_first_user_message
        self.default_stop_sequences = default_stop_sequences
        self.use_user_role_for_function_call_result = use_user_role_for_function_call_result
        self.strip_prompt = strip_prompt
        self.bos_token = bos_token
        self.eos_token = eos_token
        self.added_system_prompt = False

    def get_bos_token(self) -> str:
        return self.bos_token

    def format_conversation(
            self,
            messages: List[Dict[str, str]],
            response_role: Literal[Roles.user, Roles.assistant] | None = None,
    ) -> Tuple[str, Roles]:
        formatted_messages = self.pre_prompt
        last_role = Roles.assistant
        self.added_system_prompt = False
        for message in messages:
            role = Roles(message["role"])
            content = self._format_message_content(message["content"], role)

            if role == Roles.system:
                formatted_messages += self._format_system_message(content)
                last_role = Roles.system
            elif role == Roles.user:
                formatted_messages += self._format_user_message(content)
                last_role = Roles.user
            elif role == Roles.assistant:
                formatted_messages += self._format_assistant_message(content)
                last_role = Roles.assistant
            elif role == Roles.tool:
                formatted_messages += self._format_tool_message(content)
                last_role = Roles.tool

        return self._format_response(formatted_messages, last_role, response_role)

    def _format_message_content(self, content: str, role: Roles) -> str:
        if self.strip_prompt:
            return content.strip()
        return content

    def _format_system_message(self, content: str) -> str:
        formatted_message = self.prompt_markers[Roles.system].start + content + self.prompt_markers[Roles.system].end
        self.added_system_prompt = True
        if self.include_sys_prompt_in_first_user_message:
            formatted_message = self.prompt_markers[Roles.user].start + formatted_message
        return formatted_message

    def _format_user_message(self, content: str) -> str:
        if self.include_sys_prompt_in_first_user_message and self.added_system_prompt:
            self.added_system_prompt = False
            return content + self.prompt_markers[Roles.user].end
        return self.prompt_markers[Roles.user].start + content + self.prompt_markers[Roles.user].end

    def _format_assistant_message(self, content: str) -> str:
        return self.prompt_markers[Roles.assistant].start + content + self.prompt_markers[Roles.assistant].end

    def _format_tool_message(self, content: str) -> str:
        if isinstance(content, list):
            content = "\n".join(json.dumps(m, indent=2) for m in content)
        if self.use_user_role_for_function_call_result:
            return self._format_user_message(content)
        else:
            return self.prompt_markers[Roles.tool].start + content + self.prompt_markers[Roles.tool].end

    def _format_response(
            self,
            formatted_messages: str,
            last_role: Roles,
            response_role: Literal[Roles.user, Roles.assistant] | None = None,
    ) -> Tuple[str, Roles]:
        if response_role is None:
            response_role = Roles.assistant if last_role != Roles.assistant else Roles.user

        prompt_start = self.prompt_markers[response_role].start.strip() if self.strip_prompt else self.prompt_markers[
            response_role].start
        return formatted_messages + prompt_start, response_role

mixtral_prompt_markers = {
    Roles.system: PromptMarkers("", """\n\n"""),
    Roles.user: PromptMarkers("""[INST] """, """ [/INST]"""),
    Roles.assistant: PromptMarkers("""""", """</s>"""),
    Roles.tool: PromptMarkers("", ""),
}

mixtral_formatter = MessagesFormatter(
    "",
    mixtral_prompt_markers,
    True,
    ["</s>"],
)

from transformers import TextStreamer, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3",)
mis_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", load_in_4bit = True)
mis_model = PeftModel.from_pretrained(mis_model,
                                "svjack/DPO_Genshin_Impact_Mistral_Plot_Engine_Step_Json_Short_lora_small"
                                )
mis_model = mis_model.eval()

streamer = TextStreamer(tokenizer)

def mistral_hf_predict(messages, mis_model = mis_model,
    tokenizer = tokenizer, streamer = streamer,
    do_sample = True,
    top_p = 0.95,
    top_k = 40,
    max_new_tokens = 512,
    max_input_length = 3500,
    temperature = 0.9,
    repetition_penalty = 1.0,
    device = "cuda"):

    #encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
    #model_inputs = encodeds.to(device)
    prompt, _ = mixtral_formatter.format_conversation(messages)
    model_inputs = tokenizer.encode(prompt, return_tensors="pt").to(device)

    generated_ids = mis_model.generate(model_inputs, max_new_tokens=max_new_tokens,
                                do_sample=do_sample,
                                  streamer = streamer,
                                  top_p = top_p,
                                  top_k = top_k,
                                  temperature = temperature,
                                  repetition_penalty = repetition_penalty,
                                  )
    out = tokenizer.batch_decode(generated_ids)[0].split("[/INST]")[-1].replace("</s>", "").strip()
    return out

from rapidfuzz import fuzz 
from IPython.display import clear_output
def run_step_infer_times(x, times = 5, temperature = 0.01,
                        repetition_penalty = 1.0,
                        sim_val = 70
                        ):
    req = []
    for _ in range(times):
        clear_output(wait = True)
        out = mistral_hf_predict([
                {
                    "role": "system",
                    "content": ""
                },
                {
                    "role": "user",
                    "content": x
                },
            ],
            repetition_penalty = repetition_penalty,
            temperature = temperature,
            max_new_tokens = 2070,
            max_input_length = 6000,
        )
        if req:
            val = max(map(lambda x: fuzz.ratio(x, out), req))
            #print(val)
            #print(req)
            if val < sim_val:
                req.append(out.strip())
            x = x.strip() + "\n" + out.strip()
        else:
            req.append(out.strip())
            x = x.strip() + "\n" + out.strip()
    return req

out_l = run_step_infer_times(
'''
故事标题:归乡
故事背景:在须弥城门口,派蒙与纳西妲偶遇并帮助一只昏迷的元素生命找寻家园。过程中揭示了这只生物并非普通的蕈兽,而是元素生物,并且它们曾受到过‘末日’的影响,家园被侵蚀。纳西妲回忆起晶体里的力量可能与一个预言有关,为了拯救它们的家园,她必须解决‘禁忌知识’问题,但这个过程对她自身也会产生干扰。
参与角色:派蒙、纳西妲、浮游水蕈兽、旅行者
''',
    temperature=0.1,
    repetition_penalty = 1.0,
    times = 10
)
clear_output(wait = True)

print("\n".join(out_l))

Output

{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '派蒙在须弥城门口发现纳西妲,并询问她是否来寻找‘元素生物’。纳西妲确认后,两人决定一起寻找。过程中,纳西妲解释了这只生物的特殊性以及它们的家园被‘末日’影响。'}
{'参与者1': '派蒙', '参与者2': '浮游水蕈兽', '当前故事背景': '派蒙与纳西妲找到昏迷的元素生物,并通过沟通发现它们的特殊性。纳西妲提出要帮助它们找回家园,而这涉及到‘禁忌知识’和晶体的力量。'}
{'参与者1': '纳西妲', '参与者2': '旅行者', '当前故事背景': '纳西妲向旅行者解释了‘元素生物’的特殊性以及它们的家园被‘末日’影响,并提出要解决‘禁忌知识’问题以拯救它们。'}
{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '派蒙对纳西妲回忆中晶体力量的含义感到疑惑,纳西妲解释这可能与一个预言有关,但她自己对此有所犹豫。'}
{'参与者1': '纳西妲', '参与者2': '浮游水蕈兽', '当前故事背景': '纳西妲试图通过沟通与晶体交流,但发现自己的意识被污染,这让她对‘禁忌知识’有所怀疑。'}
{'参与者1': '派蒙', '参与者2': '纳西妲', '当前故事背景': '派蒙和纳西妲讨论了‘禁忌知识’可能对纳西妲的影响,以及她的决定要面对并解决这个问题。'}
{'参与者1': '纳西妲', '参与者2': '浮游水蕈兽', '当前故事背景': '纳西妲提出要带回晶体,并希望旅行者能帮助她解决‘禁忌知识’问题,以拯救元素生物的家园。'}

train_2024-05-30-19-38-44

This model is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.3 on the dpo_genshin_impact_plot_engine_step_short_json dataset.

Model description

More information needed

Intended uses & limitations

More information needed

Training and evaluation data

More information needed

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 5e-05
  • train_batch_size: 2
  • eval_batch_size: 8
  • seed: 42
  • gradient_accumulation_steps: 8
  • total_train_batch_size: 16
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: cosine
  • num_epochs: 3.0
  • mixed_precision_training: Native AMP

Training results

Framework versions

  • PEFT 0.11.1
  • Transformers 4.41.1
  • Pytorch 2.3.0+cu121
  • Datasets 2.19.1
  • Tokenizers 0.19.1
Downloads last month
1
Inference API
Unable to determine this model’s pipeline type. Check the docs .

Model tree for svjack/DPO_Genshin_Impact_Mistral_Plot_Engine_Step_Json_Short_lora_small

Adapter
(196)
this model