hysts's picture
hysts HF Staff
Update
7252e54
raw
history blame
3.91 kB
#!/usr/bin/env python
import datetime
import json
import os
import pathlib
import shutil
import tempfile
import uuid
from typing import Any
import gradio as gr
from gradio_client import Client
from scheduler import ZipScheduler
HF_TOKEN = os.getenv('HF_TOKEN')
UPLOAD_REPO_ID = os.getenv('UPLOAD_REPO_ID')
UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '5'))
USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1'
LOCAL_SAVE_DIR = pathlib.Path(os.getenv('LOCAL_SAVE_DIR', 'results'))
LOCAL_SAVE_DIR.mkdir(parents=True, exist_ok=True)
scheduler = ZipScheduler(repo_id=UPLOAD_REPO_ID,
repo_type='dataset',
every=UPLOAD_FREQUENCY,
private=not USE_PUBLIC_REPO,
token=HF_TOKEN,
folder_path=LOCAL_SAVE_DIR)
client = Client('stabilityai/stable-diffusion')
def generate(prompt: str) -> tuple[str, list[str]]:
negative_prompt = ''
guidance_scale = 9
out_dir = client.predict(prompt,
negative_prompt,
guidance_scale,
fn_index=1)
config = {
'prompt': prompt,
'negative_prompt': negative_prompt,
'guidance_scale': guidance_scale,
}
config_file = tempfile.NamedTemporaryFile(mode='w',
suffix='.json',
delete=False)
json.dump(config, config_file)
with open(pathlib.Path(out_dir) / 'captions.json') as f:
paths = list(json.load(f).keys())
return config_file.name, paths
def get_selected_index(evt: gr.SelectData) -> int:
return evt.index
def save_preference(config_path: str, gallery: list[dict[str, Any]],
selected_index: int) -> None:
save_dir = LOCAL_SAVE_DIR / f'{uuid.uuid4()}'
save_dir.mkdir(parents=True, exist_ok=True)
paths = [x['name'] for x in gallery]
with scheduler.lock:
for index, path in enumerate(paths):
ext = pathlib.Path(path).suffix
shutil.move(path, save_dir / f'{index:03d}{ext}')
with open(config_path) as f:
config = json.load(f)
json_path = save_dir / 'preferences.json'
with json_path.open('w') as f:
preferences = config | {
'selected_index': selected_index,
'timestamp': datetime.datetime.utcnow().isoformat(),
}
json.dump(preferences, f)
def clear() -> tuple[dict, dict, dict]:
return (
gr.update(value=None),
gr.update(value=None),
gr.update(interactive=False),
)
with gr.Blocks(css='style.css') as demo:
with gr.Group():
prompt = gr.Text(show_label=False, placeholder='Prompt')
config_path = gr.Text(visible=False)
gallery = gr.Gallery(show_label=False).style(columns=2,
rows=2,
height='600px',
object_fit='scale-down')
selected_index = gr.Number(visible=False, precision=0)
save_preference_button = gr.Button('Save preference', interactive=False)
prompt.submit(
fn=generate,
inputs=prompt,
outputs=[config_path, gallery],
).success(
fn=lambda: gr.update(interactive=True),
outputs=save_preference_button,
queue=False,
)
gallery.select(
fn=get_selected_index,
outputs=selected_index,
queue=False,
)
save_preference_button.click(
fn=save_preference,
inputs=[config_path, gallery, selected_index],
queue=False,
).then(
fn=clear,
outputs=[config_path, gallery, save_preference_button],
queue=False,
)
demo.queue(concurrency_count=5).launch()