dk-davidekim's picture
Upload 7 files (#1)
2220c11
raw
history blame
7.39 kB
import requests
import streamlit as st
from streamlit_lottie import st_lottie
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
# Page Config
st.set_page_config(
page_title="๋…ธ๋ž˜ ๊ฐ€์‚ฌ nํ–‰์‹œ",
page_icon="๐Ÿ’Œ",
layout="wide"
)
### Model
tokenizer = AutoTokenizer.from_pretrained("wumusill/final_project_kogpt2")
@st.cache(show_spinner=False)
def load_model():
model = AutoModelForCausalLM.from_pretrained("wumusill/final_project_kogpt2")
return model
model = load_model()
# Class : Dict ์ค‘๋ณต ํ‚ค ์ถœ๋ ฅ
class poem(object):
def __init__(self,letter):
self.letter = letter
def __str__(self):
return self.letter
def __repr__(self):
return "'"+self.letter+"'"
def n_line_poem(input_letter):
# ๋‘์Œ ๋ฒ•์น™ ์‚ฌ์ „
dooeum = {"๋ผ":"๋‚˜", "๋ฝ":"๋‚™", "๋ž€":"๋‚œ", "๋ž„":"๋‚ ", "๋žŒ":"๋‚จ", "๋ž":"๋‚ฉ", "๋ž‘":"๋‚ญ",
"๋ž˜":"๋‚ด", "๋žญ":"๋ƒ‰", "๋ƒ‘":"์•ฝ", "๋žต":"์•ฝ", "๋ƒฅ":"์–‘", "๋Ÿ‰":"์–‘", "๋…€":"์—ฌ",
"๋ ค":"์—ฌ", "๋…":"์—ญ", "๋ ฅ":"์—ญ", "๋…„":"์—ฐ", "๋ จ":"์—ฐ", "๋…ˆ":"์—ด", "๋ ฌ":"์—ด",
"๋…":"์—ผ", "๋ ด":"์—ผ", "๋ ต":"์—ฝ", "๋…•":"์˜", "๋ น":"์˜", "๋…œ":"์˜ˆ", "๋ก€":"์˜ˆ",
"๋กœ":"๋…ธ", "๋ก":"๋…น", "๋ก ":"๋…ผ", "๋กฑ":"๋†", "๋ขฐ":"๋‡Œ", "๋‡จ":"์š”", "๋ฃŒ":"์š”",
"๋ฃก":"์šฉ", "๋ฃจ":"๋ˆ„", "๋‰ด":"์œ ", "๋ฅ˜":"์œ ", "๋‰ต":"์œก", "๋ฅ™":"์œก", "๋ฅœ":"์œค",
"๋ฅ ":"์œจ", "๋ฅญ":"์œต", "๋ฅต":"๋Š‘", "๋ฆ„":"๋Š ", "๋ฆ‰":"๋Šฅ", "๋‹ˆ":"์ด", "๋ฆฌ":"์ด",
"๋ฆฐ":'์ธ', '๋ฆผ':'์ž„', '๋ฆฝ':'์ž…'}
# ๊ฒฐ๊ณผ๋ฌผ์„ ๋‹ด์„ list
res_l = []
# ํ•œ ๊ธ€์ž์”ฉ ์ธ๋ฑ์Šค์™€ ํ•จ๊ป˜ ๊ฐ€์ ธ์˜ด
for idx, val in enumerate(input_letter):
# ๋‘์Œ ๋ฒ•์น™ ์ ์šฉ
if val in dooeum.keys():
val = dooeum[val]
while True:
# ๋งŒ์•ฝ idx ๊ฐ€ 0 ์ด๋ผ๋ฉด == ์ฒซ ๊ธ€์ž
if idx == 0:
# ์ฒซ ๊ธ€์ž ์ธ์ฝ”๋”ฉ
input_ids = tokenizer.encode(
val, add_special_tokens=False, return_tensors="pt")
# print(f"{idx}๋ฒˆ ์ธ์ฝ”๋”ฉ : {input_ids}\n") # 2์ฐจ์› ํ…์„œ
# ์ฒซ ๊ธ€์ž ์ธ์ฝ”๋”ฉ ๊ฐ’์œผ๋กœ ๋ฌธ์žฅ ์ƒ์„ฑ
output_sequence = model.generate(
input_ids=input_ids,
do_sample=True, max_length=42,
min_length=5, temperature=0.9, repetition_penalty=1.5,
no_repeat_ngram_size=2)[0]
# print("์ฒซ ๊ธ€์ž ์ธ์ฝ”๋”ฉ ํ›„ generate ๊ฒฐ๊ณผ:", output_sequence, "\n") # tensor
# ์ฒซ ๊ธ€์ž๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด
else:
# ํ•œ ์Œ์ ˆ
input_ids = tokenizer.encode(
val, add_special_tokens=False, return_tensors="pt")
# print(f"{idx}๋ฒˆ ์งธ ๊ธ€์ž ์ธ์ฝ”๋”ฉ : {input_ids} \n")
# ์ข€๋” ๋งค๋„๋Ÿฌ์šด ์‚ผํ–‰์‹œ๋ฅผ ์œ„ํ•ด ์ด์ „ ์ธ์ฝ”๋”ฉ๊ณผ ์ง€๊ธˆ ์ธ์ฝ”๋”ฉ ์—ฐ๊ฒฐ
link_with_pre_sentence = torch.cat((generated_sequence, input_ids[0]), 0)
link_with_pre_sentence = torch.reshape(link_with_pre_sentence, (1, len(link_with_pre_sentence)))
# print(f"์ด์ „ ํ…์„œ์™€ ์—ฐ๊ฒฐ๋œ ํ…์„œ {link_with_pre_sentence} \n")
# ์ธ์ฝ”๋”ฉ ๊ฐ’์œผ๋กœ ๋ฌธ์žฅ ์ƒ์„ฑ
output_sequence = model.generate(
input_ids=link_with_pre_sentence,
do_sample=True, max_length=42,
min_length=5, temperature=0.9, repetition_penalty=1.5,
no_repeat_ngram_size=2)[0]
# print(f"{idx}๋ฒˆ ์ธ์ฝ”๋”ฉ ํ›„ generate : {output_sequence}")
# ์ƒ์„ฑ๋œ ๋ฌธ์žฅ ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜ (์ธ์ฝ”๋”ฉ ๋˜์–ด์žˆ๊ณ , ์ƒ์„ฑ๋œ ๋ฌธ์žฅ ๋’ค๋กœ padding ์ด ์žˆ๋Š” ์ƒํƒœ)
generated_sequence = output_sequence.tolist()
# print(f"{idx}๋ฒˆ ์ธ์ฝ”๋”ฉ ๋ฆฌ์ŠคํŠธ : {generated_sequence} \n")
# padding index ์•ž๊นŒ์ง€ slicing ํ•จ์œผ๋กœ์จ padding ์ œ๊ฑฐ, padding์ด ์—†์„ ์ˆ˜๋„ ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์กฐ๊ฑด๋ฌธ ํ™•์ธ ํ›„ ์ œ๊ฑฐ
if tokenizer.pad_token_id in generated_sequence:
generated_sequence = generated_sequence[:generated_sequence.index(tokenizer.pad_token_id)]
generated_sequence = torch.tensor(generated_sequence)
# print(f"{idx}๋ฒˆ ์ธ์ฝ”๋”ฉ ๋ฆฌ์ŠคํŠธ ํŒจ๋”ฉ ์ œ๊ฑฐ ํ›„ ๋‹ค์‹œ ํ…์„œ : {generated_sequence} \n")
# ์ฒซ ๊ธ€์ž๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด, generate ๋œ ์Œ์ ˆ๋งŒ ๊ฒฐ๊ณผ๋ฌผ list์— ๋“ค์–ด๊ฐˆ ์ˆ˜ ์žˆ๊ฒŒ ์•ž ๋ฌธ์žฅ์— ๋Œ€ํ•œ ์ธ์ฝ”๋”ฉ ๊ฐ’ ์ œ๊ฑฐ
# print(generated_sequence)
if idx != 0:
# ์ด์ „ ๋ฌธ์žฅ์˜ ๊ธธ์ด ์ดํ›„๋กœ ์Šฌ๋ผ์ด์‹ฑํ•ด์„œ ์•ž ๋ฌธ์žฅ ์ œ๊ฑฐ
generated_sequence = generated_sequence[len_sequence:]
len_sequence = len(generated_sequence)
# print("len_seq", len_sequence)
# ์Œ์ ˆ ๊ทธ๋Œ€๋กœ ๋ฑ‰์œผ๋ฉด ๋‹ค์‹œ ํ•ด์™€, ์•„๋‹ˆ๋ฉด while๋ฌธ ํƒˆ์ถœ
if len_sequence > 1:
break
# ๊ฒฐ๊ณผ๋ฌผ ๋ฆฌ์ŠคํŠธ์— ๋‹ด๊ธฐ
res_l.append(generated_sequence)
poem_dict = {}
for letter, res in zip(input_letter, res_l):
decode_res = tokenizer.decode(res, clean_up_tokenization_spaces=True, skip_special_tokens=True)
poem_dict[poem(letter)] = decode_res
return poem_dict
###
# Image(.gif)
@st.cache(show_spinner=False)
def load_lottieurl(url: str):
r = requests.get(url)
if r.status_code != 200:
return None
return r.json()
lottie_url = "https://assets7.lottiefiles.com/private_files/lf30_fjln45y5.json"
lottie_json = load_lottieurl(lottie_url)
st_lottie(lottie_json, speed=1, height=200, key="initial")
# Title
row0_spacer1, row0_1, row0_spacer2, row0_2, row0_spacer3 = st.columns(
(0.01, 2, 0.05, 0.5, 0.01)
)
with row0_1:
st.markdown("# ํ•œ๊ธ€ ๋…ธ๋ž˜ ๊ฐ€์‚ฌ nํ–‰์‹œโœ")
st.markdown("### ๐Ÿฆ๋ฉ‹์Ÿ์ด์‚ฌ์ž์ฒ˜๋Ÿผ AIS7๐Ÿฆ - ํŒŒ์ด๋„ ํ”„๋กœ์ ํŠธ")
with row0_2:
st.write("")
st.write("")
st.write("")
st.subheader("1์กฐ - ํ•ดํŒŒ๋ฆฌ")
st.write("์ด์ง€ํ˜œ, ์ตœ์ง€์˜, ๊ถŒ์†Œํฌ, ๋ฌธ์ข…ํ˜„, ๊ตฌ์žํ˜„, ๊น€์˜์ค€")
st.write('---')
# Explanation
row1_spacer1, row1_1, row1_spacer2 = st.columns((0.01, 0.01, 0.01))
with row1_1:
st.markdown("### nํ–‰์‹œ ๊ฐ€์ด๋“œ๋ผ์ธ")
st.markdown("1. ํ•˜๋‹จ์— ์žˆ๋Š” ํ…์ŠคํŠธ๋ฐ”์— 5์ž ์ดํ•˜ ํ•œ๊ธ€ ๋‹จ์–ด๋ฅผ ๋„ฃ์–ด์ฃผ์„ธ์š”")
st.markdown("2. 'nํ–‰์‹œ ์ œ์ž‘ํ•˜๊ธฐ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•ด์ฃผ์„ธ์š”")
st.write('---')
# Model & Input
row2_spacer1, row2_1, row2_spacer2= st.columns((0.01, 0.01, 0.01))
# Word Input
with row2_1:
word_input = st.text_input(
"nํ–‰์‹œ์— ์‚ฌ์šฉํ•  ํ•œ๊ธ€ ๋‹จ์–ด๋ฅผ ์ ๊ณ  ๋ฒ„ํŠผ์„ ๋ˆŒ๋Ÿฌ์ฃผ์„ธ์š”.(์ตœ๋Œ€ 5์ž) ๐Ÿ‘‡",
placeholder='ํ•œ๊ธ€ ๋‹จ์–ด๋ฅผ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”',
max_chars=5
)
word_input = re.sub("[^๊ฐ€-ํžฃ]", "", word_input)
if st.button('nํ–‰์‹œ ์ œ์ž‘ํ•˜๊ธฐ'):
if word_input == "":
st.error("์˜จ์ „ํ•œ ํ•œ๊ธ€ ๋‹จ์–ด๋ฅผ ์‚ฌ์šฉํ•ด์ฃผ์„ธ์š”!")
else:
st.write("nํ–‰์‹œ ๋‹จ์–ด : ", word_input)
with st.spinner('์ž ์‹œ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”...'):
result = n_line_poem(word_input)
st.success('์™„๋ฃŒ๋์Šต๋‹ˆ๋‹ค!')
for r in result:
st.write(f'{r} : {result[r]}')