Spaces:
Runtime error
Runtime error
import pandas as pd | |
import requests | |
import streamlit as st | |
from streamlit_lottie import st_lottie | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import re | |
# Page Config | |
st.set_page_config( | |
page_title="๋ ธ๋ ๊ฐ์ฌ nํ์ Beta", | |
page_icon="๐", | |
layout="wide" | |
) | |
# st.text(os.listdir(os.curdir)) | |
### Model | |
tokenizer = AutoTokenizer.from_pretrained("wumusill/final_project_kogpt2") | |
def load_model(): | |
model = AutoModelForCausalLM.from_pretrained("wumusill/final_project_kogpt2") | |
return model | |
model = load_model() | |
def get_word(): | |
word = pd.read_csv("ballad_word.csv", encoding="cp949") | |
return word | |
word = get_word() | |
one = word[word["0"].str.startswith("ํ")].sample(1).values[0][0] | |
# st.header(type(one)) | |
# st.header(one) | |
# Class : Dict ์ค๋ณต ํค ์ถ๋ ฅ | |
class poem(object): | |
def __init__(self,letter): | |
self.letter = letter | |
def __str__(self): | |
return self.letter | |
def __repr__(self): | |
return "'"+self.letter+"'" | |
def beta_poem(input_letter): | |
# ๋์ ๋ฒ์น ์ฌ์ | |
dooeum = {"๋ผ":"๋", "๋ฝ":"๋", "๋":"๋", "๋":"๋ ", "๋":"๋จ", "๋":"๋ฉ", "๋":"๋ญ", | |
"๋":"๋ด", "๋ญ":"๋", "๋":"์ฝ", "๋ต":"์ฝ", "๋ฅ":"์", "๋":"์", "๋ ":"์ฌ", | |
"๋ ค":"์ฌ", "๋ ":"์ญ", "๋ ฅ":"์ญ", "๋ ":"์ฐ", "๋ จ":"์ฐ", "๋ ":"์ด", "๋ ฌ":"์ด", | |
"๋ ":"์ผ", "๋ ด":"์ผ", "๋ ต":"์ฝ", "๋ ":"์", "๋ น":"์", "๋ ":"์", "๋ก":"์", | |
"๋ก":"๋ ธ", "๋ก":"๋ น", "๋ก ":"๋ ผ", "๋กฑ":"๋", "๋ขฐ":"๋", "๋จ":"์", "๋ฃ":"์", | |
"๋ฃก":"์ฉ", "๋ฃจ":"๋", "๋ด":"์ ", "๋ฅ":"์ ", "๋ต":"์ก", "๋ฅ":"์ก", "๋ฅ":"์ค", | |
"๋ฅ ":"์จ", "๋ฅญ":"์ต", "๋ฅต":"๋", "๋ฆ":"๋ ", "๋ฆ":"๋ฅ", "๋":"์ด", "๋ฆฌ":"์ด", | |
"๋ฆฐ":'์ธ', '๋ฆผ':'์', '๋ฆฝ':'์ '} | |
# ๊ฒฐ๊ณผ๋ฌผ์ ๋ด์ list | |
res_l = [] | |
len_sequence = 0 | |
# ํ ๊ธ์์ฉ ์ธ๋ฑ์ค์ ํจ๊ป ๊ฐ์ ธ์ด | |
for idx, val in enumerate(input_letter): | |
# ๋์ ๋ฒ์น ์ ์ฉ | |
if val in dooeum.keys(): | |
val = dooeum[val] | |
# ๋ฐ๋ผ๋์ ์๋ ๋จ์ด ์ ์ฉ | |
try: | |
one = word[word["0"].str.startswith(val)].sample(1).values[0][0] | |
# st.text(one) | |
except: | |
one = val | |
# ์ข๋ ๋งค๋๋ฌ์ด ์ผํ์๋ฅผ ์ํด ์ด์ ๋ฌธ์ฅ์ด๋ ํ์ฌ ์์ ์ฐ๊ฒฐ | |
# ์ดํ generate ๋ ๋ฌธ์ฅ์์ ์ด์ ๋ฌธ์ฅ์ ๋ํ ๋ฐ์ดํฐ ์ ๊ฑฐ | |
link_with_pre_sentence = (" ".join(res_l)+ " " + one + " " if idx != 0 else one).strip() | |
# print(link_with_pre_sentence) | |
# ์ฐ๊ฒฐ๋ ๋ฌธ์ฅ์ ์ธ์ฝ๋ฉ | |
input_ids = tokenizer.encode(link_with_pre_sentence, add_special_tokens=False, return_tensors="pt") | |
# ์ธ์ฝ๋ฉ ๊ฐ์ผ๋ก ๋ฌธ์ฅ ์์ฑ | |
output_sequence = model.generate( | |
input_ids=input_ids, | |
do_sample=True, | |
max_length=42, | |
min_length=len_sequence + 2, | |
temperature=0.9, | |
repetition_penalty=1.5, | |
no_repeat_ngram_size=2) | |
# ์์ฑ๋ ๋ฌธ์ฅ ๋ฆฌ์คํธ๋ก ๋ณํ (์ธ์ฝ๋ฉ ๋์ด์๊ณ , ์์ฑ๋ ๋ฌธ์ฅ ๋ค๋ก padding ์ด ์๋ ์ํ) | |
generated_sequence = output_sequence.tolist()[0] | |
# padding index ์๊น์ง slicing ํจ์ผ๋ก์จ padding ์ ๊ฑฐ, padding์ด ์์ ์๋ ์๊ธฐ ๋๋ฌธ์ ์กฐ๊ฑด๋ฌธ ํ์ธ ํ ์ ๊ฑฐ | |
# ์ฌ์ฉํ generated_sequence ๊ฐ 5๋ณด๋ค ์งง์ผ๋ฉด ๊ฐ์ ์ ์ผ๋ก ๊ธธ์ด๋ฅผ 8๋ก ํด์ค๋ค... | |
if tokenizer.pad_token_id in generated_sequence: | |
check_index = generated_sequence.index(tokenizer.pad_token_id) | |
check_index = check_index if check_index-len_sequence > 3 else len_sequence + 8 | |
generated_sequence = generated_sequence[:check_index] | |
word_encode = tokenizer.encode(one, add_special_tokens=False, return_tensors="pt").tolist()[0][0] | |
split_index = len(generated_sequence) - 1 - generated_sequence[::-1].index(word_encode) | |
# ์ฒซ ๊ธ์๊ฐ ์๋๋ผ๋ฉด, generate ๋ ์์ ๋ง ๊ฒฐ๊ณผ๋ฌผ list์ ๋ค์ด๊ฐ ์ ์๊ฒ ์ ๋ฌธ์ฅ์ ๋ํ ์ธ์ฝ๋ฉ ๊ฐ ์ ๊ฑฐ | |
generated_sequence = generated_sequence[split_index:] | |
# print(tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)) | |
# ๋ค์ ์์ ์ ์ํด ๊ธธ์ด ๊ฐฑ์ | |
len_sequence += len([elem for elem in generated_sequence if elem not in(tokenizer.all_special_ids)]) | |
# ๊ฒฐ๊ณผ๋ฌผ ๋์ฝ๋ฉ | |
decoded_sequence = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
# ๊ฒฐ๊ณผ๋ฌผ ๋ฆฌ์คํธ์ ๋ด๊ธฐ | |
res_l.append(decoded_sequence) | |
poem_dict = {"Type":"beta"} | |
for letter, res in zip(input_letter, res_l): | |
# decode_res = tokenizer.decode(res, clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
poem_dict[poem(letter)] = res | |
return poem_dict | |
def alpha_poem(input_letter): | |
# ๋์ ๋ฒ์น ์ฌ์ | |
dooeum = {"๋ผ":"๋", "๋ฝ":"๋", "๋":"๋", "๋":"๋ ", "๋":"๋จ", "๋":"๋ฉ", "๋":"๋ญ", | |
"๋":"๋ด", "๋ญ":"๋", "๋":"์ฝ", "๋ต":"์ฝ", "๋ฅ":"์", "๋":"์", "๋ ":"์ฌ", | |
"๋ ค":"์ฌ", "๋ ":"์ญ", "๋ ฅ":"์ญ", "๋ ":"์ฐ", "๋ จ":"์ฐ", "๋ ":"์ด", "๋ ฌ":"์ด", | |
"๋ ":"์ผ", "๋ ด":"์ผ", "๋ ต":"์ฝ", "๋ ":"์", "๋ น":"์", "๋ ":"์", "๋ก":"์", | |
"๋ก":"๋ ธ", "๋ก":"๋ น", "๋ก ":"๋ ผ", "๋กฑ":"๋", "๋ขฐ":"๋", "๋จ":"์", "๋ฃ":"์", | |
"๋ฃก":"์ฉ", "๋ฃจ":"๋", "๋ด":"์ ", "๋ฅ":"์ ", "๋ต":"์ก", "๋ฅ":"์ก", "๋ฅ":"์ค", | |
"๋ฅ ":"์จ", "๋ฅญ":"์ต", "๋ฅต":"๋", "๋ฆ":"๋ ", "๋ฆ":"๋ฅ", "๋":"์ด", "๋ฆฌ":"์ด", | |
"๋ฆฐ":'์ธ', '๋ฆผ':'์', '๋ฆฝ':'์ '} | |
# ๊ฒฐ๊ณผ๋ฌผ์ ๋ด์ list | |
res_l = [] | |
# ํ ๊ธ์์ฉ ์ธ๋ฑ์ค์ ํจ๊ป ๊ฐ์ ธ์ด | |
for idx, val in enumerate(input_letter): | |
# ๋์ ๋ฒ์น ์ ์ฉ | |
if val in dooeum.keys(): | |
val = dooeum[val] | |
while True: | |
# ๋ง์ฝ idx ๊ฐ 0 ์ด๋ผ๋ฉด == ์ฒซ ๊ธ์ | |
if idx == 0: | |
# ์ฒซ ๊ธ์ ์ธ์ฝ๋ฉ | |
input_ids = tokenizer.encode( | |
val, add_special_tokens=False, return_tensors="pt") | |
# print(f"{idx}๋ฒ ์ธ์ฝ๋ฉ : {input_ids}\n") # 2์ฐจ์ ํ ์ | |
# ์ฒซ ๊ธ์ ์ธ์ฝ๋ฉ ๊ฐ์ผ๋ก ๋ฌธ์ฅ ์์ฑ | |
output_sequence = model.generate( | |
input_ids=input_ids, | |
do_sample=True, | |
max_length=42, | |
min_length=5, | |
temperature=0.9, | |
repetition_penalty=1.7, | |
no_repeat_ngram_size=2)[0] | |
# print("์ฒซ ๊ธ์ ์ธ์ฝ๋ฉ ํ generate ๊ฒฐ๊ณผ:", output_sequence, "\n") # tensor | |
# ์ฒซ ๊ธ์๊ฐ ์๋๋ผ๋ฉด | |
else: | |
# ํ ์์ | |
input_ids = tokenizer.encode( | |
val, add_special_tokens=False, return_tensors="pt") | |
# print(f"{idx}๋ฒ ์งธ ๊ธ์ ์ธ์ฝ๋ฉ : {input_ids} \n") | |
# ์ข๋ ๋งค๋๋ฌ์ด ์ผํ์๋ฅผ ์ํด ์ด์ ์ธ์ฝ๋ฉ๊ณผ ์ง๊ธ ์ธ์ฝ๋ฉ ์ฐ๊ฒฐ | |
link_with_pre_sentence = torch.cat((generated_sequence, input_ids[0]), 0) | |
link_with_pre_sentence = torch.reshape(link_with_pre_sentence, (1, len(link_with_pre_sentence))) | |
# print(f"์ด์ ํ ์์ ์ฐ๊ฒฐ๋ ํ ์ {link_with_pre_sentence} \n") | |
# ์ธ์ฝ๋ฉ ๊ฐ์ผ๋ก ๋ฌธ์ฅ ์์ฑ | |
output_sequence = model.generate( | |
input_ids=link_with_pre_sentence, | |
do_sample=True, | |
max_length=42, | |
min_length=5, | |
temperature=0.9, | |
repetition_penalty=1.7, | |
no_repeat_ngram_size=2)[0] | |
# print(f"{idx}๋ฒ ์ธ์ฝ๋ฉ ํ generate : {output_sequence}") | |
# ์์ฑ๋ ๋ฌธ์ฅ ๋ฆฌ์คํธ๋ก ๋ณํ (์ธ์ฝ๋ฉ ๋์ด์๊ณ , ์์ฑ๋ ๋ฌธ์ฅ ๋ค๋ก padding ์ด ์๋ ์ํ) | |
generated_sequence = output_sequence.tolist() | |
# print(f"{idx}๋ฒ ์ธ์ฝ๋ฉ ๋ฆฌ์คํธ : {generated_sequence} \n") | |
# padding index ์๊น์ง slicing ํจ์ผ๋ก์จ padding ์ ๊ฑฐ, padding์ด ์์ ์๋ ์๊ธฐ ๋๋ฌธ์ ์กฐ๊ฑด๋ฌธ ํ์ธ ํ ์ ๊ฑฐ | |
if tokenizer.pad_token_id in generated_sequence: | |
generated_sequence = generated_sequence[:generated_sequence.index(tokenizer.pad_token_id)] | |
generated_sequence = torch.tensor(generated_sequence) | |
# print(f"{idx}๋ฒ ์ธ์ฝ๋ฉ ๋ฆฌ์คํธ ํจ๋ฉ ์ ๊ฑฐ ํ ๋ค์ ํ ์ : {generated_sequence} \n") | |
# ์ฒซ ๊ธ์๊ฐ ์๋๋ผ๋ฉด, generate ๋ ์์ ๋ง ๊ฒฐ๊ณผ๋ฌผ list์ ๋ค์ด๊ฐ ์ ์๊ฒ ์ ๋ฌธ์ฅ์ ๋ํ ์ธ์ฝ๋ฉ ๊ฐ ์ ๊ฑฐ | |
# print(generated_sequence) | |
if idx != 0: | |
# ์ด์ ๋ฌธ์ฅ์ ๊ธธ์ด ์ดํ๋ก ์ฌ๋ผ์ด์ฑํด์ ์ ๋ฌธ์ฅ ์ ๊ฑฐ | |
generated_sequence = generated_sequence[len_sequence:] | |
len_sequence = len(generated_sequence) | |
# print("len_seq", len_sequence) | |
# ์์ ๊ทธ๋๋ก ๋ฑ์ผ๋ฉด ๋ค์ ํด์, ์๋๋ฉด while๋ฌธ ํ์ถ | |
if len_sequence > 1: | |
break | |
# ๊ฒฐ๊ณผ๋ฌผ ๋ฆฌ์คํธ์ ๋ด๊ธฐ | |
res_l.append(generated_sequence) | |
poem_dict = {"Type":"alpha"} | |
for letter, res in zip(input_letter, res_l): | |
decode_res = tokenizer.decode(res, clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
poem_dict[poem(letter)] = decode_res | |
return poem_dict | |
# Image(.gif) | |
def load_lottieurl(url: str): | |
r = requests.get(url) | |
if r.status_code != 200: | |
return None | |
return r.json() | |
lottie_url = "https://assets7.lottiefiles.com/private_files/lf30_fjln45y5.json" | |
lottie_json = load_lottieurl(lottie_url) | |
st_lottie(lottie_json, speed=1, height=200, key="initial") | |
# Title | |
row0_spacer1, row0_1, row0_spacer2, row0_2, row0_spacer3 = st.columns( | |
(0.01, 2, 0.05, 0.5, 0.01) | |
) | |
with row0_1: | |
st.markdown("# ํ๊ธ ๋ ธ๋ ๊ฐ์ฌ nํ์โ") | |
st.markdown("### ๐ฆ๋ฉ์์ด์ฌ์์ฒ๋ผ AIS7๐ฆ - ํ์ด๋ ํ๋ก์ ํธ") | |
with row0_2: | |
st.write("") | |
st.write("") | |
st.write("") | |
st.subheader("1์กฐ - ํดํ๋ฆฌ") | |
st.write("์ด์งํ, ์ต์ง์, ๊ถ์ํฌ, ๋ฌธ์ข ํ, ๊ตฌ์ํ, ๊น์์ค") | |
st.write('---') | |
# Explanation | |
row1_spacer1, row1_1, row1_spacer2 = st.columns((0.01, 0.01, 0.01)) | |
with row1_1: | |
st.markdown("### nํ์ ๊ฐ์ด๋๋ผ์ธ") | |
st.markdown("1. ํ๋จ์ ์๋ ํ ์คํธ๋ฐ์ 5์ ์ดํ ๋จ์ด๋ฅผ ๋ฃ์ด์ฃผ์ธ์") | |
st.markdown("2. 'nํ์ ์ ์ํ๊ธฐ' ๋ฒํผ์ ํด๋ฆญํด์ฃผ์ธ์") | |
st.markdown("* nํ์ ํ์ ์ค์ \n" | |
" * Alpha ver. : ๋ชจ๋ธ์ด ์ฒซ ์์ ๋ถํฐ ์์ฑ\n" | |
" * Beta ver. : ์ฒซ ์์ ์ ๋ฐ์ดํฐ์ ์์ ์ฐพ๊ณ , ๋ค์ ๋ถ๋ถ์ ์์ฑ") | |
st.write('---') | |
# Model & Input | |
row2_spacer1, row2_1, row2_spacer2= st.columns((0.01, 0.01, 0.01)) | |
col1, col2 = st.columns(2) | |
# Word Input | |
with row2_1: | |
with col1: | |
genre = st.radio( | |
"nํ์ ํ์ ์ ํ", | |
('Alpha', 'Beta(test์ค)')) | |
if genre == 'Alpha': | |
n_line_poem = alpha_poem | |
else: | |
n_line_poem = beta_poem | |
with col2: | |
word_input = st.text_input( | |
"nํ์์ ์ฌ์ฉํ ๋จ์ด๋ฅผ ์ ๊ณ ๋ฒํผ์ ๋๋ฌ์ฃผ์ธ์.(์ต๋ 5์) ๐", | |
placeholder='ํ๊ธ ๋จ์ด๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์', | |
max_chars=5 | |
) | |
word_input = re.sub("[^๊ฐ-ํฃ]", "", word_input) | |
if st.button('nํ์ ์ ์ํ๊ธฐ'): | |
if word_input == "": | |
st.error("์จ์ ํ ํ๊ธ ๋จ์ด๋ฅผ ์ฌ์ฉํด์ฃผ์ธ์!") | |
else: | |
st.write("nํ์ ๋จ์ด : ", word_input) | |
with st.spinner('์ ์ ๊ธฐ๋ค๋ ค์ฃผ์ธ์...'): | |
result = n_line_poem(word_input) | |
st.success('์๋ฃ๋์ต๋๋ค!') | |
for r in result: | |
st.write(f'{r} : {result[r]}') | |