Jonghyun Moon
deploy streamlit
26f786b
raw
history blame
12.4 kB
import pandas as pd
import requests
import streamlit as st
from streamlit_lottie import st_lottie
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
# Page Config
st.set_page_config(
page_title="๋…ธ๋ž˜ ๊ฐ€์‚ฌ nํ–‰์‹œ Beta",
page_icon="๐Ÿ’Œ",
layout="wide"
)
# st.text(os.listdir(os.curdir))
### Model
tokenizer = AutoTokenizer.from_pretrained("wumusill/final_project_kogpt2")
@st.cache(show_spinner=False)
def load_model():
model = AutoModelForCausalLM.from_pretrained("wumusill/final_project_kogpt2")
return model
model = load_model()
@st.cache(show_spinner=False)
def get_word():
word = pd.read_csv("ballad_word.csv", encoding="cp949")
return word
word = get_word()
one = word[word["0"].str.startswith("ํ•œ")].sample(1).values[0][0]
# st.header(type(one))
# st.header(one)
# Class : Dict ์ค‘๋ณต ํ‚ค ์ถœ๋ ฅ
class poem(object):
def __init__(self,letter):
self.letter = letter
def __str__(self):
return self.letter
def __repr__(self):
return "'"+self.letter+"'"
def beta_poem(input_letter):
# ๋‘์Œ ๋ฒ•์น™ ์‚ฌ์ „
dooeum = {"๋ผ":"๋‚˜", "๋ฝ":"๋‚™", "๋ž€":"๋‚œ", "๋ž„":"๋‚ ", "๋žŒ":"๋‚จ", "๋ž":"๋‚ฉ", "๋ž‘":"๋‚ญ",
"๋ž˜":"๋‚ด", "๋žญ":"๋ƒ‰", "๋ƒ‘":"์•ฝ", "๋žต":"์•ฝ", "๋ƒฅ":"์–‘", "๋Ÿ‰":"์–‘", "๋…€":"์—ฌ",
"๋ ค":"์—ฌ", "๋…":"์—ญ", "๋ ฅ":"์—ญ", "๋…„":"์—ฐ", "๋ จ":"์—ฐ", "๋…ˆ":"์—ด", "๋ ฌ":"์—ด",
"๋…":"์—ผ", "๋ ด":"์—ผ", "๋ ต":"์—ฝ", "๋…•":"์˜", "๋ น":"์˜", "๋…œ":"์˜ˆ", "๋ก€":"์˜ˆ",
"๋กœ":"๋…ธ", "๋ก":"๋…น", "๋ก ":"๋…ผ", "๋กฑ":"๋†", "๋ขฐ":"๋‡Œ", "๋‡จ":"์š”", "๋ฃŒ":"์š”",
"๋ฃก":"์šฉ", "๋ฃจ":"๋ˆ„", "๋‰ด":"์œ ", "๋ฅ˜":"์œ ", "๋‰ต":"์œก", "๋ฅ™":"์œก", "๋ฅœ":"์œค",
"๋ฅ ":"์œจ", "๋ฅญ":"์œต", "๋ฅต":"๋Š‘", "๋ฆ„":"๋Š ", "๋ฆ‰":"๋Šฅ", "๋‹ˆ":"์ด", "๋ฆฌ":"์ด",
"๋ฆฐ":'์ธ', '๋ฆผ':'์ž„', '๋ฆฝ':'์ž…'}
# ๊ฒฐ๊ณผ๋ฌผ์„ ๋‹ด์„ list
res_l = []
len_sequence = 0
# ํ•œ ๊ธ€์ž์”ฉ ์ธ๋ฑ์Šค์™€ ํ•จ๊ป˜ ๊ฐ€์ ธ์˜ด
for idx, val in enumerate(input_letter):
# ๋‘์Œ ๋ฒ•์น™ ์ ์šฉ
if val in dooeum.keys():
val = dooeum[val]
# ๋ฐœ๋ผ๋“œ์— ์žˆ๋Š” ๋‹จ์–ด ์ ์šฉ
try:
one = word[word["0"].str.startswith(val)].sample(1).values[0][0]
# st.text(one)
except:
one = val
# ์ข€๋” ๋งค๋„๋Ÿฌ์šด ์‚ผํ–‰์‹œ๋ฅผ ์œ„ํ•ด ์ด์ „ ๋ฌธ์žฅ์ด๋ž‘ ํ˜„์žฌ ์Œ์ ˆ ์—ฐ๊ฒฐ
# ์ดํ›„ generate ๋œ ๋ฌธ์žฅ์—์„œ ์ด์ „ ๋ฌธ์žฅ์— ๋Œ€ํ•œ ๋ฐ์ดํ„ฐ ์ œ๊ฑฐ
link_with_pre_sentence = (" ".join(res_l)+ " " + one + " " if idx != 0 else one).strip()
# print(link_with_pre_sentence)
# ์—ฐ๊ฒฐ๋œ ๋ฌธ์žฅ์„ ์ธ์ฝ”๋”ฉ
input_ids = tokenizer.encode(link_with_pre_sentence, add_special_tokens=False, return_tensors="pt")
# ์ธ์ฝ”๋”ฉ ๊ฐ’์œผ๋กœ ๋ฌธ์žฅ ์ƒ์„ฑ
output_sequence = model.generate(
input_ids=input_ids,
do_sample=True,
max_length=42,
min_length=len_sequence + 2,
temperature=0.9,
repetition_penalty=1.5,
no_repeat_ngram_size=2)
# ์ƒ์„ฑ๋œ ๋ฌธ์žฅ ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜ (์ธ์ฝ”๋”ฉ ๋˜์–ด์žˆ๊ณ , ์ƒ์„ฑ๋œ ๋ฌธ์žฅ ๋’ค๋กœ padding ์ด ์žˆ๋Š” ์ƒํƒœ)
generated_sequence = output_sequence.tolist()[0]
# padding index ์•ž๊นŒ์ง€ slicing ํ•จ์œผ๋กœ์จ padding ์ œ๊ฑฐ, padding์ด ์—†์„ ์ˆ˜๋„ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์กฐ๊ฑด๋ฌธ ํ™•์ธ ํ›„ ์ œ๊ฑฐ
# ์‚ฌ์šฉํ•  generated_sequence ๊ฐ€ 5๋ณด๋‹ค ์งง์œผ๋ฉด ๊ฐ•์ œ์ ์œผ๋กœ ๊ธธ์ด๋ฅผ 8๋กœ ํ•ด์ค€๋‹ค...
if tokenizer.pad_token_id in generated_sequence:
check_index = generated_sequence.index(tokenizer.pad_token_id)
check_index = check_index if check_index-len_sequence > 3 else len_sequence + 8
generated_sequence = generated_sequence[:check_index]
word_encode = tokenizer.encode(one, add_special_tokens=False, return_tensors="pt").tolist()[0][0]
split_index = len(generated_sequence) - 1 - generated_sequence[::-1].index(word_encode)
# ์ฒซ ๊ธ€์ž๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด, generate ๋œ ์Œ์ ˆ๋งŒ ๊ฒฐ๊ณผ๋ฌผ list์— ๋“ค์–ด๊ฐˆ ์ˆ˜ ์žˆ๊ฒŒ ์•ž ๋ฌธ์žฅ์— ๋Œ€ํ•œ ์ธ์ฝ”๋”ฉ ๊ฐ’ ์ œ๊ฑฐ
generated_sequence = generated_sequence[split_index:]
# print(tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True))
# ๋‹ค์Œ ์Œ์ ˆ์„ ์œ„ํ•ด ๊ธธ์ด ๊ฐฑ์‹ 
len_sequence += len([elem for elem in generated_sequence if elem not in(tokenizer.all_special_ids)])
# ๊ฒฐ๊ณผ๋ฌผ ๋””์ฝ”๋”ฉ
decoded_sequence = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
# ๊ฒฐ๊ณผ๋ฌผ ๋ฆฌ์ŠคํŠธ์— ๋‹ด๊ธฐ
res_l.append(decoded_sequence)
poem_dict = {"Type":"beta"}
for letter, res in zip(input_letter, res_l):
# decode_res = tokenizer.decode(res, clean_up_tokenization_spaces=True, skip_special_tokens=True)
poem_dict[poem(letter)] = res
return poem_dict
def alpha_poem(input_letter):
# ๋‘์Œ ๋ฒ•์น™ ์‚ฌ์ „
dooeum = {"๋ผ":"๋‚˜", "๋ฝ":"๋‚™", "๋ž€":"๋‚œ", "๋ž„":"๋‚ ", "๋žŒ":"๋‚จ", "๋ž":"๋‚ฉ", "๋ž‘":"๋‚ญ",
"๋ž˜":"๋‚ด", "๋žญ":"๋ƒ‰", "๋ƒ‘":"์•ฝ", "๋žต":"์•ฝ", "๋ƒฅ":"์–‘", "๋Ÿ‰":"์–‘", "๋…€":"์—ฌ",
"๋ ค":"์—ฌ", "๋…":"์—ญ", "๋ ฅ":"์—ญ", "๋…„":"์—ฐ", "๋ จ":"์—ฐ", "๋…ˆ":"์—ด", "๋ ฌ":"์—ด",
"๋…":"์—ผ", "๋ ด":"์—ผ", "๋ ต":"์—ฝ", "๋…•":"์˜", "๋ น":"์˜", "๋…œ":"์˜ˆ", "๋ก€":"์˜ˆ",
"๋กœ":"๋…ธ", "๋ก":"๋…น", "๋ก ":"๋…ผ", "๋กฑ":"๋†", "๋ขฐ":"๋‡Œ", "๋‡จ":"์š”", "๋ฃŒ":"์š”",
"๋ฃก":"์šฉ", "๋ฃจ":"๋ˆ„", "๋‰ด":"์œ ", "๋ฅ˜":"์œ ", "๋‰ต":"์œก", "๋ฅ™":"์œก", "๋ฅœ":"์œค",
"๋ฅ ":"์œจ", "๋ฅญ":"์œต", "๋ฅต":"๋Š‘", "๋ฆ„":"๋Š ", "๋ฆ‰":"๋Šฅ", "๋‹ˆ":"์ด", "๋ฆฌ":"์ด",
"๋ฆฐ":'์ธ', '๋ฆผ':'์ž„', '๋ฆฝ':'์ž…'}
# ๊ฒฐ๊ณผ๋ฌผ์„ ๋‹ด์„ list
res_l = []
# ํ•œ ๊ธ€์ž์”ฉ ์ธ๋ฑ์Šค์™€ ํ•จ๊ป˜ ๊ฐ€์ ธ์˜ด
for idx, val in enumerate(input_letter):
# ๋‘์Œ ๋ฒ•์น™ ์ ์šฉ
if val in dooeum.keys():
val = dooeum[val]
while True:
# ๋งŒ์•ฝ idx ๊ฐ€ 0 ์ด๋ผ๋ฉด == ์ฒซ ๊ธ€์ž
if idx == 0:
# ์ฒซ ๊ธ€์ž ์ธ์ฝ”๋”ฉ
input_ids = tokenizer.encode(
val, add_special_tokens=False, return_tensors="pt")
# print(f"{idx}๋ฒˆ ์ธ์ฝ”๋”ฉ : {input_ids}\n") # 2์ฐจ์› ํ…์„œ
# ์ฒซ ๊ธ€์ž ์ธ์ฝ”๋”ฉ ๊ฐ’์œผ๋กœ ๋ฌธ์žฅ ์ƒ์„ฑ
output_sequence = model.generate(
input_ids=input_ids,
do_sample=True,
max_length=42,
min_length=5,
temperature=0.9,
repetition_penalty=1.7,
no_repeat_ngram_size=2)[0]
# print("์ฒซ ๊ธ€์ž ์ธ์ฝ”๋”ฉ ํ›„ generate ๊ฒฐ๊ณผ:", output_sequence, "\n") # tensor
# ์ฒซ ๊ธ€์ž๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด
else:
# ํ•œ ์Œ์ ˆ
input_ids = tokenizer.encode(
val, add_special_tokens=False, return_tensors="pt")
# print(f"{idx}๋ฒˆ ์งธ ๊ธ€์ž ์ธ์ฝ”๋”ฉ : {input_ids} \n")
# ์ข€๋” ๋งค๋„๋Ÿฌ์šด ์‚ผํ–‰์‹œ๋ฅผ ์œ„ํ•ด ์ด์ „ ์ธ์ฝ”๋”ฉ๊ณผ ์ง€๊ธˆ ์ธ์ฝ”๋”ฉ ์—ฐ๊ฒฐ
link_with_pre_sentence = torch.cat((generated_sequence, input_ids[0]), 0)
link_with_pre_sentence = torch.reshape(link_with_pre_sentence, (1, len(link_with_pre_sentence)))
# print(f"์ด์ „ ํ…์„œ์™€ ์—ฐ๊ฒฐ๋œ ํ…์„œ {link_with_pre_sentence} \n")
# ์ธ์ฝ”๋”ฉ ๊ฐ’์œผ๋กœ ๋ฌธ์žฅ ์ƒ์„ฑ
output_sequence = model.generate(
input_ids=link_with_pre_sentence,
do_sample=True,
max_length=42,
min_length=5,
temperature=0.9,
repetition_penalty=1.7,
no_repeat_ngram_size=2)[0]
# print(f"{idx}๋ฒˆ ์ธ์ฝ”๋”ฉ ํ›„ generate : {output_sequence}")
# ์ƒ์„ฑ๋œ ๋ฌธ์žฅ ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜ (์ธ์ฝ”๋”ฉ ๋˜์–ด์žˆ๊ณ , ์ƒ์„ฑ๋œ ๋ฌธ์žฅ ๋’ค๋กœ padding ์ด ์žˆ๋Š” ์ƒํƒœ)
generated_sequence = output_sequence.tolist()
# print(f"{idx}๋ฒˆ ์ธ์ฝ”๋”ฉ ๋ฆฌ์ŠคํŠธ : {generated_sequence} \n")
# padding index ์•ž๊นŒ์ง€ slicing ํ•จ์œผ๋กœ์จ padding ์ œ๊ฑฐ, padding์ด ์—†์„ ์ˆ˜๋„ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์กฐ๊ฑด๋ฌธ ํ™•์ธ ํ›„ ์ œ๊ฑฐ
if tokenizer.pad_token_id in generated_sequence:
generated_sequence = generated_sequence[:generated_sequence.index(tokenizer.pad_token_id)]
generated_sequence = torch.tensor(generated_sequence)
# print(f"{idx}๋ฒˆ ์ธ์ฝ”๋”ฉ ๋ฆฌ์ŠคํŠธ ํŒจ๋”ฉ ์ œ๊ฑฐ ํ›„ ๋‹ค์‹œ ํ…์„œ : {generated_sequence} \n")
# ์ฒซ ๊ธ€์ž๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด, generate ๋œ ์Œ์ ˆ๋งŒ ๊ฒฐ๊ณผ๋ฌผ list์— ๋“ค์–ด๊ฐˆ ์ˆ˜ ์žˆ๊ฒŒ ์•ž ๋ฌธ์žฅ์— ๋Œ€ํ•œ ์ธ์ฝ”๋”ฉ ๊ฐ’ ์ œ๊ฑฐ
# print(generated_sequence)
if idx != 0:
# ์ด์ „ ๋ฌธ์žฅ์˜ ๊ธธ์ด ์ดํ›„๋กœ ์Šฌ๋ผ์ด์‹ฑํ•ด์„œ ์•ž ๋ฌธ์žฅ ์ œ๊ฑฐ
generated_sequence = generated_sequence[len_sequence:]
len_sequence = len(generated_sequence)
# print("len_seq", len_sequence)
# ์Œ์ ˆ ๊ทธ๋Œ€๋กœ ๋ฑ‰์œผ๋ฉด ๋‹ค์‹œ ํ•ด์™€, ์•„๋‹ˆ๋ฉด while๋ฌธ ํƒˆ์ถœ
if len_sequence > 1:
break
# ๊ฒฐ๊ณผ๋ฌผ ๋ฆฌ์ŠคํŠธ์— ๋‹ด๊ธฐ
res_l.append(generated_sequence)
poem_dict = {"Type":"alpha"}
for letter, res in zip(input_letter, res_l):
decode_res = tokenizer.decode(res, clean_up_tokenization_spaces=True, skip_special_tokens=True)
poem_dict[poem(letter)] = decode_res
return poem_dict
# Image(.gif)
@st.cache(show_spinner=False)
def load_lottieurl(url: str):
r = requests.get(url)
if r.status_code != 200:
return None
return r.json()
lottie_url = "https://assets7.lottiefiles.com/private_files/lf30_fjln45y5.json"
lottie_json = load_lottieurl(lottie_url)
st_lottie(lottie_json, speed=1, height=200, key="initial")
# Title
row0_spacer1, row0_1, row0_spacer2, row0_2, row0_spacer3 = st.columns(
(0.01, 2, 0.05, 0.5, 0.01)
)
with row0_1:
st.markdown("# ํ•œ๊ธ€ ๋…ธ๋ž˜ ๊ฐ€์‚ฌ nํ–‰์‹œโœ")
st.markdown("### ๐Ÿฆ๋ฉ‹์Ÿ์ด์‚ฌ์ž์ฒ˜๋Ÿผ AIS7๐Ÿฆ - ํŒŒ์ด๋„ ํ”„๋กœ์ ํŠธ")
with row0_2:
st.write("")
st.write("")
st.write("")
st.subheader("1์กฐ - ํ•ดํŒŒ๋ฆฌ")
st.write("์ด์ง€ํ˜œ, ์ตœ์ง€์˜, ๊ถŒ์†Œํฌ, ๋ฌธ์ข…ํ˜„, ๊ตฌ์žํ˜„, ๊น€์˜์ค€")
st.write('---')
# Explanation
row1_spacer1, row1_1, row1_spacer2 = st.columns((0.01, 0.01, 0.01))
with row1_1:
st.markdown("### nํ–‰์‹œ ๊ฐ€์ด๋“œ๋ผ์ธ")
st.markdown("1. ํ•˜๋‹จ์— ์žˆ๋Š” ํ…์ŠคํŠธ๋ฐ”์— 5์ž ์ดํ•˜ ๋‹จ์–ด๋ฅผ ๋„ฃ์–ด์ฃผ์„ธ์š”")
st.markdown("2. 'nํ–‰์‹œ ์ œ์ž‘ํ•˜๊ธฐ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•ด์ฃผ์„ธ์š”")
st.markdown("* nํ–‰์‹œ ํƒ€์ž… ์„ค์ •\n"
" * Alpha ver. : ๋ชจ๋ธ์ด ์ฒซ ์Œ์ ˆ๋ถ€ํ„ฐ ์ƒ์„ฑ\n"
" * Beta ver. : ์ฒซ ์Œ์ ˆ์„ ๋ฐ์ดํ„ฐ์…‹์—์„œ ์ฐพ๊ณ , ๋‹ค์Œ ๋ถ€๋ถ„์„ ์ƒ์„ฑ")
st.write('---')
# Model & Input
row2_spacer1, row2_1, row2_spacer2= st.columns((0.01, 0.01, 0.01))
col1, col2 = st.columns(2)
# Word Input
with row2_1:
with col1:
genre = st.radio(
"nํ–‰์‹œ ํƒ€์ž… ์„ ํƒ",
('Alpha', 'Beta(test์ค‘)'))
if genre == 'Alpha':
n_line_poem = alpha_poem
else:
n_line_poem = beta_poem
with col2:
word_input = st.text_input(
"nํ–‰์‹œ์— ์‚ฌ์šฉํ•  ๋‹จ์–ด๋ฅผ ์ ๊ณ  ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”.(์ตœ๋Œ€ 5์ž) ๐Ÿ‘‡",
placeholder='ํ•œ๊ธ€ ๋‹จ์–ด๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”',
max_chars=5
)
word_input = re.sub("[^๊ฐ€-ํžฃ]", "", word_input)
if st.button('nํ–‰์‹œ ์ œ์ž‘ํ•˜๊ธฐ'):
if word_input == "":
st.error("์˜จ์ „ํ•œ ํ•œ๊ธ€ ๋‹จ์–ด๋ฅผ ์‚ฌ์šฉํ•ด์ฃผ์„ธ์š”!")
else:
st.write("nํ–‰์‹œ ๋‹จ์–ด : ", word_input)
with st.spinner('์ž ์‹œ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”...'):
result = n_line_poem(word_input)
st.success('์™„๋ฃŒ๋์Šต๋‹ˆ๋‹ค!')
for r in result:
st.write(f'{r} : {result[r]}')