Spaces:
Runtime error
Runtime error
""" | |
This script defines a PromptTemplate class that assists in generating | |
conversation/prompt templates. The script facilitates formatting prompts | |
for inference and training by combining various context elements and user inputs. | |
""" | |
import dataclasses | |
from typing import Dict, List, Union | |
class PromptTemplate: | |
"""A class that manages prompt templates""" | |
# The name of this template | |
name: str | |
# The template of the system prompt | |
system_template: str = "{system_message}" | |
# The template for the system context | |
context_template: str = "{user_context}\n{news_context}" | |
# The template for the conversation history | |
chat_history_template: str = "{chat_history}" | |
# The template of the user question | |
question_template: str = "{question}" | |
# The template of the system answer | |
answer_template: str = "{answer}" | |
# The system message | |
system_message: str = "" | |
# Separator | |
sep: str = "\n" | |
eos: str = "</s>" | |
def input_variables(self) -> List[str]: | |
"""Returns a list of input variables for the prompt template""" | |
return ["user_context", "news_context", "chat_history", "question", "answer"] | |
def train_raw_template(self): | |
"""Returns the training prompt template format""" | |
system = self.system_template.format(system_message=self.system_message) | |
context = f"{self.sep}{self.context_template}" | |
chat_history = f"{self.sep}{self.chat_history_template}" | |
question = f"{self.sep}{self.question_template}" | |
answer = f"{self.sep}{self.answer_template}" | |
return f"{system}{context}{chat_history}{question}{answer}{self.eos}" | |
def infer_raw_template(self): | |
"""Returns the inference prompt template format""" | |
system = self.system_template.format(system_message=self.system_message) | |
context = f"{self.sep}{self.context_template}" | |
chat_history = f"{self.sep}{self.chat_history_template}" | |
question = f"{self.sep}{self.question_template}" | |
return f"{system}{context}{chat_history}{question}{self.eos}" | |
def format_train(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]: | |
"""Formats the data sample to a training sample""" | |
prompt = self.train_raw_template.format( | |
user_context=sample["user_context"], | |
news_context=sample["news_context"], | |
chat_history=sample.get("chat_history", ""), | |
question=sample["question"], | |
answer=sample["answer"], | |
) | |
return {"prompt": prompt, "payload": sample} | |
def format_infer(self, sample: Dict[str, str]) -> Dict[str, Union[str, Dict]]: | |
"""Formats the data sample to a testing sample""" | |
prompt = self.infer_raw_template.format( | |
user_context=sample["user_context"], | |
news_context=sample["news_context"], | |
chat_history=sample.get("chat_history", ""), | |
question=sample["question"], | |
) | |
return {"prompt": prompt, "payload": sample} | |
# Global Templates registry | |
templates: Dict[str, PromptTemplate] = {} | |
def register_llm_template(template: PromptTemplate): | |
"""Register a new template to the global templates registry""" | |
templates[template.name] = template | |
def get_llm_template(name: str) -> PromptTemplate: | |
"""Returns the template assigned to the given name""" | |
return templates[name] | |
##### Register Templates ##### | |
# - Mistral 7B Instruct v0.2 Template | |
register_llm_template( | |
PromptTemplate( | |
name="mistral", | |
system_template="<s>{system_message}", | |
system_message="You are a helpful assistant, with financial expertise.", | |
context_template="{user_context}\n{news_context}", | |
chat_history_template="Summary: {chat_history}", | |
question_template="[INST] {question} [/INST]", | |
answer_template="{answer}", | |
sep="\n", | |
eos=" </s>", | |
) | |
) | |
# - FALCON (spec: https://huggingface.co/tiiuae/falcon-7b/blob/main/tokenizer.json) | |
register_llm_template( | |
PromptTemplate( | |
name="falcon", | |
system_template=">>INTRODUCTION<< {system_message}", | |
system_message="You are a helpful assistant, with financial expertise.", | |
context_template=">>DOMAIN<< {user_context}\n{news_context}", | |
chat_history_template=">>SUMMARY<< {chat_history}", | |
question_template=">>QUESTION<< {question}", | |
answer_template=">>ANSWER<< {answer}", | |
sep="\n", | |
eos="<|endoftext|>", | |
) | |
) | |