File size: 2,030 Bytes
9c9a39f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda, RunnableSequence
from langchain_core.messages import HumanMessage


def strip_prompt(info):
    print(info)
    eot_token = "[/INST] "
    i = info.content.rfind(eot_token)
    if i == -1:
        return info

    info.content = info.content[i + len(eot_token):]

    return {"next": info.content}
   

class Supervisor():

    members = []
    
    system_prompt = (
        "You are a supervisor tasked with managing a conversation between the"
        " following workers:  {members}. Given the following user request,"
        " respond with the worker to act next. Each worker will perform a"
        " task and respond with their results and status. When finished,"
        " respond with FINISH."
    )

    def __init__(self, llm, members):
        self.members += members
        self.prompt = ChatPromptTemplate.from_messages(
            [
                ("human", self.system_prompt),
                ("assistant", "ok"),
                MessagesPlaceholder(variable_name="messages"),
                ("assistant", "ok"),
                (
                    "human",
                    "Given the conversation above, who should act next?"
                    " Or should we FINISH? Select one of: {options}",
                ),
            ]
        ).partial(options=str(self.get_options()), members=", ".join(self.members))

        self.chain = ( self.prompt | llm | RunnableLambda(strip_prompt))

    def add_member(self, member):
        self.members.append(member)

    def get_members(self):
        return self.members;
        
    def get_options(self):
        return ["FINISH"] + self.members

    def get_chain(self):
        return self.chain
        

    def invoke(self, query):
        self.chain.invoke([HumanMessage(query)])