Unggi's picture
first commit
eb1ba05
raw history blame
No virus
2.31 kB
import numpy as np
import itertools
from konlpy.tag import Okt
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import gradio as gr
# make function using import pip to install torch
import pip
pip.main(['install', 'torch'])
pip.main(['install', 'transformers'])
import torch
import transformers
from transformers import BertTokenizerFast
from transformers import AutoModel
def make_candiadte(prompt):
okt = Okt()
tokenized_doc = okt.pos(prompt)
tokenized_nouns = ' '.join([word[0] for word in tokenized_doc if word[1] == 'Noun'])
n_gram_range = (2, 3)
count = CountVectorizer(ngram_range=n_gram_range).fit([tokenized_nouns])
candidates = count.get_feature_names_out()
return candidates
# saved_model
def load_model():
pretrained_model_name = "kykim/bert-kor-base"
tokenizer = BertTokenizerFast.from_pretrained(pretrained_model_name)
model = AutoModel.from_pretrained(pretrained_model_name)
return model, tokenizer
# main
def inference(prompt):
candidates = make_candiadte(prompt)
model, tokenizer = load_model()
input_ids = tokenizer.encode(prompt)
input_ids = torch.tensor(input_ids).unsqueeze(0)
doc_embedding = model(input_ids)["pooler_output"]
top_n = 5
words = []
distances = []
for word in candidates:
input_ids = tokenizer.encode(word)
input_ids = torch.tensor(input_ids).unsqueeze(0)
word_embedding = model(input_ids)["pooler_output"]
distance = torch.cosine_similarity(doc_embedding, word_embedding, dim=1).item()
words.append(word)
distances.append(distance)
#print(word, torch.cosine_similarity(doc_embedding, word_embedding, dim=1).item())
cos_df = pd.DataFrame({'word':words, 'distance':distances})
# sort by distance
cos_df = cos_df.sort_values(by='distance', ascending=False)
# top n
cos_df = cos_df[:top_n]
cos_df["word"].values
outputs = " ".join(["#" + s for s in cos_df["word"].values])
outputs
return outputs
demo = gr.Interface(
fn=inference,
inputs="text",
outputs="text" #return κ°’
).launch() # launch(share=True)λ₯Ό μ„€μ •ν•˜λ©΄ μ™ΈλΆ€μ—μ„œ 접속 κ°€λŠ₯ν•œ 링크가 생성됨
demo.launch()