Spaces:
Paused
Paused
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder | |
from langchain_core.runnables import RunnableLambda, RunnableSequence | |
from langchain_core.messages import HumanMessage | |
def strip_prompt(info): | |
print(info) | |
eot_token = "[/INST] " | |
i = info.content.rfind(eot_token) | |
if i == -1: | |
return info | |
info.content = info.content[i + len(eot_token):] | |
return {"next": info.content} | |
class Supervisor(): | |
members = [] | |
system_prompt = ( | |
"You are a supervisor tasked with managing a conversation between the" | |
" following workers: {members}. Given the following user request," | |
" respond with the worker to act next. Each worker will perform a" | |
" task and respond with their results and status. When finished," | |
" respond with FINISH." | |
) | |
def __init__(self, llm, members): | |
self.members += members | |
self.prompt = ChatPromptTemplate.from_messages( | |
[ | |
("human", self.system_prompt), | |
("assistant", "ok"), | |
MessagesPlaceholder(variable_name="messages"), | |
("assistant", "ok"), | |
( | |
"human", | |
"Given the conversation above, who should act next?" | |
" Or should we FINISH? Select one of: {options}", | |
), | |
] | |
).partial(options=str(self.get_options()), members=", ".join(self.members)) | |
self.chain = ( self.prompt | llm | RunnableLambda(strip_prompt)) | |
def add_member(self, member): | |
self.members.append(member) | |
def get_members(self): | |
return self.members; | |
def get_options(self): | |
return ["FINISH"] + self.members | |
def get_chain(self): | |
return self.chain | |
def invoke(self, query): | |
self.chain.invoke([HumanMessage(query)]) | |