furusu commited on
Commit
bed5a04
·
verified ·
1 Parent(s): 39476de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import timm
3
- import numpy as np
4
- import faiss
5
  import pandas as pd
6
 
7
 
@@ -11,28 +10,29 @@ DESCRIPTION = """
11
  """
12
 
13
  model = timm.create_model(f"hf_hub:SmilingWolf/wd-eva02-large-tagger-v3", pretrained=True)
14
- head = model.head.weight.data.cpu().numpy()
15
  del model
16
  df = pd.read_csv(f"https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3/resolve/main/selected_tags.csv")
17
  id2label = df["name"].to_dict()
18
  label2id = {v:k for k,v in id2label.items()}
 
 
 
19
 
20
- faiss.normalize_L2(head)
21
- index = faiss.IndexFlatIP(head.shape[1])
22
- index.add(head)
 
23
 
24
- def predict(target_tag):
25
- target_id = label2id[target_tag]
26
- query = head[target_id:target_id+1]
27
- k = 50
28
- target_id = label2id[target_tag]
29
- distances, indices = index.search(query, k)
30
- return {id2label[indice]:distance for indice, distance in zip(indices[0], distances[0])}
31
 
32
  demo = gr.Interface(
33
  fn=predict,
34
  inputs=[
35
- gr.Dropdown(list(label2id.keys()), label="Target tag", value="otoko_no_ko"),
 
36
  ],
37
  outputs=gr.Label(num_top_classes=50),
38
  title=TITLE,
 
1
  import gradio as gr
2
  import timm
3
+ import torch
 
4
  import pandas as pd
5
 
6
 
 
10
  """
11
 
12
  model = timm.create_model(f"hf_hub:SmilingWolf/wd-eva02-large-tagger-v3", pretrained=True)
13
+ head = model.head.weight.data
14
  del model
15
  df = pd.read_csv(f"https://huggingface.co/SmilingWolf/wd-eva02-large-tagger-v3/resolve/main/selected_tags.csv")
16
  id2label = df["name"].to_dict()
17
  label2id = {v:k for k,v in id2label.items()}
18
+ general_tags = df[df["category"] == 0].index
19
+ character_tags = df[df["category"] == 4].index
20
+ all_tags = df.index
21
 
22
+ def predict(target_tags, search_in):
23
+ target_tags = [tag.strip().replace(" ", "_") for tag in target_tags.split(",")]
24
+ target_ids = [label2id[tag] for tag in target_tags]
25
+ query = head[target_ids].unsqueeze(1)
26
 
27
+ sim = torch.cosine_similarity(query, head.unsqueeze(0), dim=2).mean(dim=0)
28
+ tags = general_tags if search_in == "general" else character_tags if search_in == "character" else all_tags
29
+ return {id2label[i]: sim[i].item() for i in tags}
 
 
 
 
30
 
31
  demo = gr.Interface(
32
  fn=predict,
33
  inputs=[
34
+ gr.Text(value="pink hair, braid", label="Target tags"),
35
+ gr.Dropdown(["all", "general", "character"], label="Search in", value="all")
36
  ],
37
  outputs=gr.Label(num_top_classes=50),
38
  title=TITLE,