|
import gradio as gr |
|
from qdhf_things import run_qdhf, many_pictures |
|
from generate_examples import EXAMPLE_PROMPTS |
|
import os |
|
import io |
|
|
|
|
|
EXAMPLES_DIR = os.path.abspath("./examples") |
|
|
|
def generate_images(prompt, init_pop, total_itrs): |
|
init_pop = int(init_pop) |
|
total_itrs = int(total_itrs) |
|
|
|
if not prompt.strip(): |
|
prompt = "a duck crossing the street" |
|
archive_plots = [] |
|
for archive, plt_fig in run_qdhf(prompt, init_pop, total_itrs): |
|
buf = io.BytesIO() |
|
plt_fig.savefig(buf, format='png') |
|
buf.seek(0) |
|
archive_plots.append(buf.getvalue()) |
|
|
|
final_archive_plot = archive_plots[-1] |
|
generated_images = many_pictures(archive, prompt) |
|
|
|
|
|
temp_archive_file = "temp_archive_plot.png" |
|
temp_images_file = "temp_generated_images.png" |
|
|
|
with open(temp_archive_file, 'wb') as f: |
|
f.write(final_archive_plot) |
|
|
|
generated_images.savefig(temp_images_file) |
|
|
|
return temp_archive_file, temp_images_file |
|
|
|
def show_example(prompt): |
|
index = EXAMPLE_PROMPTS.index(prompt) |
|
archive_plot_path = os.path.join(EXAMPLES_DIR, f"archive_{index}.mp4") |
|
images_path = os.path.join(EXAMPLES_DIR, f"archive_pics_{index}.png") |
|
return prompt, archive_plot_path, images_path |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Quality Diversity through Human Feedback") |
|
gr.Markdown("[Paper](https://arxiv.org/abs/2310.12103) | [Project Website](https://liding.info/qdhf/)") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
prompt_input = gr.Textbox(label="Enter your prompt here", placeholder="a duck crossing the street") |
|
init_pop = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Initial Population") |
|
total_itrs = gr.Slider(minimum=10, maximum=300, value=200, step=10, label="Total Iterations") |
|
generate_button = gr.Button("Generate", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
archive_output = gr.Video(label="Archive Plot") |
|
images_output = gr.Image(label="Generated Pictures") |
|
|
|
generate_button.click(generate_images, |
|
inputs=[prompt_input, init_pop, total_itrs], |
|
outputs=[archive_output, images_output]) |
|
|
|
gr.Examples( |
|
examples=EXAMPLE_PROMPTS, |
|
inputs=prompt_input, |
|
outputs=[prompt_input, archive_output, images_output], |
|
fn=show_example, |
|
cache_examples=True, |
|
label="Example Prompts" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |