File size: 2,141 Bytes
8200c4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# + tags=["hide_inp"]

desc = """
### Gradio Tool

Chain that ask for a command-line question and then runs the bash command. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/srush/MiniChain/blob/master/examples/bash.ipynb)

(Adapted from LangChain [BashChain](https://langchain.readthedocs.io/en/latest/modules/chains/examples/llm_bash.html))
"""
# -

# $

from minichain import Id, prompt, OpenAIStream
from gradio_tools.tools import StableDiffusionTool, ImageCaptioningTool

    

@prompt(StableDiffusionTool())
def gen(model, query):
    return model(query)

@prompt(ImageCaptioningTool())
def caption(model, img_src):
    return model(img_src)

tools = [gen, caption]

@prompt(Id(),
        #OpenAIStream(), stream=True,
        template_file="agent.pmpt.tpl")
def agent(model, query):
    print(model(dict(tools=[(str(tool.backend.__class__), tool.backend.description)
                            for tool in tools],
                     input=query
                     )))
    return ("StableDiffusionTool", "Draw a flower")
    # out = ""
    # for t in model.stream(dict(tools=[(str(tool.backend.__class__), tool.backend.description)
    #                                   for tool in tools],
    #                       input=query
    #                       )):
    #     out += t
    #     yield out
    # lines = out.split("\n")
    # response = lines[0].split("?")[1].strip()
    # if response == "Yes":
    #     tool = lines[1].split(":")[1].strip()
    #     yield tool

@prompt(dynamic=tools)
def selector(model, input):
    selector, input = input
    if selector == "StableDiffusionTool":
        return model.tool(input, tool_num=0)
    else:
        return model.tool(input, tool_num=1)
        

def run(query):
    select_input = agent(query)
    return selector(select_input)

run("make a pic").run()
# $

gradio = show(run,
              subprompts=[agent, selector],
              examples=['Draw me a flower'],
              out_type="markdown",
              description=desc
              )
if __name__ == "__main__":
    gradio.launch()