import os import uvicorn import gradio as gr from pathlib import Path from huggingface_hub import Repository import json from db import Database from fastapi import FastAPI from datetime import datetime import subprocess HF_TOKEN = os.environ.get("HF_TOKEN") S3_DATA_FOLDER = Path("sd-multiplayer-data") DB_FOLDER = Path("diffusers-gallery-data") ASSETS_URL = "" repo = Repository( local_dir=DB_FOLDER, repo_type="dataset", clone_from="huggingface-projects/diffusers-gallery-data", use_auth_token=True, ) repo.git_pull() database = Database(DB_FOLDER) styles_cls = ["anime", "3D", "realistic", "other"] nsfw_cls = ["safe", "suggestive", "explicit"] js_get_url_params = """ function (current_model, styles, nsfw) { const params = new URLSearchParams(; current_model.model_id = params.get("model_id") || ""; window.history.replaceState({}, document.title, "/"); return [current_model, styles, nsfw] } """ def next_model(query_params, styles=None, nsfw=None): model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None print(model_id, styles, nsfw) with database.get_db() as db: if model_id: cursor = db.execute( """SELECT *, SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged FROM models WHERE id = ?""", (model_id,)) row = cursor.fetchone() if row is None: raise gr.Error("Cannot find model to annotate") else: cursor = db.execute( """SELECT *, SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged FROM models WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL ORDER BY RANDOM() LIMIT 1""") row = cursor.fetchone() if row is None: raise gr.Error("Cannot find any more models to annotate") total_unflagged = row["total_unflagged"] model_id = row["id"] data = json.loads(row["data"]) images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")] flags_data = json.loads(row["flags"] or "{}") styles = flags_data.get("styles", []) nsfw = flags_data.get("nsfw", None) title = f'''#### [Model {model_id}]({model_id}) **Unflaggedd** {total_unflagged}''' return images, title, styles, nsfw, {"model_id": model_id} def flag_model(current_model, styles=None, nsfw=None): model_id = current_model["model_id"] print("Flagging model", model_id, styles, nsfw) with database.get_db() as db: db.execute( """UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id)) return next_model({}, styles, nsfw) blocks = gr.Blocks() with blocks: gr.Markdown('''### Diffusers Gallery annotation tool Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information. ''') model_title = gr.Markdown() gallery = gr.Gallery( label="Images", show_label=False, elem_id="gallery" ).style(grid=[3]) styles = gr.CheckboxGroup( styles_cls, info="Classify the image as one or more of the following classes") nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?") # invisible inputs to store the query params query_params = gr.JSON(value={}, visible=False) current_model = gr.State({}) next_btn = gr.Button("Next") submit_btn = gr.Button("Submit"), inputs=[query_params, styles, nsfw], outputs=[gallery, model_title, styles, nsfw, current_model]), inputs=[current_model, styles, nsfw], outputs=[ gallery, model_title, styles, nsfw, current_model]) blocks.load(next_model, inputs=[query_params, styles, nsfw], outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params) app = FastAPI() @app.get("/sync") def read_main(): sync_data() return "Synced flagged" def sync_data(): print("Updating DB repository") time ="%Y-%m-%d %H:%M:%S") subprocess.Popen( f"git add . && git commit --amend -m 'update at flags {time}' && git push --force", cwd=DB_FOLDER, shell=True) app = gr.mount_gradio_app(app, blocks, "/") if __name__ == "__main__":, host='', port=7860)