File size: 3,417 Bytes
88bce96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from typing import Dict, List, Literal, TypedDict


from models import Model
from pybars import Compiler
compiler = Compiler()

class Turn(TypedDict):
    role: Literal["user", "assistant", "system"]
    content: str

def chatmsg(message:str, role:Literal["user", "assistant", "system"]):
    return {"role": role, "content": message}

conversation=List[Turn]

class ChatModel:
    def __init__(self,model:Model,sysprompt:str):
        self.setModel(model)
        self.setSysPrompt(sysprompt)
    def __call__(self, msg:str):
        raise NotImplementedError
    def getconversation(self) -> conversation:
        raise NotImplementedError
    def conversationend(self) -> bool:
        raise NotImplementedError
    def setconversation(self,conversation:conversation):
        raise NotImplementedError
    def setSysPrompt(self,sysprompt:str):
        def _eq(this, a,b):
            return a==b
        self.sysprompt=compiler.compile(sysprompt)({
            "model":self.name
        },helpers={"eq":_eq})
        print(self.name+" SystemPrompt:\n"+self.sysprompt)
    def setModel(self,model:Model):
        self.model=model

class SwapChatModel(ChatModel):
    def __init__(self,model:Model,sysprompt:str):
        super().__init__(model,sysprompt)
        self.conversation=[]
    def __call__(self, msg:str):
        if "End of conversation." in [i["content"] for i in self.conversation]:
            return
        self.conversation.append(chatmsg(msg,"assistant"))
        prompt="".join([
            self.model.start(),
            self.model.conv([chatmsg(self.sysprompt,"system")]),
            self.model.conv(self.conversation),self.model.starttok("user")
            ])
        ret=self.model(prompt, stop=[".","\n \n","?\n",".\n","tile|>","\n"],max_tokens=100)
        comp=ret["choices"][0]["text"]
        if("<|end" in comp):
            self.conversation.append(chatmsg(comp.removesuffix("<|end"),"user"))
            self.conversation.append(chatmsg("End of conversation.","user"))
        else:
            self.conversation.append(chatmsg(comp,"user"))
    def getconversation(self) -> conversation:
        return self.conversation
    def conversationend(self) -> bool:
        return "End of conversation." in [i["content"] for i in self.conversation]
    def setconversation(self,conversation:conversation):
        self.conversation=conversation
SwapChatModel.name="SwapChat"


class InquiryChatModel(SwapChatModel):
    def __init__(self,model:Model,sysprompt:str):
        super().__init__(model,sysprompt)
    def inquire(self,msg):
        prompt="".join([
            self.model.start(),
            self.model.conv([chatmsg(self.sysprompt,"system")]),
            self.model.conv(self.conversation),
            self.model.conv([chatmsg(msg,"assistant")]),
            self.model.starttok("system"),
            "Is this conversation complete(true/false)?\n"
            ])
        ret=self.model(prompt, stop=[".","\n \n","?\n",".\n","tile|>","\n"],max_tokens=10)
        print("system prompt:",ret["choices"][0]["text"])
        if "true" in ret["choices"][0]["text"].lower():
            self.conversation.append(chatmsg(msg,"user"))
            self.conversation.append(chatmsg("End of conversation.","user"))
    def __call__(self, msg:str):
        self.inquire(msg)
        super().__call__(msg)
InquiryChatModel.name="InquiryChat"
models=[SwapChatModel,InquiryChatModel]