import torch import json import gradio as gr TITLE = "Danboru Tag Similarity" DESCRIPTION = """ 与えられたダンボールタグの類似度を計算します。\n 対応するタグのリストはFilesからそれぞれのテキストファイルを参照してください。(Dartと同じです)。\n Dartを参考に、isek-ai/danbooru-tags-2023データセットでタグをシャッフルして2エポック学習しました。\n 学習後のトークン埋め込みを元に計算しています。 """ with open("id_to_token.json", "r") as f: id_to_token = json.load(f) token_to_id = {v:int(k) for k,v in id_to_token.items()} with open("popular.txt", "r") as f: populars = f.read().splitlines() with open("character.txt", "r") as f: characters = f.read().splitlines() with open("copyright.txt", "r") as f: copyrights = f.read().splitlines() with open("general.txt", "r") as f: generals = f.read().splitlines() tags = characters + copyrights + generals token_embeddings = torch.load("token_embeddings.pt") def predict(target_tag, sort_by, category, popular): if sort_by == "descending": multiplier = 1 else: multiplier = -1 target_embedding = token_embeddings[token_to_id[target_tag]].unsqueeze(0) sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1) * multiplier if category == "general": tag_list = generals elif category == "character": tag_list = characters elif category == "copyright": tag_list = copyrights else: tag_list = tags if popular=="only_popular": tag_list = list(set(tag_list) & set(populars)) return {k:sims[token_to_id[k]].item() for k in tag_list} demo = gr.Interface( fn=predict, inputs=[ gr.Textbox(label="Target tag", value="otoko no ko"), gr.Radio(choices=["descending", "ascending"], label="Sort by", value="descending"), gr.Dropdown(choices=["all", "general", "character", "copyright"], value="all", label="category"), gr.Radio(choices=["all", "only_popular"], label="Only popular tag (count>=1000)", value="all"), ], outputs=gr.Label(num_top_classes=50), title=TITLE, description=DESCRIPTION ) demo.launch()