Spaces:
Runtime error
Runtime error
File size: 3,914 Bytes
6c2dda3 7252e54 6c2dda3 7252e54 6c2dda3 7252e54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
#!/usr/bin/env python
import datetime
import json
import os
import pathlib
import shutil
import tempfile
import uuid
from typing import Any
import gradio as gr
from gradio_client import Client
from scheduler import ZipScheduler
HF_TOKEN = os.getenv('HF_TOKEN')
UPLOAD_REPO_ID = os.getenv('UPLOAD_REPO_ID')
UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '5'))
USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1'
LOCAL_SAVE_DIR = pathlib.Path(os.getenv('LOCAL_SAVE_DIR', 'results'))
LOCAL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
scheduler = ZipScheduler(repo_id=UPLOAD_REPO_ID,
repo_type='dataset',
every=UPLOAD_FREQUENCY,
private=not USE_PUBLIC_REPO,
token=HF_TOKEN,
folder_path=LOCAL_SAVE_DIR)
client = Client('stabilityai/stable-diffusion')
def generate(prompt: str) -> tuple[str, list[str]]:
negative_prompt = ''
guidance_scale = 9
out_dir = client.predict(prompt,
negative_prompt,
guidance_scale,
fn_index=1)
config = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'guidance_scale': guidance_scale,
}
config_file = tempfile.NamedTemporaryFile(mode='w',
suffix='.json',
delete=False)
json.dump(config, config_file)
with open(pathlib.Path(out_dir) / 'captions.json') as f:
paths = list(json.load(f).keys())
return config_file.name, paths
def get_selected_index(evt: gr.SelectData) -> int:
return evt.index
def save_preference(config_path: str, gallery: list[dict[str, Any]],
selected_index: int) -> None:
save_dir = LOCAL_SAVE_DIR / f'{uuid.uuid4()}'
save_dir.mkdir(parents=True, exist_ok=True)
paths = [x['name'] for x in gallery]
with scheduler.lock:
for index, path in enumerate(paths):
ext = pathlib.Path(path).suffix
shutil.move(path, save_dir / f'{index:03d}{ext}')
with open(config_path) as f:
config = json.load(f)
json_path = save_dir / 'preferences.json'
with json_path.open('w') as f:
preferences = config | {
'selected_index': selected_index,
'timestamp': datetime.datetime.utcnow().isoformat(),
}
json.dump(preferences, f)
def clear() -> tuple[dict, dict, dict]:
return (
gr.update(value=None),
gr.update(value=None),
gr.update(interactive=False),
)
with gr.Blocks(css='style.css') as demo:
with gr.Group():
prompt = gr.Text(show_label=False, placeholder='Prompt')
config_path = gr.Text(visible=False)
gallery = gr.Gallery(show_label=False).style(columns=2,
rows=2,
height='600px',
object_fit='scale-down')
selected_index = gr.Number(visible=False, precision=0)
save_preference_button = gr.Button('Save preference', interactive=False)
prompt.submit(
fn=generate,
inputs=prompt,
outputs=[config_path, gallery],
).success(
fn=lambda: gr.update(interactive=True),
outputs=save_preference_button,
queue=False,
)
gallery.select(
fn=get_selected_index,
outputs=selected_index,
queue=False,
)
save_preference_button.click(
fn=save_preference,
inputs=[config_path, gallery, selected_index],
queue=False,
).then(
fn=clear,
outputs=[config_path, gallery, save_preference_button],
queue=False,
)
demo.queue(concurrency_count=5).launch()
|