Spaces:
Runtime error
Runtime error
Commit
โข
56d2f69
1
Parent(s):
3397cd1
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,90 @@
|
|
1 |
import streamlit as st
|
2 |
-
import numpy as np
|
3 |
-
import pandas as pd
|
4 |
-
import tensorflow as tf
|
5 |
-
import matplotlib.pyplot as plt
|
6 |
-
import requests
|
7 |
-
import re
|
8 |
from streamlit_lottie import st_lottie
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
st.set_page_config(
|
11 |
page_title="๋
ธ๋ ๊ฐ์ฌ nํ์",
|
12 |
page_icon="๐",
|
@@ -60,11 +138,6 @@ st.write('---')
|
|
60 |
# Model & Input
|
61 |
row2_spacer1, row2_1, row2_spacer2, row2_2, row2_spacer3 = st.columns((0.01, 1.5, 0.05, 1.5, 0.01))
|
62 |
|
63 |
-
# def load_model():
|
64 |
-
# return tf.keras.models.load_model('')
|
65 |
-
|
66 |
-
# model = load_model()
|
67 |
-
|
68 |
# Genre Selector
|
69 |
if "genre" not in st.session_state:
|
70 |
st.session_state.genre = "์ ์ฒด"
|
@@ -90,7 +163,7 @@ with row2_2:
|
|
90 |
st.write("nํ์ ๋จ์ด : ", word_input)
|
91 |
|
92 |
if st.button('nํ์ ์ ์ํ๊ธฐ'):
|
93 |
-
st.write(
|
94 |
|
95 |
|
96 |
|
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from streamlit_lottie import st_lottie
|
3 |
|
4 |
+
### Model
|
5 |
+
import torch
|
6 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
7 |
+
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained("wumusill/final_20man")
|
9 |
+
|
10 |
+
@st.cache
|
11 |
+
def load_model():
|
12 |
+
model = AutoModelForCausalLM.from_pretrained("wumusill/final_20man")
|
13 |
+
return model
|
14 |
+
|
15 |
+
model = load_model()
|
16 |
+
|
17 |
+
def mind(input_letter):
|
18 |
+
# ๊ฒฐ๊ณผ๋ฌผ์ ๋ด์ list
|
19 |
+
res_l = []
|
20 |
+
|
21 |
+
# ํ ๊ธ์์ฉ ์ธ๋ฑ์ค์ ํจ๊ป ๊ฐ์ ธ์ด
|
22 |
+
for idx, val in enumerate(input_letter):
|
23 |
+
|
24 |
+
# ๋ง์ฝ idx ๊ฐ 0 ์ด๋ผ๋ฉด == ์ฒซ ๊ธ์
|
25 |
+
if idx == 0:
|
26 |
+
# ์ฒซ ๊ธ์ ์ธ์ฝ๋ฉ
|
27 |
+
input_ids = tokenizer.encode(
|
28 |
+
val, add_special_tokens=False, return_tensors="pt")
|
29 |
+
|
30 |
+
# ์ฒซ ๊ธ์ ์ธ์ฝ๋ฉ ๊ฐ์ผ๋ก ๋ฌธ์ฅ ์์ฑ
|
31 |
+
output_sequence = model.generate(
|
32 |
+
input_ids,
|
33 |
+
do_sample=True, max_length=42)
|
34 |
+
|
35 |
+
# ์ฒซ ๊ธ์๊ฐ ์๋๋ผ๋ฉด
|
36 |
+
else:
|
37 |
+
# ์ข๋ ๋งค๋๋ฌ์ด ์ผํ์๋ฅผ ์ํด ์ด์ ๋ฌธ์ฅ์ด๋ ํ์ฌ ์์ ์ฐ๊ฒฐ
|
38 |
+
# ์ดํ generate ๋ ๋ฌธ์ฅ์์ ์ด์ ๋ฌธ์ฅ์ ๋ํ ๋ฐ์ดํฐ ์ ๊ฑฐ
|
39 |
+
link_with_pre_sentence = " ".join(res_l) + " " + val
|
40 |
+
# print(link_with_pre_sentence)
|
41 |
+
|
42 |
+
# ์ฐ๊ฒฐ๋ ๋ฌธ์ฅ์ ์ธ์ฝ๋ฉ
|
43 |
+
input_ids = tokenizer.encode(
|
44 |
+
link_with_pre_sentence, add_special_tokens=False, return_tensors="pt")
|
45 |
+
|
46 |
+
# ์ธ์ฝ๋ฉ ๊ฐ์ผ๋ก ๋ฌธ์ฅ ์์ฑ
|
47 |
+
output_sequence = model.generate(
|
48 |
+
input_ids,
|
49 |
+
do_sample=True, max_length=42)
|
50 |
+
|
51 |
+
# ์์ฑ๋ ๋ฌธ์ฅ ๋ฆฌ์คํธ๋ก ๋ณํ (์ธ์ฝ๋ฉ ๋์ด์๊ณ , ์์ฑ๋ ๋ฌธ์ฅ ๋ค๋ก padding ์ด ์๋ ์ํ)
|
52 |
+
generated_sequence = output_sequence.tolist()[0]
|
53 |
+
|
54 |
+
# padding index ์๊น์ง slicing ํจ์ผ๋ก์จ padding ์ ๊ฑฐ
|
55 |
+
generated_sequence = generated_sequence[:generated_sequence.index(tokenizer.pad_token_id)]
|
56 |
+
|
57 |
+
# ์ฒซ ๊ธ์๊ฐ ์๋๋ผ๋ฉด, generate ๋ ์์ ๋ง ๊ฒฐ๊ณผ๋ฌผ list์ ๋ค์ด๊ฐ ์ ์๊ฒ ์ ๋ฌธ์ฅ์ ๋ํ ์ธ์ฝ๋ฉ ๊ฐ ์ ๊ฑฐ
|
58 |
+
# print(generated_sequence)
|
59 |
+
if idx != 0:
|
60 |
+
# ์ด์ ๋ฌธ์ฅ์ ๋ง์ง๋ง ์ํ์ค ์ดํ๋ก ์ฌ๋ผ์ด์ฑํด์ ์ ๋ฌธ์ฅ ์ ๊ฑฐ
|
61 |
+
generated_sequence = generated_sequence[generated_sequence.index(last_sequence) + 1:]
|
62 |
+
|
63 |
+
# ๋ค์ ์์ ์ ์ํด ๋ง์ง๋ง ์ํ์ค ๊ฐฑ์
|
64 |
+
last_sequence = generated_sequence[-1]
|
65 |
+
|
66 |
+
# ์ฒซ ๊ธ์๋ผ๋ฉด
|
67 |
+
else:
|
68 |
+
# ๋ง์ง๋ง ์ํ์ค ์ ์ฅ
|
69 |
+
last_sequence = generated_sequence[-1]
|
70 |
+
|
71 |
+
# print(last_sequence)
|
72 |
+
|
73 |
+
# ๊ฒฐ๊ณผ๋ฌผ ๋์ฝ๋ฉ
|
74 |
+
decoded_sequence = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
|
75 |
+
|
76 |
+
# ๊ฒฐ๊ณผ๋ฌผ ๋ฆฌ์คํธ์ ๋ด๊ธฐ
|
77 |
+
res_l.append(decoded_sequence)
|
78 |
+
|
79 |
+
# print(res_l)
|
80 |
+
|
81 |
+
# ๊ฒฐ๊ณผ๋ฌผ list์์ ํ ์ค์ฉ ์ถ๋ ฅ
|
82 |
+
for letter, res in zip(input_letter, res_l):
|
83 |
+
return(f"{letter} :", res)
|
84 |
+
|
85 |
+
###
|
86 |
+
|
87 |
+
|
88 |
st.set_page_config(
|
89 |
page_title="๋
ธ๋ ๊ฐ์ฌ nํ์",
|
90 |
page_icon="๐",
|
|
|
138 |
# Model & Input
|
139 |
row2_spacer1, row2_1, row2_spacer2, row2_2, row2_spacer3 = st.columns((0.01, 1.5, 0.05, 1.5, 0.01))
|
140 |
|
|
|
|
|
|
|
|
|
|
|
141 |
# Genre Selector
|
142 |
if "genre" not in st.session_state:
|
143 |
st.session_state.genre = "์ ์ฒด"
|
|
|
163 |
st.write("nํ์ ๋จ์ด : ", word_input)
|
164 |
|
165 |
if st.button('nํ์ ์ ์ํ๊ธฐ'):
|
166 |
+
st.write(mind(word_input))
|
167 |
|
168 |
|
169 |
|