Spaces:
Sleeping
Sleeping
import dataclasses | |
from enum import auto, Enum | |
from typing import List, Tuple, Any | |
class SeparatorStyle(Enum): | |
"""Different separator style.""" | |
SINGLE = auto() | |
TWO = auto() | |
class Conversation: | |
"""A class that keeps all conversation history.""" | |
system: str | |
roles: List[str] | |
messages: List[List[str]] | |
offset: int | |
sep_style: SeparatorStyle = SeparatorStyle.SINGLE | |
sep: str = "###" | |
sep2: str = None | |
# Used for gradio server | |
skip_next: bool = False | |
conv_id: Any = None | |
def get_prompt(self): | |
if self.sep_style == SeparatorStyle.SINGLE: | |
ret = self.system | |
for role, message in self.messages: | |
if message: | |
ret += self.sep + " " + role + ": " + message | |
else: | |
ret += self.sep + " " + role + ":" | |
return ret | |
elif self.sep_style == SeparatorStyle.TWO: | |
seps = [self.sep, self.sep2] | |
ret = self.system + seps[0] | |
for i, (role, message) in enumerate(self.messages): | |
if message: | |
ret += role + ": " + message + seps[i % 2] | |
else: | |
ret += role + ":" | |
return ret | |
else: | |
raise ValueError(f"Invalid style: {self.sep_style}") | |
def append_message(self, role, message): | |
self.messages.append([role, message]) | |
def to_gradio_chatbot(self): | |
ret = [] | |
for i, (role, msg) in enumerate(self.messages[self.offset:]): | |
if i % 2 == 0: | |
ret.append([msg, None]) | |
else: | |
ret[-1][-1] = msg | |
return ret | |
def copy(self): | |
return Conversation( | |
system=self.system, | |
roles=self.roles, | |
messages=[[x, y] for x, y in self.messages], | |
offset=self.offset, | |
sep_style=self.sep_style, | |
sep=self.sep, | |
sep2=self.sep2, | |
conv_id=self.conv_id) | |
def dict(self): | |
return { | |
"system": self.system, | |
"roles": self.roles, | |
"messages": self.messages, | |
"offset": self.offset, | |
"sep": self.sep, | |
"sep2": self.sep2, | |
"conv_id": self.conv_id, | |
} | |
conv = Conversation( | |
system="A chat between a curious user and an artificial intelligence assistant. " | |
"The assistant gives helpful, detailed, and polite answers to the user's questions.", | |
roles=("USER", "ASSISTANT"), | |
messages=[], | |
offset=0, | |
sep_style=SeparatorStyle.TWO, | |
sep=" ", | |
sep2="</s>", | |
) | |
conv.append_message(conv.roles[0], "Why would Microsoft take this down?") | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
result = model.generate(**inputs, max_new_tokens=1000) | |
generated_ids = result[0] | |
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
print(generated_text) |