hysts's picture
hysts HF Staff
Switch to 'runwayml/stable-diffusion-v1-5' Space (#3)
5796dbb
#!/usr/bin/env python
import datetime
import json
import os
import pathlib
import tempfile
from typing import Any
import gradio as gr
from gradio_client import Client
from scheduler import ParquetScheduler
HF_TOKEN = os.environ["HF_TOKEN"]
UPLOAD_REPO_ID = os.environ["UPLOAD_REPO_ID"]
UPLOAD_FREQUENCY = int(os.getenv("UPLOAD_FREQUENCY", "15"))
USE_PUBLIC_REPO = os.getenv("USE_PUBLIC_REPO") == "1"
ABOUT_THIS_SPACE = """
This Space is a sample Space that collects user preferences for the results generated by a diffusion model.
This demo calls the [stable diffusion Space](https://huggingface.co/spaces/stabilityai/stable-diffusion) with the [`gradio_client`](https://pypi.org/project/gradio-client/) library.
The user preference data is periodically archived in parquet format and uploaded to [this dataset repo](https://huggingface.co/datasets/hysts-samples/sample-user-preferences).
The periodic upload is done using [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler).
See [this Space](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) for more general usage.
"""
scheduler = ParquetScheduler(
repo_id=UPLOAD_REPO_ID,
every=UPLOAD_FREQUENCY,
private=not USE_PUBLIC_REPO,
token=HF_TOKEN,
)
# client = Client("stabilityai/stable-diffusion") # Space is paused
client = Client("runwayml/stable-diffusion-v1-5")
def generate(prompt: str) -> tuple[str, list[str]]:
negative_prompt = ""
guidance_scale = 9.0
# out_dir = client.predict(prompt, negative_prompt, guidance_scale, fn_index=1) # Space 'stabilityai/stable-diffusion' is paused
out_dir = client.predict(prompt, fn_index=1)
config = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"guidance_scale": guidance_scale,
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as config_file:
json.dump(config, config_file)
with (pathlib.Path(out_dir) / "captions.json").open() as f:
paths = list(json.load(f).keys())
return config_file.name, paths
def get_selected_index(evt: gr.SelectData) -> int:
return evt.index
def save_preference(config_path: str, gallery: list[dict[str, Any]], selected_index: int) -> None:
# Load config
with open(config_path) as f:
data = json.load(f)
# Add selected item + timestamp
data["selected_index"] = selected_index
data["timestamp"] = datetime.datetime.utcnow().isoformat()
# Add images
for index, path in enumerate(x["name"] for x in gallery):
data[f"image_{index:03d}"] = path
# Send to scheduler
scheduler.append(data)
def update_save_button(selected_index: int) -> dict:
return gr.update(interactive=selected_index != -1)
def clear() -> tuple[dict, dict, dict]:
return (
gr.update(value=None),
gr.update(value=None),
gr.update(value=-1),
)
with gr.Blocks(css="style.css") as demo:
with gr.Group():
prompt = gr.Text(show_label=False, placeholder="Prompt")
gallery = gr.Gallery(
show_label=False,
columns=2,
rows=2,
height="600px",
object_fit="scale-down",
allow_preview=False,
)
save_preference_button = gr.Button("Save preference", interactive=False)
config_path = gr.Text(visible=False)
selected_index = gr.Number(visible=False, precision=0, value=-1)
with gr.Accordion(label="About this Space", open=False):
gr.Markdown(ABOUT_THIS_SPACE)
prompt.submit(
fn=generate,
inputs=prompt,
outputs=[config_path, gallery],
api_name=False,
)
selected_index.change(
fn=update_save_button,
inputs=selected_index,
outputs=save_preference_button,
queue=False,
api_name=False,
)
gallery.select(
fn=get_selected_index,
outputs=selected_index,
queue=False,
api_name=False,
)
save_preference_button.click(
fn=save_preference,
inputs=[config_path, gallery, selected_index],
queue=False,
api_name=False,
).then(
fn=clear,
outputs=[config_path, gallery, selected_index],
queue=False,
api_name=False,
)
if __name__ == "__main__":
demo.queue(api_open=False, concurrency_count=5).launch()