Spaces:
Paused
Paused
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.runnables import RunnableLambda, RunnableSequence | |
| from langchain_core.messages import HumanMessage | |
| def strip_prompt(info): | |
| print(info) | |
| eot_token = "[/INST] " | |
| i = info.content.rfind(eot_token) | |
| if i == -1: | |
| return info | |
| info.content = info.content[i + len(eot_token):] | |
| return {"next": info.content} | |
| class Supervisor(): | |
| members = [] | |
| system_prompt = ( | |
| "You are a supervisor tasked with managing a conversation between the" | |
| " following workers: {members}. Given the following user request," | |
| " respond with the worker to act next. Each worker will perform a" | |
| " task and respond with their results and status. When finished," | |
| " respond with FINISH." | |
| ) | |
| def __init__(self, llm, members): | |
| self.members += members | |
| self.prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("human", self.system_prompt), | |
| ("assistant", "ok"), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ("assistant", "ok"), | |
| ( | |
| "human", | |
| "Given the conversation above, who should act next?" | |
| " Or should we FINISH? Select one of: {options}", | |
| ), | |
| ] | |
| ).partial(options=str(self.get_options()), members=", ".join(self.members)) | |
| self.chain = ( self.prompt | llm | RunnableLambda(strip_prompt)) | |
| def add_member(self, member): | |
| self.members.append(member) | |
| def get_members(self): | |
| return self.members; | |
| def get_options(self): | |
| return ["FINISH"] + self.members | |
| def get_chain(self): | |
| return self.chain | |
| def invoke(self, query): | |
| self.chain.invoke([HumanMessage(query)]) | |