|
import streamlit as lit |
|
import torch |
|
from transformers import BartForConditionalGeneration, PreTrainedTokenizerFast |
|
|
|
@lit.cache(allow_output_mutation = True) |
|
def loadModels(): |
|
repository = "rycont/biblify" |
|
_model = BartForConditionalGeneration.from_pretrained(repository) |
|
_tokenizer = PreTrainedTokenizerFast.from_pretrained(repository) |
|
|
|
print("Loaded :)") |
|
return _model, _tokenizer |
|
|
|
lit.title("์ฑ๊ฒฝ๋งํฌ ์์ฑ๊ธฐ") |
|
lit.caption("ํ ๋ฌธ์ฅ์ ๊ฐ์ฅ ์ ๋ณํํฉ๋๋ค. ์ ๋๋ก ๋์ํ์ง ์๋ค๋ฉด ์๋ ๋งํฌ๋ก ์ด๋ํด์ฃผ์ธ์") |
|
lit.caption("https://main-biblify-space-rycont.endpoint.ainize.ai/") |
|
|
|
loadModels() |
|
|
|
MAX_LENGTH = 128 |
|
|
|
def biblifyWithBeams(beam, tokens, attention_mask): |
|
generated = model.generate( |
|
input_ids = torch.Tensor([ tokens ]).to(torch.int64), |
|
attention_mask = torch.Tensor([ attentionMasks ]).to(torch.int64), |
|
num_beams = beam, |
|
max_length = MAX_LENGTH, |
|
eos_token_id=tokenizer.eos_token_id, |
|
bad_words_ids=[[tokenizer.unk_token_id]] |
|
)[0] |
|
|
|
return tokenizer.decode( |
|
generated, |
|
).replace('<s>', '').replace('</s>', '') |
|
|
|
with lit.form("gen"): |
|
text_input = lit.text_input("๋ฌธ์ฅ ์
๋ ฅ") |
|
submitted = lit.form_submit_button("์์ฑ") |
|
|
|
if len(text_input.strip()) > 0: |
|
print(text_input) |
|
|
|
text_input = "<s>" + text_input + "</s>" |
|
|
|
tokens = tokenizer.encode(text_input) |
|
tokenLength = len(tokens) |
|
|
|
attentionMasks = [ 1 ] * tokenLength + [ 0 ] * (MAX_LENGTH - tokenLength) |
|
tokens = tokens + [ tokenizer.pad_token_id ] * (MAX_LENGTH - tokenLength) |
|
|
|
results = [] |
|
|
|
for i in range(10)[5:]: |
|
generated = biblifyWithBeams( |
|
i + 1, |
|
tokens, |
|
attentionMasks |
|
) |
|
if generated in results: |
|
print("์ค๋ณต๋จ") |
|
continue |
|
|
|
results.append(generated) |
|
|
|
with lit.expander(str(len(results)) + "๋ฒ์งธ ๊ฒฐ๊ณผ (" + str(i +1) + ")", True): |
|
lit.write(generated) |
|
print(generated) |
|
|
|
lit.caption("๋ฐ " + str(5 - len(results)) + " ๊ฐ์ ์ค๋ณต๋ ๊ฒฐ๊ณผ") |