Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
import datetime | |
import json | |
import os | |
import pathlib | |
import shutil | |
import tempfile | |
import uuid | |
from typing import Any | |
import gradio as gr | |
from gradio_client import Client | |
from scheduler import ZipScheduler | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
UPLOAD_REPO_ID = os.getenv('UPLOAD_REPO_ID') | |
UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '5')) | |
USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1' | |
LOCAL_SAVE_DIR = pathlib.Path(os.getenv('LOCAL_SAVE_DIR', 'results')) | |
LOCAL_SAVE_DIR.mkdir(parents=True, exist_ok=True) | |
scheduler = ZipScheduler(repo_id=UPLOAD_REPO_ID, | |
repo_type='dataset', | |
every=UPLOAD_FREQUENCY, | |
private=not USE_PUBLIC_REPO, | |
token=HF_TOKEN, | |
folder_path=LOCAL_SAVE_DIR) | |
client = Client('stabilityai/stable-diffusion') | |
def generate(prompt: str) -> tuple[str, list[str]]: | |
negative_prompt = '' | |
guidance_scale = 9 | |
out_dir = client.predict(prompt, | |
negative_prompt, | |
guidance_scale, | |
fn_index=1) | |
config = { | |
'prompt': prompt, | |
'negative_prompt': negative_prompt, | |
'guidance_scale': guidance_scale, | |
} | |
config_file = tempfile.NamedTemporaryFile(mode='w', | |
suffix='.json', | |
delete=False) | |
json.dump(config, config_file) | |
with open(pathlib.Path(out_dir) / 'captions.json') as f: | |
paths = list(json.load(f).keys()) | |
return config_file.name, paths | |
def get_selected_index(evt: gr.SelectData) -> int: | |
return evt.index | |
def save_preference(config_path: str, gallery: list[dict[str, Any]], | |
selected_index: int) -> None: | |
save_dir = LOCAL_SAVE_DIR / f'{uuid.uuid4()}' | |
save_dir.mkdir(parents=True, exist_ok=True) | |
paths = [x['name'] for x in gallery] | |
with scheduler.lock: | |
for index, path in enumerate(paths): | |
ext = pathlib.Path(path).suffix | |
shutil.move(path, save_dir / f'{index:03d}{ext}') | |
with open(config_path) as f: | |
config = json.load(f) | |
json_path = save_dir / 'preferences.json' | |
with json_path.open('w') as f: | |
preferences = config | { | |
'selected_index': selected_index, | |
'timestamp': datetime.datetime.utcnow().isoformat(), | |
} | |
json.dump(preferences, f) | |
def clear() -> tuple[dict, dict, dict]: | |
return ( | |
gr.update(value=None), | |
gr.update(value=None), | |
gr.update(interactive=False), | |
) | |
with gr.Blocks(css='style.css') as demo: | |
with gr.Group(): | |
prompt = gr.Text(show_label=False, placeholder='Prompt') | |
config_path = gr.Text(visible=False) | |
gallery = gr.Gallery(show_label=False).style(columns=2, | |
rows=2, | |
height='600px', | |
object_fit='scale-down') | |
selected_index = gr.Number(visible=False, precision=0) | |
save_preference_button = gr.Button('Save preference', interactive=False) | |
prompt.submit( | |
fn=generate, | |
inputs=prompt, | |
outputs=[config_path, gallery], | |
).success( | |
fn=lambda: gr.update(interactive=True), | |
outputs=save_preference_button, | |
queue=False, | |
) | |
gallery.select( | |
fn=get_selected_index, | |
outputs=selected_index, | |
queue=False, | |
) | |
save_preference_button.click( | |
fn=save_preference, | |
inputs=[config_path, gallery, selected_index], | |
queue=False, | |
).then( | |
fn=clear, | |
outputs=[config_path, gallery, save_preference_button], | |
queue=False, | |
) | |
demo.queue(concurrency_count=5).launch() | |