File size: 2,312 Bytes
1f99e57
 
 
 
3b0f531
 
 
 
 
 
 
 
029c999
 
 
 
3b0f531
 
 
 
 
 
 
 
029c999
1f99e57
 
3b0f531
1f99e57
 
 
 
029c999
 
 
3b0f531
029c999
3b0f531
029c999
3b0f531
029c999
3b0f531
029c999
 
 
 
3b0f531
029c999
1f99e57
 
 
 
c954be3
3b0f531
 
 
1f99e57
 
3b0f531
 
1f99e57
 
3b0f531
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import json
import gradio as gr

TITLE = "Danboru Tag Similarity"
DESCRIPTION = """
与えられたダンボールタグの類似度を計算します。\n
対応するタグのリストはFilesからそれぞれのテキストファイルを参照してください。(Dartと同じです)。\n
Dartを参考に、isek-ai/danbooru-tags-2023データセットでタグをシャッフルして2エポック学習しました。\n
学習後のトークン埋め込みを元に計算しています。
"""

with open("id_to_token.json", "r") as f:
    id_to_token = json.load(f)
token_to_id = {v:int(k) for k,v in id_to_token.items()}

with open("popular.txt", "r") as f:
    populars = f.read().splitlines()
with open("character.txt", "r") as f:
    characters = f.read().splitlines()
with open("copyright.txt", "r") as f:
    copyrights = f.read().splitlines()
with open("general.txt", "r") as f:
    generals = f.read().splitlines()
tags = characters + copyrights + generals
token_embeddings = torch.load("token_embeddings.pt")

def predict(target_tag, sort_by, category, popular):
    if sort_by == "descending":
        multiplier = 1
    else:
        multiplier = -1
    target_embedding = token_embeddings[token_to_id[target_tag]].unsqueeze(0)
    sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1) * multiplier
    
    if category == "general":
        tag_list = generals
    elif category == "character":
        tag_list = characters
    elif category == "copyright":
        tag_list = copyrights
    else:
        tag_list = tags
    
    if popular=="only_popular":
        tag_list = list(set(tag_list) & set(populars))
        
    return {k:sims[token_to_id[k]].item() for k in tag_list}

demo = gr.Interface(
        fn=predict, 
        inputs=[
            gr.Textbox(label="Target tag", value="otoko no ko"),
            gr.Radio(choices=["descending", "ascending"], label="Sort by", value="descending"),
            gr.Dropdown(choices=["all", "general", "character", "copyright"], value="all", label="category"),
            gr.Radio(choices=["all", "only_popular"], label="Only popular tag (count>=1000)", value="all"),
        ],
        outputs=gr.Label(num_top_classes=50),
        title=TITLE,
        description=DESCRIPTION
    )

demo.launch()