File size: 4,740 Bytes
9a5a85e
9a6a181
9a5a85e
 
 
 
 
9a6a181
 
 
9a5a85e
 
 
 
9a6a181
9a5a85e
 
9a6a181
9a5a85e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a6a181
9a5a85e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a6a181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import uvicorn
import gradio as gr
from pathlib import Path
from huggingface_hub import Repository
import json
from db import Database
from fastapi import FastAPI
from datetime import datetime
import subprocess

HF_TOKEN = os.environ.get("HF_TOKEN")
S3_DATA_FOLDER = Path("sd-multiplayer-data")
DB_FOLDER = Path("diffusers-gallery-data")

ASSETS_URL = "https://d26smi9133w0oo.cloudfront.net/diffusers-gallery/"


repo = Repository(
    local_dir=DB_FOLDER,
    repo_type="dataset",
    clone_from="huggingface-projects/diffusers-gallery-data",
    use_auth_token=True,
)
repo.git_pull()

database = Database(DB_FOLDER)


styles_cls = ["anime", "3D", "realistic", "other"]
nsfw_cls = ["safe", "suggestive", "explicit"]


js_get_url_params = """
    function (current_model, styles, nsfw) {
        const params = new URLSearchParams(window.location.search);
        current_model.model_id = params.get("model_id") || "";
        window.history.replaceState({}, document.title, "/");
        return [current_model, styles, nsfw]
        }
"""


def next_model(query_params, styles=None, nsfw=None):
    model_id = query_params["model_id"] if 'model_id' in query_params and query_params["model_id"] else None
    print(model_id, styles, nsfw)

    with database.get_db() as db:
        if model_id:
            cursor = db.execute(
                """SELECT *,
                    SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
                FROM models
                WHERE id = ?""", (model_id,))
            row = cursor.fetchone()
            if row is None:
                raise gr.Error("Cannot find model to annotate")
        else:
            cursor = db.execute(
                """SELECT *,
                    SUM(CASE WHEN flags IS NULL THEN 1 ELSE 0 END) OVER () AS total_unflagged
                FROM models
                WHERE json_array_length(data, '$.images') > 0 AND flags IS NULL
                ORDER BY RANDOM()
                LIMIT 1""")
            row = cursor.fetchone()
            if row is None:
                raise gr.Error("Cannot find any more models to annotate")

        total_unflagged = row["total_unflagged"]
        model_id = row["id"]
        data = json.loads(row["data"])
        images = [ASSETS_URL + x for x in data["images"] if x.endswith(".jpg")]
        flags_data = json.loads(row["flags"] or "{}")
        styles = flags_data.get("styles", [])
        nsfw = flags_data.get("nsfw", None)

        title = f'''#### [Model {model_id}](https://huggingface.co/{model_id})
        **Unflaggedd** {total_unflagged}'''

        return images, title, styles, nsfw, {"model_id": model_id}


def flag_model(current_model, styles=None, nsfw=None):
    model_id = current_model["model_id"]
    print("Flagging model", model_id, styles, nsfw)
    with database.get_db() as db:
        db.execute(
            """UPDATE models SET flags = ? WHERE id = ?""", (json.dumps({"styles": styles, "nsfw": nsfw}), model_id))
    return next_model({}, styles, nsfw)


blocks = gr.Blocks()
with blocks:
    gr.Markdown('''### Diffusers Gallery annotation tool
    Please select multiple classes for each image. If you are unsure, select "other" and also check the model card for more information.
    ''')
    model_title = gr.Markdown()
    gallery = gr.Gallery(
        label="Images", show_label=False, elem_id="gallery"
    ).style(grid=[3])
    styles = gr.CheckboxGroup(
        styles_cls, info="Classify the image as one or more of the following classes")
    nsfw = gr.Radio(nsfw_cls, info="Is the image NSFW?")
    # invisible inputs to store the query params
    query_params = gr.JSON(value={}, visible=False)
    current_model = gr.State({})
    next_btn = gr.Button("Next")
    submit_btn = gr.Button("Submit")
    next_btn.click(next_model, inputs=[query_params, styles, nsfw],
                   outputs=[gallery, model_title,  styles, nsfw, current_model])

    submit_btn.click(flag_model, inputs=[current_model, styles, nsfw], outputs=[
                     gallery, model_title, styles, nsfw, current_model])

    blocks.load(next_model, inputs=[query_params, styles, nsfw],
                outputs=[gallery, model_title, styles, nsfw, current_model], _js=js_get_url_params)


app = FastAPI()


@app.get("/sync")
def read_main():
    sync_data()
    return "Synced flagged"


def sync_data():
    print("Updating DB repository")
    time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    subprocess.Popen(
        f"git add . && git commit --amend -m 'update at flags {time}' && git push --force", cwd=DB_FOLDER, shell=True)


app = gr.mount_gradio_app(app, blocks, "/")


if __name__ == "__main__":
    uvicorn.run(app, host='0.0.0.0', port=7860)