Spaces:
Runtime error
Runtime error
File size: 4,638 Bytes
bb59984 40e829d bb59984 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
"""
This script defines a PromptTemplate class that assists in generating
conversation/prompt templates. The script facilitates formatting prompts
for inference and training by combining various context elements and user inputs.
"""
import dataclasses
from typing import Dict, List, Union
@dataclasses.dataclass
class PromptTemplate:
"""A class that manages prompt templates"""
# The name of this template
name: str
# The template of the system prompt
system_template: str = "{system_message}"
# The template for the system context
context_template: str = "{user_context}\n{news_context}"
# The template for the conversation history
chat_history_template: str = "{chat_history}"
# The template of the user question
question_template: str = "{question}"
# The template of the system answer
answer_template: str = "{answer}"
# The system message
system_message: str = ""
# Separator
sep: str = "\n"
eos: str = "</s>"
@property
def input_variables(self) -> List[str]:
"""Returns a list of input variables for the prompt template"""
return ["user_context", "news_context", "chat_history", "question", "answer"]
@property
def train_raw_template(self):
"""Returns the training prompt template format"""
system = self.system_template.format(system_message=self.system_message)
context = f"{self.sep}{self.context_template}"
chat_history = f"{self.sep}{self.chat_history_template}"
question = f"{self.sep}{self.question_template}"
answer = f"{self.sep}{self.answer_template}"
return f"{system}{context}{chat_history}{question}{answer}{self.eos}"
@property
def infer_raw_template(self):
"""Returns the inference prompt template format"""
system = self.system_template.format(system_message=self.system_message)
context = f"{self.sep}{self.context_template}"
chat_history = f"{self.sep}{self.chat_history_template}"
question = f"{self.sep}{self.question_template}"
return f"{system}{context}{chat_history}{question}{self.eos}"
def format_train(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
"""Formats the data sample to a training sample"""
prompt = self.train_raw_template.format(
user_context=sample["user_context"],
news_context=sample["news_context"],
chat_history=sample.get("chat_history", ""),
question=sample["question"],
answer=sample["answer"],
)
return {"prompt": prompt, "payload": sample}
def format_infer(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]:
"""Formats the data sample to a testing sample"""
prompt = self.infer_raw_template.format(
user_context=sample["user_context"],
news_context=sample["news_context"],
chat_history=sample.get("chat_history", ""),
question=sample["question"],
)
return {"prompt": prompt, "payload": sample}
# Global Templates registry
templates: Dict[str, PromptTemplate] = {}
def register_llm_template(template: PromptTemplate):
"""Register a new template to the global templates registry"""
templates[template.name] = template
def get_llm_template(name: str) -> PromptTemplate:
"""Returns the template assigned to the given name"""
return templates[name]
##### Register Templates #####
# - Mistral 7B Instruct v0.2 Template
register_llm_template(
PromptTemplate(
name="mistral",
system_template="<s>{system_message}",
system_message="You are a helpful assistant, with financial expertise, and you do not answer questions which contain illegal or harmful information.",
context_template="{user_context}\n{news_context}",
chat_history_template="Summary: {chat_history}",
question_template="[INST] {question} [/INST]",
answer_template="{answer}",
sep="\n",
eos=" </s>",
)
)
# - FALCON (spec: https://huggingface.co/tiiuae/falcon-7b/blob/main/tokenizer.json)
register_llm_template(
PromptTemplate(
name="falcon",
system_template=">>INTRODUCTION<< {system_message}",
system_message="You are a helpful assistant, with financial expertise.",
context_template=">>DOMAIN<< {user_context}\n{news_context}",
chat_history_template=">>SUMMARY<< {chat_history}",
question_template=">>QUESTION<< {question}",
answer_template=">>ANSWER<< {answer}",
sep="\n",
eos="<|endoftext|>",
)
)
|