Spaces:
Runtime error
Runtime error
import torch | |
import json | |
import gradio as gr | |
with open("num_to_token.json", "r") as f: | |
num_to_token = json.load(f) | |
token_to_num = {v:k for k,v in num_to_token.items()} | |
token_embeddings = torch.load("token_embeddings.pt") | |
tags = sorted(list(num_to_token.values())) | |
def predict(target_tag, sort_by="descend"): | |
if sort_by == "descending": | |
multiplier = 1 | |
else: | |
multiplier = -1 | |
target_embedding = token_embeddings[int(token_to_num[target_tag])].unsqueeze(0) | |
sims = torch.cosine_similarity(target_embedding, token_embeddings, dim=1) | |
results = {num_to_token[str(i)]:sims[i].item() * multiplier for i in range(len(num_to_token))} | |
return results | |
demo = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Dropdown(choices=tags, label="Target tag", value="otoko no ko"), | |
gr.Dropdown(choices=["ascending", "descending"], label="Sort by", value="descending") | |
], | |
outputs=gr.Label(num_top_classes=50), | |
) | |
demo.launch() |