radames commited on
Commit
9a5a85e
1 Parent(s): 8b53ae6
Files changed (4) hide show
  1. .gitignore +22 -0
  2. app.py +118 -0
  3. db.py +34 -0
  4. requirements.txt +5 -0
.gitignore ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ node_modules
3
+ /build
4
+ /.svelte-kit
5
+ /package
6
+ .env
7
+ .env.*
8
+ !.env.example
9
+
10
+ # Ignore files for PNPM, NPM and YARN
11
+ pnpm-lock.yaml
12
+ package-lock.json
13
+ yarn.lock
14
+ venv/
15
+ __pycache__/
16
+ flagged/
17
+ data
18
+ data.db
19
+ data.json
20
+ rooms_data.db
21
+ sd-multiplayer-data/
22
+ diffusers-gallery-data/
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import os
3
+ import re
4
+ from io import BytesIO
5
+ import uuid
6
+ import gradio as gr
7
+ from pathlib import Path
8
+ from huggingface_hub import Repository
9
+ import json
10
+ from db import Database
11
+
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
13
+ S3_DATA_FOLDER = Path("sd-multiplayer-data")
14
+ DB_FOLDER = Path("diffusers-gallery-data")
15
+ ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"
16
+
17
+ repo = Repository(
18
+ local_dir=DB_FOLDER,
19
+ repo_type="dataset",
20
+ clone_from="huggingface-projects/diffusers-gallery-data",
21
+ use_auth_token=True,
22
+ )
23
+ repo.git_pull()
24
+
25
+ database = Database(DB_FOLDER)
26
+
27
+
28
+ blocks = gr.Blocks()
29
+
30
+ styles_cls = ["anime", "3D", "realistic", "other"]
31
+ nsfw_cls = ["safe", "suggestive", "explicit"]
32
+
33
+
34
+ js_get_url_params = """
35
+ function (current_model, styles, nsfw) {
36
+ const params = new URLSearchParams(window.location.search);
37
+ current_model.model_id = params.get("model_id") || "";
38
+ window.history.replaceState({}, document.title, "/");
39
+ return [current_model, styles, nsfw]
40
+ }
41
+ """
42
+
43
+
44
+ def next_model(query_params, styles=None, nsfw=None):
45
+ model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None
46
+ print(model_id, styles, nsfw)
47
+
48
+ with database.get_db() as db:
49
+ if model_id:
50
+ cursor = db.execute(
51
+ """SELECT *,
52
+ SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
53
+ FROM models
54
+ WHERE id = ?""", (model_id,))
55
+ row = cursor.fetchone()
56
+ if row is None:
57
+ raise gr.Error("Cannot find model to annotate")
58
+ else:
59
+ cursor = db.execute(
60
+ """SELECT *,
61
+ SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
62
+ FROM models
63
+ WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL
64
+ ORDER BY RANDOM()
65
+ LIMIT 1""")
66
+ row = cursor.fetchone()
67
+ if row is None:
68
+ raise gr.Error("Cannot find any more models to annotate")
69
+
70
+ total_unflagged = row["total_unflagged"]
71
+ model_id = row["id"]
72
+ data = json.loads(row["data"])
73
+ images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")]
74
+ flags_data = json.loads(row["flags"] or "{}")
75
+ styles = flags_data.get("styles", [])
76
+ nsfw = flags_data.get("nsfw", None)
77
+
78
+ title = f'''#### [Model {model_id}](https://huggingface.co/{model_id})
79
+ **Unflaggedd** {total_unflagged}'''
80
+
81
+ return images, title, styles, nsfw, {"model_id": model_id}
82
+
83
+
84
+ def flag_model(current_model, styles=None, nsfw=None):
85
+ model_id = current_model["model_id"]
86
+ print("Flagging model", model_id, styles, nsfw)
87
+ with database.get_db() as db:
88
+ db.execute(
89
+ """UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id))
90
+ return next_model({}, styles, nsfw)
91
+
92
+
93
+ with blocks:
94
+ gr.Markdown('''### Diffusers Gallery annotation tool
95
+ Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information.
96
+ ''')
97
+ model_title = gr.Markdown()
98
+ gallery = gr.Gallery(
99
+ label="Images", show_label=False, elem_id="gallery"
100
+ ).style(grid=[3])
101
+ styles = gr.CheckboxGroup(
102
+ styles_cls, info="Classify the image as one or more of the following classes")
103
+ nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?")
104
+ # invisible inputs to store the query params
105
+ query_params = gr.JSON(value={}, visible=False)
106
+ current_model = gr.State({})
107
+ next_btn = gr.Button("Next")
108
+ submit_btn = gr.Button("Submit")
109
+ next_btn.click(next_model, inputs=[query_params, styles, nsfw],
110
+ outputs=[gallery, model_title, styles, nsfw, current_model])
111
+
112
+ submit_btn.click(flag_model, inputs=[current_model, styles, nsfw], outputs=[
113
+ gallery, model_title, styles, nsfw, current_model])
114
+
115
+ blocks.load(next_model, inputs=[query_params, styles, nsfw],
116
+ outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params)
117
+
118
+ blocks.launch(enable_queue=False)
db.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from pathlib import Path
3
+ import math
4
+
5
+
6
+ def power(x, y):
7
+ return math.pow(x, y)
8
+
9
+
10
+ class Database:
11
+ def __init__(self, db_path=None):
12
+ if db_path is None:
13
+ raise ValueError("db_path must be provided")
14
+ self.db_path = db_path
15
+ self.db_file = self.db_path / "models.db"
16
+ with self.get_db() as db:
17
+ try:
18
+ db.execute("ALTER TABLE models ADD COLUMN flags TEXT")
19
+ except sqlite3.OperationalError:
20
+ print("Column Flags already exists")
21
+ pass
22
+
23
+ def get_db(self):
24
+ db = sqlite3.connect(self.db_file, check_same_thread=False)
25
+ db.create_function("MYPOWER", 2, power)
26
+ db.row_factory = sqlite3.Row
27
+ return db
28
+
29
+ def __enter__(self):
30
+ self.db = self.get_db()
31
+ return self.db
32
+
33
+ def __exit__(self, exc_type, exc_value, traceback):
34
+ self.db.close()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ huggingface-hub
2
+ uvicorn
3
+ fastapi
4
+ Pillow
5
+ gradio