minichain / agent.py
Sasha Rush
.
8d0a608
# + tags=["hide_inp"]
desc = """
### Agent
Chain that executes different tools based on model decisions. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/bash.ipynb)
(Adapted from LangChain )
"""
# -
# $
from minichain import Id, prompt, OpenAI, show, transform, Mock, Break
from gradio_tools.tools import StableDiffusionTool, ImageCaptioningTool, ImageToMusicTool
# class ImageCaptioningTool:
# def run(self, inp):
# return "This is a picture of a smiling huggingface logo."
# description = "Image Captioning"
tools = [StableDiffusionTool(), ImageCaptioningTool(), ImageToMusicTool()]
@prompt(OpenAI(stop=["Observation:"]),
template_file="agent.pmpt.tpl")
def agent(model, query, history):
return model(dict(tools=[(str(tool.__class__.__name__), tool.description)
for tool in tools],
input=query,
agent_scratchpad=history
))
@transform()
def tool_parse(out):
lines = out.split("\n")
if lines[0].split("?")[-1].strip() == "Yes":
tool = lines[1].split(":", 1)[-1].strip()
command = lines[2].split(":", 1)[-1].strip()
return tool, command
else:
return Break()
@prompt(tools)
def tool_use(model, usage):
selector, command = usage
for i, tool in enumerate(tools):
if selector == tool.__class__.__name__:
return model(command, tool_num=i)
return ("",)
@transform()
def append(history, new, observation):
return history + "\n" + new + "Observation: " + observation
def run(query):
history = ""
observations = []
for i in range(3):
select_input = agent(query, history)
observations.append(tool_use(tool_parse(select_input)))
history = append(history, select_input, observations[i])
return observations[-1]
# $
gradio = show(run,
subprompts=[agent, tool_use] * 3,
examples=[
"I would please like a photo of a dog riding a skateboard. "
"Please caption this image and create a song for it.",
'Use an image generator tool to draw a cat.',
'Caption the image https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png from the internet'],
out_type="markdown",
description=desc,
show_advanced=False
)
if __name__ == "__main__":
gradio.queue().launch()