jennyzzt's picture
update
eabf1a3
import gradio as gr
from qdhf_things import run_qdhf, many_pictures
from generate_examples import EXAMPLE_PROMPTS
import os
import io
# Get the absolute path to the examples directory
EXAMPLES_DIR = os.path.abspath("./examples")
def generate_images(prompt, init_pop, total_itrs):
init_pop = int(init_pop)
total_itrs = int(total_itrs)
# Use placeholder if prompt is empty
if not prompt.strip():
prompt = "a duck crossing the street"
archive_plots = []
for archive, plt_fig in run_qdhf(prompt, init_pop, total_itrs):
buf = io.BytesIO()
plt_fig.savefig(buf, format='png')
buf.seek(0)
archive_plots.append(buf.getvalue())
final_archive_plot = archive_plots[-1]
generated_images = many_pictures(archive, prompt)
# Save the final archive plot and generated images as temporary files
temp_archive_file = "temp_archive_plot.png"
temp_images_file = "temp_generated_images.png"
with open(temp_archive_file, 'wb') as f:
f.write(final_archive_plot)
generated_images.savefig(temp_images_file)
return temp_archive_file, temp_images_file
def show_example(prompt):
index = EXAMPLE_PROMPTS.index(prompt)
archive_plot_path = os.path.join(EXAMPLES_DIR, f"archive_{index}.mp4")
images_path = os.path.join(EXAMPLES_DIR, f"archive_pics_{index}.png")
return prompt, archive_plot_path, images_path
with gr.Blocks() as demo:
gr.Markdown("# Quality Diversity through Human Feedback")
gr.Markdown("[Paper](https://arxiv.org/abs/2310.12103) | [Project Website](https://liding.info/qdhf/)")
with gr.Row():
with gr.Column(scale=1):
prompt_input = gr.Textbox(label="Enter your prompt here", placeholder="a duck crossing the street")
init_pop = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Initial Population")
total_itrs = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Total Iterations")
generate_button = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
archive_output = gr.Video(label="Archive Plot")
images_output = gr.Image(label="Generated Pictures")
generate_button.click(generate_images,
inputs=[prompt_input, init_pop, total_itrs],
outputs=[archive_output, images_output])
gr.Examples(
examples=EXAMPLE_PROMPTS,
inputs=prompt_input,
outputs=[prompt_input, archive_output, images_output],
fn=show_example,
cache_examples=True,
label="Example Prompts"
)
if __name__ == "__main__":
demo.launch()