File size: 3,105 Bytes
c904a51
 
 
 
 
 
 
 
bd896c2
 
 
c904a51
bd896c2
c904a51
bd896c2
 
 
 
 
 
 
c904a51
bd896c2
 
 
 
 
 
 
 
 
 
 
c904a51
 
bd896c2
c904a51
bd896c2
 
 
 
 
 
 
 
 
 
c904a51
bd896c2
 
 
 
 
 
 
 
 
2fcebc3
 
 
 
 
 
 
c904a51
bd896c2
 
 
 
 
 
 
 
 
 
c904a51
bd896c2
 
 
 
 
 
 
c904a51
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from gradio_tools import StableDiffusionTool, ImageCaptioningTool, StableDiffusionPromptGeneratorTool, TextToVideoTool
from langchain import OpenAI
from langchain.agents import initialize_agent, AgentType
from langchain.memory import ConversationBufferMemory
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain.prompts import ChatPromptTemplate
import gradio as gr

def initialize_llm(openai_api_key):
    try:
        llm = OpenAI(temperature=0, openai_api_key=openai_api_key)

    except Exception as e:
        print(e)
        return [None, False]

    memory = ConversationBufferMemory(memory_key="chat_history")
    tools = [StableDiffusionTool().langchain, ImageCaptioningTool().langchain,
             StableDiffusionPromptGeneratorTool().langchain, TextToVideoTool().langchain]
    agent = initialize_agent(tools, llm, memory=memory, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)

    return [agent, True]


def diffuse(query, openai_api_key):
    structured_prompt = """
    You are a program that converts a prompt to a suitable stable diffusion prompt. You generate an image with this prompt using stable diffusion. Then you generate a video of the generated image. 
    The prompt is {input_prompt}

    you have to structure it using {structure}
    """

    video_path = ResponseSchema(name="video_path",
                                description="Generate a video of the prompt and return its path.")

    image_path = ResponseSchema(name="image_path",
                                description="Generate an image of the prompt and return its path")

    output_parser = StructuredOutputParser.from_response_schemas(
        [
            video_path,
            image_path,
        ]
    )

    format_instructions = output_parser.get_format_instructions()
    structured_prompt_template = ChatPromptTemplate.from_template(structured_prompt)

    my_message = structured_prompt_template.format_messages(
        input_prompt=query,
        structure=format_instructions
    )

    agent, success = initialize_llm(openai_api_key)

    if success:
        res = agent.run(my_message)
        output_dict = output_parser.parse(res.content)
        vid_path = output_dict.get(video_path)
        img_path = output_dict.get(image_path)
        print(vid_path, img_path)
        return [vid_path, img_path]
    else:
        print("Something went wrong")


with gr.Blocks() as demo:
    gr.Markdown(
        """
        My stable diffusion demo
        """
    )

    with gr.Row():
        openai_api_key = gr.Textbox(label="OpenAI API Key", type='password')

    with gr.Row():
        input = gr.Textbox(label="Input")
    with gr.Row():
        result_video = gr.Video(label='Result', show_label=False, elem_id='gallery')
        result_image = gr.Image(label='Result', show_label=False)
    with gr.Row():
        generate_btn = gr.Button("Generate")
        generate_btn.click(fn=diffuse, inputs=[input, openai_api_key], outputs=[result_video, result_image],
                           api_name="stableDiffusion")
demo.launch()