import gradio as gr from qdhf_things import run_qdhf, many_pictures from generate_examples import EXAMPLE_PROMPTS import os import io # Get the absolute path to the examples directory EXAMPLES_DIR = os.path.abspath("./examples") def generate_images(prompt, init_pop, total_itrs): init_pop = int(init_pop) total_itrs = int(total_itrs) # Use placeholder if prompt is empty if not prompt.strip(): prompt = "a duck crossing the street" archive_plots = [] for archive, plt_fig in run_qdhf(prompt, init_pop, total_itrs): buf = io.BytesIO() plt_fig.savefig(buf, format='png') buf.seek(0) archive_plots.append(buf.getvalue()) final_archive_plot = archive_plots[-1] generated_images = many_pictures(archive, prompt) # Save the final archive plot and generated images as temporary files temp_archive_file = "temp_archive_plot.png" temp_images_file = "temp_generated_images.png" with open(temp_archive_file, 'wb') as f: f.write(final_archive_plot) generated_images.savefig(temp_images_file) return temp_archive_file, temp_images_file def show_example(prompt): index = EXAMPLE_PROMPTS.index(prompt) archive_plot_path = os.path.join(EXAMPLES_DIR, f"archive_{index}.mp4") images_path = os.path.join(EXAMPLES_DIR, f"archive_pics_{index}.png") return prompt, archive_plot_path, images_path with gr.Blocks() as demo: gr.Markdown("# Quality Diversity through Human Feedback") gr.Markdown("[Paper](https://arxiv.org/abs/2310.12103) | [Project Website](https://liding.info/qdhf/)") with gr.Row(): with gr.Column(scale=1): prompt_input = gr.Textbox(label="Enter your prompt here", placeholder="a duck crossing the street") init_pop = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Initial Population") total_itrs = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Total Iterations") generate_button = gr.Button("Generate", variant="primary") with gr.Column(scale=2): archive_output = gr.Video(label="Archive Plot") images_output = gr.Image(label="Generated Pictures") generate_button.click(generate_images, inputs=[prompt_input, init_pop, total_itrs], outputs=[archive_output, images_output]) gr.Examples( examples=EXAMPLE_PROMPTS, inputs=prompt_input, outputs=[prompt_input, archive_output, images_output], fn=show_example, cache_examples=True, label="Example Prompts" ) if __name__ == "__main__": demo.launch()