File size: 4,560 Bytes
bb59984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
This script defines a PromptTemplate class that assists in generating 
conversation/prompt templates. The script facilitates formatting prompts 
for inference and training by combining various context elements and user inputs.
"""


import dataclasses
from typing import Dict, List, Union


@dataclasses.dataclass
class PromptTemplate:
    """A class that manages prompt templates"""

    # The name of this template
    name: str
    # The template of the system prompt
    system_template: str = "{system_message}"
    # The template for the system context
    context_template: str = "{user_context}\n{news_context}"
    # The template for the conversation history
    chat_history_template: str = "{chat_history}"
    # The template of the user question
    question_template: str = "{question}"
    # The template of the system answer
    answer_template: str = "{answer}"
    # The system message
    system_message: str = ""
    # Separator
    sep: str = "\n"
    eos: str = "</s>"

    @property
    def input_variables(self) -> List[str]:
        """Returns a list of input variables for the prompt template"""

        return ["user_context", "news_context", "chat_history", "question", "answer"]

    @property
    def train_raw_template(self):
        """Returns the training prompt template format"""

        system = self.system_template.format(system_message=self.system_message)
        context = f"{self.sep}{self.context_template}"
        chat_history = f"{self.sep}{self.chat_history_template}"
        question = f"{self.sep}{self.question_template}"
        answer = f"{self.sep}{self.answer_template}"

        return f"{system}{context}{chat_history}{question}{answer}{self.eos}"

    @property
    def infer_raw_template(self):
        """Returns the inference prompt template format"""

        system = self.system_template.format(system_message=self.system_message)
        context = f"{self.sep}{self.context_template}"
        chat_history = f"{self.sep}{self.chat_history_template}"
        question = f"{self.sep}{self.question_template}"

        return f"{system}{context}{chat_history}{question}{self.eos}"

    def format_train(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
        """Formats the data sample to a training sample"""

        prompt = self.train_raw_template.format(
            user_context=sample["user_context"],
            news_context=sample["news_context"],
            chat_history=sample.get("chat_history", ""),
            question=sample["question"],
            answer=sample["answer"],
        )
        return {"prompt": prompt, "payload": sample}

    def format_infer(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
        """Formats the data sample to a testing sample"""

        prompt = self.infer_raw_template.format(
            user_context=sample["user_context"],
            news_context=sample["news_context"],
            chat_history=sample.get("chat_history", ""),
            question=sample["question"],
        )
        return {"prompt": prompt, "payload": sample}


# Global Templates registry
templates: Dict[str, PromptTemplate] = {}


def register_llm_template(template: PromptTemplate):
    """Register a new template to the global templates registry"""

    templates[template.name] = template


def get_llm_template(name: str) -> PromptTemplate:
    """Returns the template assigned to the given name"""

    return templates[name]


##### Register Templates #####
# - Mistral 7B Instruct v0.2 Template
register_llm_template(
    PromptTemplate(
        name="mistral",
        system_template="<s>{system_message}",
        system_message="You are a helpful assistant, with financial expertise.",
        context_template="{user_context}\n{news_context}",
        chat_history_template="Summary: {chat_history}",
        question_template="[INST] {question} [/INST]",
        answer_template="{answer}",
        sep="\n",
        eos=" </s>",
    )
)

# - FALCON (spec: https://huggingface.co/tiiuae/falcon-7b/blob/main/tokenizer.json)
register_llm_template(
    PromptTemplate(
        name="falcon",
        system_template=">>INTRODUCTION<< {system_message}",
        system_message="You are a helpful assistant, with financial expertise.",
        context_template=">>DOMAIN<< {user_context}\n{news_context}",
        chat_history_template=">>SUMMARY<< {chat_history}",
        question_template=">>QUESTION<< {question}",
        answer_template=">>ANSWER<< {answer}",
        sep="\n",
        eos="<|endoftext|>",
    )
)