furusu commited on
Commit
3b0f531
1 Parent(s): b6f60bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -5
app.py CHANGED
@@ -2,14 +2,35 @@ import torch
2
  import json
3
  import gradio as gr
4
 
 
 
 
 
 
 
 
 
5
  with open("num_to_token.json", "r") as f:
6
  num_to_token = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  token_to_num = {v:k for k,v in num_to_token.items()}
8
  token_embeddings = torch.load("token_embeddings.pt")
9
 
10
  tags = sorted(list(num_to_token.values()))
11
 
12
- def predict(target_tag, sort_by="descend"):
13
  if sort_by == "descending":
14
  multiplier = 1
15
  else:
@@ -17,16 +38,28 @@ def predict(target_tag, sort_by="descend"):
17
  target_embedding = token_embeddings[int(token_to_num[target_tag])].unsqueeze(0)
18
  sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1)
19
  results = {num_to_token[str(i)]:sims[i].item() * multiplier for i in range(len(num_to_token))}
20
-
21
- return results
 
 
 
 
 
 
 
 
22
 
23
  demo = gr.Interface(
24
  fn=predict,
25
  inputs=[
26
  gr.Textbox(label="Target tag", value="otoko no ko"),
27
- gr.Radio(choices=["descending", "ascending"], label="Sort by", value="descending")
 
 
28
  ],
29
  outputs=gr.Label(num_top_classes=50),
 
 
30
  )
31
 
32
- demo.launch()
 
2
  import json
3
  import gradio as gr
4
 
5
+ TITLE = "Danboru Tag Similarity"
6
+ DESCRIPTION = """
7
+ 与えられたダンボールタグの類似度を計算します。\n
8
+ 対応するタグのリストはFilesからそれぞれのテキストファイルを参照してください。(Dartと同じです)。\n
9
+ Dartを参考に、isek-ai/danbooru-tags-2023データセットでタグをシャッフルして2エポック学習しました。\n
10
+ 学習後のトークン埋め込みを元に計算しています。
11
+ """
12
+
13
  with open("num_to_token.json", "r") as f:
14
  num_to_token = json.load(f)
15
+ with open("popular.txt", "r") as f:
16
+ populars = f.read().splitlines()
17
+ with open("character.txt", "r") as f:
18
+ characters = f.read().splitlines()
19
+ characters_populars = list(set(characters) & set(populars))
20
+ with open("copyright.txt", "r") as f:
21
+ copyrights = f.read().splitlines()
22
+ copyrights_populars = list(set(copyrights) & set(populars))
23
+ with open("general.txt", "r") as f:
24
+ generals = f.read().splitlines()
25
+ generals_populars = list(set(generals) & set(populars))
26
+
27
+
28
  token_to_num = {v:k for k,v in num_to_token.items()}
29
  token_embeddings = torch.load("token_embeddings.pt")
30
 
31
  tags = sorted(list(num_to_token.values()))
32
 
33
+ def predict(target_tag, sort_by, category, popular):
34
  if sort_by == "descending":
35
  multiplier = 1
36
  else:
 
38
  target_embedding = token_embeddings[int(token_to_num[target_tag])].unsqueeze(0)
39
  sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1)
40
  results = {num_to_token[str(i)]:sims[i].item() * multiplier for i in range(len(num_to_token))}
41
+ if category == "general":
42
+ tag_list = generals if popular == "all" else generals_populars
43
+ elif category == "character":
44
+ tag_list = characters if popular == "all" else characters_populars
45
+ elif category == "copyright":
46
+ tag_list = copyrights if popular == "all" else copyrights_populars
47
+ else:
48
+ tag_list = results.keys() if popular == "all" else populars
49
+
50
+ return {k:results[k] for k in tag_list}
51
 
52
  demo = gr.Interface(
53
  fn=predict,
54
  inputs=[
55
  gr.Textbox(label="Target tag", value="otoko no ko"),
56
+ gr.Radio(choices=["descending", "ascending"], label="Sort by", value="descending"),
57
+ gr.Dropdown(choices=["all", "general", "character", "copyright"], value="all", label="category"),
58
+ gr.Radio(choices=["all", "only_popular"], label="Only popular tag (count>=1000)", value="all"),
59
  ],
60
  outputs=gr.Label(num_top_classes=50),
61
+ title=TITLE,
62
+ description=DESCRIPTION
63
  )
64
 
65
+ demo.launch()