Gauthierdebo's picture
Update app.py
c904a51
from gradio_tools import StableDiffusionTool, ImageCaptioningTool, StableDiffusionPromptGeneratorTool, TextToVideoTool
from langchain import OpenAI
from langchain.agents import initialize_agent, AgentType
from langchain.memory import ConversationBufferMemory
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain.prompts import ChatPromptTemplate
import gradio as gr
def initialize_llm(openai_api_key):
try:
llm = OpenAI(temperature=0, openai_api_key=openai_api_key)
except Exception as e:
print(e)
return [None, False]
memory = ConversationBufferMemory(memory_key="chat_history")
tools = [StableDiffusionTool().langchain, ImageCaptioningTool().langchain,
StableDiffusionPromptGeneratorTool().langchain, TextToVideoTool().langchain]
agent = initialize_agent(tools, llm, memory=memory, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True)
return [agent, True]
def diffuse(query, openai_api_key):
structured_prompt = """
You are a program that converts a prompt to a suitable stable diffusion prompt. You generate an image with this prompt using stable diffusion. Then you generate a video of the generated image.
The prompt is {input_prompt}
you have to structure it using {structure}
"""
video_path = ResponseSchema(name="video_path",
description="Generate a video of the prompt and return its path.")
image_path = ResponseSchema(name="image_path",
description="Generate an image of the prompt and return its path")
output_parser = StructuredOutputParser.from_response_schemas(
[
video_path,
image_path,
]
)
format_instructions = output_parser.get_format_instructions()
structured_prompt_template = ChatPromptTemplate.from_template(structured_prompt)
my_message = structured_prompt_template.format_messages(
input_prompt=query,
structure=format_instructions
)
agent, success = initialize_llm(openai_api_key)
if success:
res = agent.run(my_message)
output_dict = output_parser.parse(res.content)
vid_path = output_dict.get(video_path)
img_path = output_dict.get(image_path)
print(vid_path, img_path)
return [vid_path, img_path]
else:
print("Something went wrong")
with gr.Blocks() as demo:
gr.Markdown(
"""
My stable diffusion demo
"""
)
with gr.Row():
openai_api_key = gr.Textbox(label="OpenAI API Key", type='password')
with gr.Row():
input = gr.Textbox(label="Input")
with gr.Row():
result_video = gr.Video(label='Result', show_label=False, elem_id='gallery')
result_image = gr.Image(label='Result', show_label=False)
with gr.Row():
generate_btn = gr.Button("Generate")
generate_btn.click(fn=diffuse, inputs=[input, openai_api_key], outputs=[result_video, result_image],
api_name="stableDiffusion")
demo.launch()