furusu commited on
Commit
029c999
1 Parent(s): 451da45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -18
app.py CHANGED
@@ -10,44 +10,42 @@ Dartを参考に、isek-ai/danbooru-tags-2023データセットでタグをシ
10
  学習後のトークン埋め込みを元に計算しています。
11
  """
12
 
13
- with open("num_to_token.json", "r") as f:
14
- num_to_token = json.load(f)
 
 
15
  with open("popular.txt", "r") as f:
16
  populars = f.read().splitlines()
17
  with open("character.txt", "r") as f:
18
  characters = f.read().splitlines()
19
- characters_populars = list(set(characters) & set(populars))
20
  with open("copyright.txt", "r") as f:
21
  copyrights = f.read().splitlines()
22
- copyrights_populars = list(set(copyrights) & set(populars))
23
  with open("general.txt", "r") as f:
24
  generals = f.read().splitlines()
25
- generals_populars = list(set(generals) & set(populars))
26
-
27
-
28
- token_to_num = {v:k for k,v in num_to_token.items()}
29
  token_embeddings = torch.load("token_embeddings.pt")
30
 
31
- tags = sorted(list(num_to_token.values()))
32
-
33
  def predict(target_tag, sort_by, category, popular):
34
  if sort_by == "descending":
35
  multiplier = 1
36
  else:
37
  multiplier = -1
38
- target_embedding = token_embeddings[int(token_to_num[target_tag])].unsqueeze(0)
39
- sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1)
40
- results = {num_to_token[str(i)]:sims[i].item() * multiplier for i in range(len(num_to_token))}
41
  if category == "general":
42
- tag_list = generals if popular == "all" else generals_populars
43
  elif category == "character":
44
- tag_list = characters if popular == "all" else characters_populars
45
  elif category == "copyright":
46
- tag_list = copyrights if popular == "all" else copyrights_populars
47
  else:
48
- tag_list = results.keys() if popular == "all" else populars
 
 
 
49
 
50
- return {k:results[k] for k in tag_list}
51
 
52
  demo = gr.Interface(
53
  fn=predict,
 
10
  学習後のトークン埋め込みを元に計算しています。
11
  """
12
 
13
+ with open("id_to_token.json", "r") as f:
14
+ id_to_token = json.load(f)
15
+ token_to_id = {v:int(k) for k,v in id_to_token.items()}
16
+
17
  with open("popular.txt", "r") as f:
18
  populars = f.read().splitlines()
19
  with open("character.txt", "r") as f:
20
  characters = f.read().splitlines()
 
21
  with open("copyright.txt", "r") as f:
22
  copyrights = f.read().splitlines()
 
23
  with open("general.txt", "r") as f:
24
  generals = f.read().splitlines()
25
+ tags = characters + copyrights + generals
 
 
 
26
  token_embeddings = torch.load("token_embeddings.pt")
27
 
 
 
28
  def predict(target_tag, sort_by, category, popular):
29
  if sort_by == "descending":
30
  multiplier = 1
31
  else:
32
  multiplier = -1
33
+ target_embedding = token_embeddings[token_to_id[target_tag]].unsqueeze(0)
34
+ sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1) * multiplier
35
+
36
  if category == "general":
37
+ tag_list = generals
38
  elif category == "character":
39
+ tag_list = characters
40
  elif category == "copyright":
41
+ tag_list = copyrights
42
  else:
43
+ tag_list = tags
44
+
45
+ if popular=="only_popular":
46
+ tag_list = list(set(tag_list) & set(populars))
47
 
48
+ return {k:sims[token_to_id[k]].item() for k in tag_list}
49
 
50
  demo = gr.Interface(
51
  fn=predict,