#!/usr/bin/env python import datetime import json import os import pathlib import tempfile from typing import Any import gradio as gr from gradio_client import Client from scheduler import ParquetScheduler HF_TOKEN = os.environ["HF_TOKEN"] UPLOAD_REPO_ID = os.environ["UPLOAD_REPO_ID"] UPLOAD_FREQUENCY = int(os.getenv("UPLOAD_FREQUENCY", "15")) USE_PUBLIC_REPO = os.getenv("USE_PUBLIC_REPO") == "1" ABOUT_THIS_SPACE = """ This Space is a sample Space that collects user preferences for the results generated by a diffusion model. This demo calls the [stable diffusion Space](https://huggingface.co/spaces/stabilityai/stable-diffusion) with the [`gradio_client`](https://pypi.org/project/gradio-client/) library. The user preference data is periodically archived in parquet format and uploaded to [this dataset repo](https://huggingface.co/datasets/hysts-samples/sample-user-preferences). The periodic upload is done using [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler). See [this Space](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) for more general usage. """ scheduler = ParquetScheduler( repo_id=UPLOAD_REPO_ID, every=UPLOAD_FREQUENCY, private=not USE_PUBLIC_REPO, token=HF_TOKEN, ) # client = Client("stabilityai/stable-diffusion") # Space is paused client = Client("runwayml/stable-diffusion-v1-5") def generate(prompt: str) -> tuple[str, list[str]]: negative_prompt = "" guidance_scale = 9.0 # out_dir = client.predict(prompt, negative_prompt, guidance_scale, fn_index=1) # Space 'stabilityai/stable-diffusion' is paused out_dir = client.predict(prompt, fn_index=1) config = { "prompt": prompt, "negative_prompt": negative_prompt, "guidance_scale": guidance_scale, } with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as config_file: json.dump(config, config_file) with (pathlib.Path(out_dir) / "captions.json").open() as f: paths = list(json.load(f).keys()) return config_file.name, paths def get_selected_index(evt: gr.SelectData) -> int: return evt.index def save_preference(config_path: str, gallery: list[dict[str, Any]], selected_index: int) -> None: # Load config with open(config_path) as f: data = json.load(f) # Add selected item + timestamp data["selected_index"] = selected_index data["timestamp"] = datetime.datetime.utcnow().isoformat() # Add images for index, path in enumerate(x["name"] for x in gallery): data[f"image_{index:03d}"] = path # Send to scheduler scheduler.append(data) def update_save_button(selected_index: int) -> dict: return gr.update(interactive=selected_index != -1) def clear() -> tuple[dict, dict, dict]: return ( gr.update(value=None), gr.update(value=None), gr.update(value=-1), ) with gr.Blocks(css="style.css") as demo: with gr.Group(): prompt = gr.Text(show_label=False, placeholder="Prompt") gallery = gr.Gallery( show_label=False, columns=2, rows=2, height="600px", object_fit="scale-down", allow_preview=False, ) save_preference_button = gr.Button("Save preference", interactive=False) config_path = gr.Text(visible=False) selected_index = gr.Number(visible=False, precision=0, value=-1) with gr.Accordion(label="About this Space", open=False): gr.Markdown(ABOUT_THIS_SPACE) prompt.submit( fn=generate, inputs=prompt, outputs=[config_path, gallery], api_name=False, ) selected_index.change( fn=update_save_button, inputs=selected_index, outputs=save_preference_button, queue=False, api_name=False, ) gallery.select( fn=get_selected_index, outputs=selected_index, queue=False, api_name=False, ) save_preference_button.click( fn=save_preference, inputs=[config_path, gallery, selected_index], queue=False, api_name=False, ).then( fn=clear, outputs=[config_path, gallery, selected_index], queue=False, api_name=False, ) if __name__ == "__main__": demo.queue(api_open=False, concurrency_count=5).launch()