Spaces:
Runtime error
Runtime error
File size: 4,740 Bytes
9a5a85e 9a6a181 9a5a85e 9a6a181 9a5a85e 9a6a181 9a5a85e 9a6a181 9a5a85e 9a6a181 9a5a85e 9a6a181 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import os
import uvicorn
import gradio as gr
from pathlib import Path
from huggingface_hub import Repository
import json
from db import Database
from fastapi import FastAPI
from datetime import datetime
import subprocess
HF_TOKEN = os.environ.get("HF_TOKEN")
S3_DATA_FOLDER = Path("sd-multiplayer-data")
DB_FOLDER = Path("diffusers-gallery-data")
ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
repo = Repository(
local_dir=DB_FOLDER,
repo_type="dataset",
clone_from="huggingface-projects/diffusers-gallery-data",
use_auth_token=True,
)
repo.git_pull()
database = Database(DB_FOLDER)
styles_cls = ["anime", "3D", "realistic", "other"]
nsfw_cls = ["safe", "suggestive", "explicit"]
js_get_url_params = """
function (current_model, styles, nsfw) {
const params = new URLSearchParams(window.location.search);
current_model.model_id = params.get("model_id") || "";
window.history.replaceState({}, document.title, "/");
return [current_model, styles, nsfw]
}
"""
def next_model(query_params, styles=None, nsfw=None):
model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None
print(model_id, styles, nsfw)
with database.get_db() as db:
if model_id:
cursor = db.execute(
"""SELECT *,
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
FROM models
WHERE id = ?""", (model_id,))
row = cursor.fetchone()
if row is None:
raise gr.Error("Cannot find model to annotate")
else:
cursor = db.execute(
"""SELECT *,
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
FROM models
WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL
ORDER BY RANDOM()
LIMIT 1""")
row = cursor.fetchone()
if row is None:
raise gr.Error("Cannot find any more models to annotate")
total_unflagged = row["total_unflagged"]
model_id = row["id"]
data = json.loads(row["data"])
images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")]
flags_data = json.loads(row["flags"] or "{}")
styles = flags_data.get("styles", [])
nsfw = flags_data.get("nsfw", None)
title = f'''#### [Model {model_id}](https://huggingface.co/{model_id})
**Unflaggedd** {total_unflagged}'''
return images, title, styles, nsfw, {"model_id": model_id}
def flag_model(current_model, styles=None, nsfw=None):
model_id = current_model["model_id"]
print("Flagging model", model_id, styles, nsfw)
with database.get_db() as db:
db.execute(
"""UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id))
return next_model({}, styles, nsfw)
blocks = gr.Blocks()
with blocks:
gr.Markdown('''### Diffusers Gallery annotation tool
Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information.
''')
model_title = gr.Markdown()
gallery = gr.Gallery(
label="Images", show_label=False, elem_id="gallery"
).style(grid=[3])
styles = gr.CheckboxGroup(
styles_cls, info="Classify the image as one or more of the following classes")
nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?")
# invisible inputs to store the query params
query_params = gr.JSON(value={}, visible=False)
current_model = gr.State({})
next_btn = gr.Button("Next")
submit_btn = gr.Button("Submit")
next_btn.click(next_model, inputs=[query_params, styles, nsfw],
outputs=[gallery, model_title, styles, nsfw, current_model])
submit_btn.click(flag_model, inputs=[current_model, styles, nsfw], outputs=[
gallery, model_title, styles, nsfw, current_model])
blocks.load(next_model, inputs=[query_params, styles, nsfw],
outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params)
app = FastAPI()
@app.get("/sync")
def read_main():
sync_data()
return "Synced flagged"
def sync_data():
print("Updating DB repository")
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
subprocess.Popen(
f"git add . && git commit --amend -m 'update at flags {time}' && git push --force", cwd=DB_FOLDER, shell=True)
app = gr.mount_gradio_app(app, blocks, "/")
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=7860)
|