radames's picture
radames HF staff
change order
9a6a181
import os
import uvicorn
import gradio as gr
from pathlib import Path
from huggingface_hub import Repository
import json
from db import Database
from fastapi import FastAPI
from datetime import datetime
import subprocess
HF_TOKEN = os.environ.get("HF_TOKEN")
S3_DATA_FOLDER = Path("sd-multiplayer-data")
DB_FOLDER = Path("diffusers-gallery-data")
ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
repo = Repository(
local_dir=DB_FOLDER,
repo_type="dataset",
clone_from="huggingface-projects/diffusers-gallery-data",
use_auth_token=True,
)
repo.git_pull()
database = Database(DB_FOLDER)
styles_cls = ["anime", "3D", "realistic", "other"]
nsfw_cls = ["safe", "suggestive", "explicit"]
js_get_url_params = """
function (current_model, styles, nsfw) {
const params = new URLSearchParams(window.location.search);
current_model.model_id = params.get("model_id") || "";
window.history.replaceState({}, document.title, "/");
return [current_model, styles, nsfw]
}
"""
def next_model(query_params, styles=None, nsfw=None):
model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None
print(model_id, styles, nsfw)
with database.get_db() as db:
if model_id:
cursor = db.execute(
"""SELECT *,
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
FROM models
WHERE id = ?""", (model_id,))
row = cursor.fetchone()
if row is None:
raise gr.Error("Cannot find model to annotate")
else:
cursor = db.execute(
"""SELECT *,
SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
FROM models
WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL
ORDER BY RANDOM()
LIMIT 1""")
row = cursor.fetchone()
if row is None:
raise gr.Error("Cannot find any more models to annotate")
total_unflagged = row["total_unflagged"]
model_id = row["id"]
data = json.loads(row["data"])
images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")]
flags_data = json.loads(row["flags"] or "{}")
styles = flags_data.get("styles", [])
nsfw = flags_data.get("nsfw", None)
title = f'''#### [Model {model_id}](https://huggingface.co/{model_id})
**Unflaggedd** {total_unflagged}'''
return images, title, styles, nsfw, {"model_id": model_id}
def flag_model(current_model, styles=None, nsfw=None):
model_id = current_model["model_id"]
print("Flagging model", model_id, styles, nsfw)
with database.get_db() as db:
db.execute(
"""UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id))
return next_model({}, styles, nsfw)
blocks = gr.Blocks()
with blocks:
gr.Markdown('''### Diffusers Gallery annotation tool
Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information.
''')
model_title = gr.Markdown()
gallery = gr.Gallery(
label="Images", show_label=False, elem_id="gallery"
).style(grid=[3])
styles = gr.CheckboxGroup(
styles_cls, info="Classify the image as one or more of the following classes")
nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?")
# invisible inputs to store the query params
query_params = gr.JSON(value={}, visible=False)
current_model = gr.State({})
next_btn = gr.Button("Next")
submit_btn = gr.Button("Submit")
next_btn.click(next_model, inputs=[query_params, styles, nsfw],
outputs=[gallery, model_title, styles, nsfw, current_model])
submit_btn.click(flag_model, inputs=[current_model, styles, nsfw], outputs=[
gallery, model_title, styles, nsfw, current_model])
blocks.load(next_model, inputs=[query_params, styles, nsfw],
outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params)
app = FastAPI()
@app.get("/sync")
def read_main():
sync_data()
return "Synced flagged"
def sync_data():
print("Updating DB repository")
time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
subprocess.Popen(
f"git add . && git commit --amend -m 'update at flags {time}' && git push --force", cwd=DB_FOLDER, shell=True)
app = gr.mount_gradio_app(app, blocks, "/")
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=7860)