File size: 4,437 Bytes
6c2dda3
 
7252e54
 
 
 
 
 
 
6c2dda3
7252e54
 
e51b1cb
7252e54
e97aac1
 
 
 
7252e54
e97aac1
72ac7b4
 
 
e51b1cb
72ac7b4
 
 
e97aac1
72ac7b4
e97aac1
306753d
 
 
 
e97aac1
7252e54
5796dbb
 
7252e54
 
 
e97aac1
e820b78
5796dbb
 
7252e54
 
e97aac1
 
 
7252e54
e97aac1
e51b1cb
7252e54
e97aac1
7252e54
 
 
 
 
 
 
 
e97aac1
e51b1cb
 
 
 
 
e97aac1
 
e51b1cb
1bb0264
e97aac1
 
e51b1cb
 
 
7252e54
 
bb7103b
 
 
 
7252e54
 
 
 
bb7103b
7252e54
 
 
e97aac1
7252e54
e97aac1
 
306753d
 
 
 
 
 
e97aac1
 
7252e54
4578dbd
bb7103b
4578dbd
e97aac1
72ac7b4
dc8c538
7252e54
 
 
 
e6ddf0c
bb7103b
 
 
 
7252e54
 
e6ddf0c
7252e54
 
 
 
 
e6ddf0c
7252e54
 
 
 
 
e6ddf0c
7252e54
 
bb7103b
7252e54
e6ddf0c
7252e54
306753d
 
e6ddf0c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python

import datetime
import json
import os
import pathlib
import tempfile
from typing import Any

import gradio as gr
from gradio_client import Client

from scheduler import ParquetScheduler

HF_TOKEN = os.environ["HF_TOKEN"]
UPLOAD_REPO_ID = os.environ["UPLOAD_REPO_ID"]
UPLOAD_FREQUENCY = int(os.getenv("UPLOAD_FREQUENCY", "15"))
USE_PUBLIC_REPO = os.getenv("USE_PUBLIC_REPO") == "1"

ABOUT_THIS_SPACE = """
This Space is a sample Space that collects user preferences for the results generated by a diffusion model.
This demo calls the [stable diffusion Space](https://huggingface.co/spaces/stabilityai/stable-diffusion) with the [`gradio_client`](https://pypi.org/project/gradio-client/) library.

The user preference data is periodically archived in parquet format and uploaded to [this dataset repo](https://huggingface.co/datasets/hysts-samples/sample-user-preferences).

The periodic upload is done using [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler).
See [this Space](https://huggingface.co/spaces/Wauplin/space_to_dataset_saver) for more general usage.
"""

scheduler = ParquetScheduler(
    repo_id=UPLOAD_REPO_ID,
    every=UPLOAD_FREQUENCY,
    private=not USE_PUBLIC_REPO,
    token=HF_TOKEN,
)

# client = Client("stabilityai/stable-diffusion") # Space is paused
client = Client("runwayml/stable-diffusion-v1-5")


def generate(prompt: str) -> tuple[str, list[str]]:
    negative_prompt = ""
    guidance_scale = 9.0
    # out_dir = client.predict(prompt, negative_prompt, guidance_scale, fn_index=1) # Space 'stabilityai/stable-diffusion' is paused
    out_dir = client.predict(prompt, fn_index=1)

    config = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "guidance_scale": guidance_scale,
    }
    with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as config_file:
        json.dump(config, config_file)

    with (pathlib.Path(out_dir) / "captions.json").open() as f:
        paths = list(json.load(f).keys())
    return config_file.name, paths


def get_selected_index(evt: gr.SelectData) -> int:
    return evt.index


def save_preference(config_path: str, gallery: list[dict[str, Any]], selected_index: int) -> None:
    # Load config
    with open(config_path) as f:
        data = json.load(f)

    # Add selected item + timestamp
    data["selected_index"] = selected_index
    data["timestamp"] = datetime.datetime.utcnow().isoformat()

    # Add images
    for index, path in enumerate(x["name"] for x in gallery):
        data[f"image_{index:03d}"] = path

    # Send to scheduler
    scheduler.append(data)


def update_save_button(selected_index: int) -> dict:
    return gr.update(interactive=selected_index != -1)


def clear() -> tuple[dict, dict, dict]:
    return (
        gr.update(value=None),
        gr.update(value=None),
        gr.update(value=-1),
    )


with gr.Blocks(css="style.css") as demo:
    with gr.Group():
        prompt = gr.Text(show_label=False, placeholder="Prompt")
        gallery = gr.Gallery(
            show_label=False,
            columns=2,
            rows=2,
            height="600px",
            object_fit="scale-down",
            allow_preview=False,
        )
    save_preference_button = gr.Button("Save preference", interactive=False)

    config_path = gr.Text(visible=False)
    selected_index = gr.Number(visible=False, precision=0, value=-1)

    with gr.Accordion(label="About this Space", open=False):
        gr.Markdown(ABOUT_THIS_SPACE)

    prompt.submit(
        fn=generate,
        inputs=prompt,
        outputs=[config_path, gallery],
        api_name=False,
    )
    selected_index.change(
        fn=update_save_button,
        inputs=selected_index,
        outputs=save_preference_button,
        queue=False,
        api_name=False,
    )
    gallery.select(
        fn=get_selected_index,
        outputs=selected_index,
        queue=False,
        api_name=False,
    )
    save_preference_button.click(
        fn=save_preference,
        inputs=[config_path, gallery, selected_index],
        queue=False,
        api_name=False,
    ).then(
        fn=clear,
        outputs=[config_path, gallery, selected_index],
        queue=False,
        api_name=False,
    )

if __name__ == "__main__":
    demo.queue(api_open=False, concurrency_count=5).launch()