Spaces:
Running
Running
File size: 2,056 Bytes
8ff63e4 e38f79f 8ff63e4 e38f79f 8ff63e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
"""An abstraction layer for prompting different models."""
from __future__ import annotations
import enum
from fastchat.model.model_adapter import get_conversation_template
class Task(enum.Enum):
"""Different system prompt styles."""
CHAT = "chat"
CHAT_CONCISE = "chat-concise"
INSTRUCT = "instruct"
INSTRUCT_CONCISE = "instruct-concise"
SYSTEM_PROMPTS = {
Task.CHAT: (
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
),
Task.CHAT_CONCISE: (
"A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions. "
"The assistant's answers are very concise. "
),
Task.INSTRUCT: (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request. "
),
Task.INSTRUCT_CONCISE: (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request. "
"The response should be very concise. "
),
}
def get_system_prompt(task: Task | str) -> str:
"""Get the system prompt for a given task."""
if isinstance(task, str):
task = Task(task)
return SYSTEM_PROMPTS[task]
def apply_model_characteristics(
prompt: str,
model_name: str,
system_prompt: str | None = None,
) -> tuple[str, str | None, list[int]]:
"""Apply and return model-specific differences."""
conv = get_conversation_template(model_name)
if system_prompt is not None:
conv.system_message = system_prompt
conv.messages = []
conv.offset = 0
conv.append_message(conv.roles[0], prompt)
conv.append_message(conv.roles[1], "")
stop_str = None if conv.stop_str is None or not conv.stop_str else conv.stop_str
return conv.get_prompt(), stop_str, (conv.stop_token_ids or [])
|