Spaces:
Paused
Paused
File size: 2,030 Bytes
9c9a39f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda, RunnableSequence
from langchain_core.messages import HumanMessage
def strip_prompt(info):
print(info)
eot_token = "[/INST] "
i = info.content.rfind(eot_token)
if i == -1:
return info
info.content = info.content[i + len(eot_token):]
return {"next": info.content}
class Supervisor():
members = []
system_prompt = (
"You are a supervisor tasked with managing a conversation between the"
" following workers: {members}. Given the following user request,"
" respond with the worker to act next. Each worker will perform a"
" task and respond with their results and status. When finished,"
" respond with FINISH."
)
def __init__(self, llm, members):
self.members += members
self.prompt = ChatPromptTemplate.from_messages(
[
("human", self.system_prompt),
("assistant", "ok"),
MessagesPlaceholder(variable_name="messages"),
("assistant", "ok"),
(
"human",
"Given the conversation above, who should act next?"
" Or should we FINISH? Select one of: {options}",
),
]
).partial(options=str(self.get_options()), members=", ".join(self.members))
self.chain = ( self.prompt | llm | RunnableLambda(strip_prompt))
def add_member(self, member):
self.members.append(member)
def get_members(self):
return self.members;
def get_options(self):
return ["FINISH"] + self.members
def get_chain(self):
return self.chain
def invoke(self, query):
self.chain.invoke([HumanMessage(query)])
|