devmodetest2 / perm /agents /supervisor.py
tengel's picture
Upload 56 files
9c9a39f verified
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda, RunnableSequence
from langchain_core.messages import HumanMessage
def strip_prompt(info):
print(info)
eot_token = "[/INST] "
i = info.content.rfind(eot_token)
if i == -1:
return info
info.content = info.content[i + len(eot_token):]
return {"next": info.content}
class Supervisor():
members = []
system_prompt = (
"You are a supervisor tasked with managing a conversation between the"
" following workers: {members}. Given the following user request,"
" respond with the worker to act next. Each worker will perform a"
" task and respond with their results and status. When finished,"
" respond with FINISH."
)
def __init__(self, llm, members):
self.members += members
self.prompt = ChatPromptTemplate.from_messages(
[
("human", self.system_prompt),
("assistant", "ok"),
MessagesPlaceholder(variable_name="messages"),
("assistant", "ok"),
(
"human",
"Given the conversation above, who should act next?"
" Or should we FINISH? Select one of: {options}",
),
]
).partial(options=str(self.get_options()), members=", ".join(self.members))
self.chain = ( self.prompt | llm | RunnableLambda(strip_prompt))
def add_member(self, member):
self.members.append(member)
def get_members(self):
return self.members;
def get_options(self):
return ["FINISH"] + self.members
def get_chain(self):
return self.chain
def invoke(self, query):
self.chain.invoke([HumanMessage(query)])