File size: 3,914 Bytes
6c2dda3
 
7252e54
 
 
 
 
 
 
 
 
6c2dda3
7252e54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c2dda3
7252e54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python

import datetime
import json
import os
import pathlib
import shutil
import tempfile
import uuid
from typing import Any

import gradio as gr
from gradio_client import Client

from scheduler import ZipScheduler

HF_TOKEN = os.getenv('HF_TOKEN')
UPLOAD_REPO_ID = os.getenv('UPLOAD_REPO_ID')
UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '5'))
USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1'
LOCAL_SAVE_DIR = pathlib.Path(os.getenv('LOCAL_SAVE_DIR', 'results'))
LOCAL_SAVE_DIR.mkdir(parents=True, exist_ok=True)

scheduler = ZipScheduler(repo_id=UPLOAD_REPO_ID,
                         repo_type='dataset',
                         every=UPLOAD_FREQUENCY,
                         private=not USE_PUBLIC_REPO,
                         token=HF_TOKEN,
                         folder_path=LOCAL_SAVE_DIR)

client = Client('stabilityai/stable-diffusion')


def generate(prompt: str) -> tuple[str, list[str]]:
    negative_prompt = ''
    guidance_scale = 9
    out_dir = client.predict(prompt,
                             negative_prompt,
                             guidance_scale,
                             fn_index=1)

    config = {
        'prompt': prompt,
        'negative_prompt': negative_prompt,
        'guidance_scale': guidance_scale,
    }
    config_file = tempfile.NamedTemporaryFile(mode='w',
                                              suffix='.json',
                                              delete=False)
    json.dump(config, config_file)

    with open(pathlib.Path(out_dir) / 'captions.json') as f:
        paths = list(json.load(f).keys())
    return config_file.name, paths


def get_selected_index(evt: gr.SelectData) -> int:
    return evt.index


def save_preference(config_path: str, gallery: list[dict[str, Any]],
                    selected_index: int) -> None:
    save_dir = LOCAL_SAVE_DIR / f'{uuid.uuid4()}'
    save_dir.mkdir(parents=True, exist_ok=True)

    paths = [x['name'] for x in gallery]
    with scheduler.lock:
        for index, path in enumerate(paths):
            ext = pathlib.Path(path).suffix
            shutil.move(path, save_dir / f'{index:03d}{ext}')

        with open(config_path) as f:
            config = json.load(f)
        json_path = save_dir / 'preferences.json'
        with json_path.open('w') as f:
            preferences = config | {
                'selected_index': selected_index,
                'timestamp': datetime.datetime.utcnow().isoformat(),
            }
            json.dump(preferences, f)


def clear() -> tuple[dict, dict, dict]:
    return (
        gr.update(value=None),
        gr.update(value=None),
        gr.update(interactive=False),
    )


with gr.Blocks(css='style.css') as demo:
    with gr.Group():
        prompt = gr.Text(show_label=False, placeholder='Prompt')
        config_path = gr.Text(visible=False)
        gallery = gr.Gallery(show_label=False).style(columns=2,
                                                     rows=2,
                                                     height='600px',
                                                     object_fit='scale-down')
        selected_index = gr.Number(visible=False, precision=0)
    save_preference_button = gr.Button('Save preference', interactive=False)

    prompt.submit(
        fn=generate,
        inputs=prompt,
        outputs=[config_path, gallery],
    ).success(
        fn=lambda: gr.update(interactive=True),
        outputs=save_preference_button,
        queue=False,
    )

    gallery.select(
        fn=get_selected_index,
        outputs=selected_index,
        queue=False,
    )
    save_preference_button.click(
        fn=save_preference,
        inputs=[config_path, gallery, selected_index],
        queue=False,
    ).then(
        fn=clear,
        outputs=[config_path, gallery, save_preference_button],
        queue=False,
    )
demo.queue(concurrency_count=5).launch()