Spaces:
Runtime error
Runtime error
File size: 4,437 Bytes
6c2dda3 7252e54 6c2dda3 7252e54 e51b1cb 7252e54 e97aac1 7252e54 e97aac1 72ac7b4 e51b1cb 72ac7b4 e97aac1 72ac7b4 e97aac1 306753d e97aac1 7252e54 5796dbb 7252e54 e97aac1 e820b78 5796dbb 7252e54 e97aac1 7252e54 e97aac1 e51b1cb 7252e54 e97aac1 7252e54 e97aac1 e51b1cb e97aac1 e51b1cb 1bb0264 e97aac1 e51b1cb 7252e54 bb7103b 7252e54 bb7103b 7252e54 e97aac1 7252e54 e97aac1 306753d e97aac1 7252e54 4578dbd bb7103b 4578dbd e97aac1 72ac7b4 dc8c538 7252e54 e6ddf0c bb7103b 7252e54 e6ddf0c 7252e54 e6ddf0c 7252e54 e6ddf0c 7252e54 bb7103b 7252e54 e6ddf0c 7252e54 306753d e6ddf0c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
#!/usr/bin/env python
import datetime
import json
import os
import pathlib
import tempfile
from typing import Any
import gradio as gr
from gradio_client import Client
from scheduler import ParquetScheduler
HF_TOKEN = os.environ["HF_TOKEN"]
UPLOAD_REPO_ID = os.environ["UPLOAD_REPO_ID"]
UPLOAD_FREQUENCY = int(os.getenv("UPLOAD_FREQUENCY", "15"))
USE_PUBLIC_REPO = os.getenv("USE_PUBLIC_REPO") == "1"
ABOUT_THIS_SPACE = """
This Space is a sample Space that collects user preferences for the results generated by a diffusion model.
This demo calls the [stable diffusion Space](https://huggingface.co/spaces/stabilityai/stable-diffusion) with the [`gradio_client`](https://pypi.org/project/gradio-client/) library.
The user preference data is periodically archived in parquet format and uploaded to [this dataset repo](https://huggingface.co/datasets/hysts-samples/sample-user-preferences).
The periodic upload is done using [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler).
See [this Space](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) for more general usage.
"""
scheduler = ParquetScheduler(
repo_id=UPLOAD_REPO_ID,
every=UPLOAD_FREQUENCY,
private=not USE_PUBLIC_REPO,
token=HF_TOKEN,
)
# client = Client("stabilityai/stable-diffusion") # Space is paused
client = Client("runwayml/stable-diffusion-v1-5")
def generate(prompt: str) -> tuple[str, list[str]]:
negative_prompt = ""
guidance_scale = 9.0
# out_dir = client.predict(prompt, negative_prompt, guidance_scale, fn_index=1) # Space 'stabilityai/stable-diffusion' is paused
out_dir = client.predict(prompt, fn_index=1)
config = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"guidance_scale": guidance_scale,
}
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as config_file:
json.dump(config, config_file)
with (pathlib.Path(out_dir) / "captions.json").open() as f:
paths = list(json.load(f).keys())
return config_file.name, paths
def get_selected_index(evt: gr.SelectData) -> int:
return evt.index
def save_preference(config_path: str, gallery: list[dict[str, Any]], selected_index: int) -> None:
# Load config
with open(config_path) as f:
data = json.load(f)
# Add selected item + timestamp
data["selected_index"] = selected_index
data["timestamp"] = datetime.datetime.utcnow().isoformat()
# Add images
for index, path in enumerate(x["name"] for x in gallery):
data[f"image_{index:03d}"] = path
# Send to scheduler
scheduler.append(data)
def update_save_button(selected_index: int) -> dict:
return gr.update(interactive=selected_index != -1)
def clear() -> tuple[dict, dict, dict]:
return (
gr.update(value=None),
gr.update(value=None),
gr.update(value=-1),
)
with gr.Blocks(css="style.css") as demo:
with gr.Group():
prompt = gr.Text(show_label=False, placeholder="Prompt")
gallery = gr.Gallery(
show_label=False,
columns=2,
rows=2,
height="600px",
object_fit="scale-down",
allow_preview=False,
)
save_preference_button = gr.Button("Save preference", interactive=False)
config_path = gr.Text(visible=False)
selected_index = gr.Number(visible=False, precision=0, value=-1)
with gr.Accordion(label="About this Space", open=False):
gr.Markdown(ABOUT_THIS_SPACE)
prompt.submit(
fn=generate,
inputs=prompt,
outputs=[config_path, gallery],
api_name=False,
)
selected_index.change(
fn=update_save_button,
inputs=selected_index,
outputs=save_preference_button,
queue=False,
api_name=False,
)
gallery.select(
fn=get_selected_index,
outputs=selected_index,
queue=False,
api_name=False,
)
save_preference_button.click(
fn=save_preference,
inputs=[config_path, gallery, selected_index],
queue=False,
api_name=False,
).then(
fn=clear,
outputs=[config_path, gallery, selected_index],
queue=False,
api_name=False,
)
if __name__ == "__main__":
demo.queue(api_open=False, concurrency_count=5).launch()
|