Spaces:
Runtime error
Runtime error
File size: 3,417 Bytes
88bce96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from typing import Dict, List, Literal, TypedDict
from models import Model
from pybars import Compiler
compiler = Compiler()
class Turn(TypedDict):
role: Literal["user", "assistant", "system"]
content: str
def chatmsg(message:str, role:Literal["user", "assistant", "system"]):
return {"role": role, "content": message}
conversation=List[Turn]
class ChatModel:
def __init__(self,model:Model,sysprompt:str):
self.setModel(model)
self.setSysPrompt(sysprompt)
def __call__(self, msg:str):
raise NotImplementedError
def getconversation(self) -> conversation:
raise NotImplementedError
def conversationend(self) -> bool:
raise NotImplementedError
def setconversation(self,conversation:conversation):
raise NotImplementedError
def setSysPrompt(self,sysprompt:str):
def _eq(this, a,b):
return a==b
self.sysprompt=compiler.compile(sysprompt)({
"model":self.name
},helpers={"eq":_eq})
print(self.name+" SystemPrompt:\n"+self.sysprompt)
def setModel(self,model:Model):
self.model=model
class SwapChatModel(ChatModel):
def __init__(self,model:Model,sysprompt:str):
super().__init__(model,sysprompt)
self.conversation=[]
def __call__(self, msg:str):
if "End of conversation." in [i["content"] for i in self.conversation]:
return
self.conversation.append(chatmsg(msg,"assistant"))
prompt="".join([
self.model.start(),
self.model.conv([chatmsg(self.sysprompt,"system")]),
self.model.conv(self.conversation),self.model.starttok("user")
])
ret=self.model(prompt, stop=[".","\n \n","?\n",".\n","tile|>","\n"],max_tokens=100)
comp=ret["choices"][0]["text"]
if("<|end" in comp):
self.conversation.append(chatmsg(comp.removesuffix("<|end"),"user"))
self.conversation.append(chatmsg("End of conversation.","user"))
else:
self.conversation.append(chatmsg(comp,"user"))
def getconversation(self) -> conversation:
return self.conversation
def conversationend(self) -> bool:
return "End of conversation." in [i["content"] for i in self.conversation]
def setconversation(self,conversation:conversation):
self.conversation=conversation
SwapChatModel.name="SwapChat"
class InquiryChatModel(SwapChatModel):
def __init__(self,model:Model,sysprompt:str):
super().__init__(model,sysprompt)
def inquire(self,msg):
prompt="".join([
self.model.start(),
self.model.conv([chatmsg(self.sysprompt,"system")]),
self.model.conv(self.conversation),
self.model.conv([chatmsg(msg,"assistant")]),
self.model.starttok("system"),
"Is this conversation complete(true/false)?\n"
])
ret=self.model(prompt, stop=[".","\n \n","?\n",".\n","tile|>","\n"],max_tokens=10)
print("system prompt:",ret["choices"][0]["text"])
if "true" in ret["choices"][0]["text"].lower():
self.conversation.append(chatmsg(msg,"user"))
self.conversation.append(chatmsg("End of conversation.","user"))
def __call__(self, msg:str):
self.inquire(msg)
super().__call__(msg)
InquiryChatModel.name="InquiryChat"
models=[SwapChatModel,InquiryChatModel] |