Spaces:
Runtime error
Runtime error
from gradio_tools import StableDiffusionTool, ImageCaptioningTool, StableDiffusionPromptGeneratorTool, TextToVideoTool | |
from langchain import OpenAI | |
from langchain.agents import initialize_agent, AgentType | |
from langchain.memory import ConversationBufferMemory | |
from langchain.output_parsers import ResponseSchema, StructuredOutputParser | |
from langchain.prompts import ChatPromptTemplate | |
import gradio as gr | |
def initialize_llm(openai_api_key): | |
try: | |
llm = OpenAI(temperature=0, openai_api_key=openai_api_key) | |
except Exception as e: | |
print(e) | |
return [None, False] | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
tools = [StableDiffusionTool().langchain, ImageCaptioningTool().langchain, | |
StableDiffusionPromptGeneratorTool().langchain, TextToVideoTool().langchain] | |
agent = initialize_agent(tools, llm, memory=memory, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True) | |
return [agent, True] | |
def diffuse(query, openai_api_key): | |
structured_prompt = """ | |
You are a program that converts a prompt to a suitable stable diffusion prompt. You generate an image with this prompt using stable diffusion. Then you generate a video of the generated image. | |
The prompt is {input_prompt} | |
you have to structure it using {structure} | |
""" | |
video_path = ResponseSchema(name="video_path", | |
description="Generate a video of the prompt and return its path.") | |
image_path = ResponseSchema(name="image_path", | |
description="Generate an image of the prompt and return its path") | |
output_parser = StructuredOutputParser.from_response_schemas( | |
[ | |
video_path, | |
image_path, | |
] | |
) | |
format_instructions = output_parser.get_format_instructions() | |
structured_prompt_template = ChatPromptTemplate.from_template(structured_prompt) | |
my_message = structured_prompt_template.format_messages( | |
input_prompt=query, | |
structure=format_instructions | |
) | |
agent, success = initialize_llm(openai_api_key) | |
if success: | |
res = agent.run(my_message) | |
output_dict = output_parser.parse(res.content) | |
vid_path = output_dict.get(video_path) | |
img_path = output_dict.get(image_path) | |
print(vid_path, img_path) | |
return [vid_path, img_path] | |
else: | |
print("Something went wrong") | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
My stable diffusion demo | |
""" | |
) | |
with gr.Row(): | |
openai_api_key = gr.Textbox(label="OpenAI API Key", type='password') | |
with gr.Row(): | |
input = gr.Textbox(label="Input") | |
with gr.Row(): | |
result_video = gr.Video(label='Result', show_label=False, elem_id='gallery') | |
result_image = gr.Image(label='Result', show_label=False) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate") | |
generate_btn.click(fn=diffuse, inputs=[input, openai_api_key], outputs=[result_video, result_image], | |
api_name="stableDiffusion") | |
demo.launch() |