dk-davidekim commited on
Commit
56d2f69
โ€ข
1 Parent(s): 3397cd1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -12
app.py CHANGED
@@ -1,12 +1,90 @@
1
  import streamlit as st
2
- import numpy as np
3
- import pandas as pd
4
- import tensorflow as tf
5
- import matplotlib.pyplot as plt
6
- import requests
7
- import re
8
  from streamlit_lottie import st_lottie
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  st.set_page_config(
11
  page_title="๋…ธ๋ž˜ ๊ฐ€์‚ฌ nํ–‰์‹œ",
12
  page_icon="๐Ÿ’Œ",
@@ -60,11 +138,6 @@ st.write('---')
60
  # Model & Input
61
  row2_spacer1, row2_1, row2_spacer2, row2_2, row2_spacer3 = st.columns((0.01, 1.5, 0.05, 1.5, 0.01))
62
 
63
- # def load_model():
64
- # return tf.keras.models.load_model('')
65
-
66
- # model = load_model()
67
-
68
  # Genre Selector
69
  if "genre" not in st.session_state:
70
  st.session_state.genre = "์ „์ฒด"
@@ -90,7 +163,7 @@ with row2_2:
90
  st.write("nํ–‰์‹œ ๋‹จ์–ด : ", word_input)
91
 
92
  if st.button('nํ–‰์‹œ ์ œ์ž‘ํ•˜๊ธฐ'):
93
- st.write("์ œ์ž‘์ค‘...")
94
 
95
 
96
 
 
1
  import streamlit as st
 
 
 
 
 
 
2
  from streamlit_lottie import st_lottie
3
 
4
+ ### Model
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained("wumusill/final_20man")
9
+
10
+ @st.cache
11
+ def load_model():
12
+ model = AutoModelForCausalLM.from_pretrained("wumusill/final_20man")
13
+ return model
14
+
15
+ model = load_model()
16
+
17
+ def mind(input_letter):
18
+ # ๊ฒฐ๊ณผ๋ฌผ์„ ๋‹ด์„ list
19
+ res_l = []
20
+
21
+ # ํ•œ ๊ธ€์ž์”ฉ ์ธ๋ฑ์Šค์™€ ํ•จ๊ป˜ ๊ฐ€์ ธ์˜ด
22
+ for idx, val in enumerate(input_letter):
23
+
24
+ # ๋งŒ์•ฝ idx ๊ฐ€ 0 ์ด๋ผ๋ฉด == ์ฒซ ๊ธ€์ž
25
+ if idx == 0:
26
+ # ์ฒซ ๊ธ€์ž ์ธ์ฝ”๋”ฉ
27
+ input_ids = tokenizer.encode(
28
+ val, add_special_tokens=False, return_tensors="pt")
29
+
30
+ # ์ฒซ ๊ธ€์ž ์ธ์ฝ”๋”ฉ ๊ฐ’์œผ๋กœ ๋ฌธ์žฅ ์ƒ์„ฑ
31
+ output_sequence = model.generate(
32
+ input_ids,
33
+ do_sample=True, max_length=42)
34
+
35
+ # ์ฒซ ๊ธ€์ž๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด
36
+ else:
37
+ # ์ข€๋” ๋งค๋„๋Ÿฌ์šด ์‚ผํ–‰์‹œ๋ฅผ ์œ„ํ•ด ์ด์ „ ๋ฌธ์žฅ์ด๋ž‘ ํ˜„์žฌ ์Œ์ ˆ ์—ฐ๊ฒฐ
38
+ # ์ดํ›„ generate ๋œ ๋ฌธ์žฅ์—์„œ ์ด์ „ ๋ฌธ์žฅ์— ๋Œ€ํ•œ ๋ฐ์ดํ„ฐ ์ œ๊ฑฐ
39
+ link_with_pre_sentence = " ".join(res_l) + " " + val
40
+ # print(link_with_pre_sentence)
41
+
42
+ # ์—ฐ๊ฒฐ๋œ ๋ฌธ์žฅ์„ ์ธ์ฝ”๋”ฉ
43
+ input_ids = tokenizer.encode(
44
+ link_with_pre_sentence, add_special_tokens=False, return_tensors="pt")
45
+
46
+ # ์ธ์ฝ”๋”ฉ ๊ฐ’์œผ๋กœ ๋ฌธ์žฅ ์ƒ์„ฑ
47
+ output_sequence = model.generate(
48
+ input_ids,
49
+ do_sample=True, max_length=42)
50
+
51
+ # ์ƒ์„ฑ๋œ ๋ฌธ์žฅ ๋ฆฌ์ŠคํŠธ๋กœ ๋ณ€ํ™˜ (์ธ์ฝ”๋”ฉ ๋˜์–ด์žˆ๊ณ , ์ƒ์„ฑ๋œ ๋ฌธ์žฅ ๋’ค๋กœ padding ์ด ์žˆ๋Š” ์ƒํƒœ)
52
+ generated_sequence = output_sequence.tolist()[0]
53
+
54
+ # padding index ์•ž๊นŒ์ง€ slicing ํ•จ์œผ๋กœ์จ padding ์ œ๊ฑฐ
55
+ generated_sequence = generated_sequence[:generated_sequence.index(tokenizer.pad_token_id)]
56
+
57
+ # ์ฒซ ๊ธ€์ž๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด, generate ๋œ ์Œ์ ˆ๋งŒ ๊ฒฐ๊ณผ๋ฌผ list์— ๋“ค์–ด๊ฐˆ ์ˆ˜ ์žˆ๊ฒŒ ์•ž ๋ฌธ์žฅ์— ๋Œ€ํ•œ ์ธ์ฝ”๋”ฉ ๊ฐ’ ์ œ๊ฑฐ
58
+ # print(generated_sequence)
59
+ if idx != 0:
60
+ # ์ด์ „ ๋ฌธ์žฅ์˜ ๋งˆ์ง€๋ง‰ ์‹œํ€€์Šค ์ดํ›„๋กœ ์Šฌ๋ผ์ด์‹ฑํ•ด์„œ ์•ž ๋ฌธ์žฅ ์ œ๊ฑฐ
61
+ generated_sequence = generated_sequence[generated_sequence.index(last_sequence) + 1:]
62
+
63
+ # ๋‹ค์Œ ์Œ์ ˆ์„ ์œ„ํ•ด ๋งˆ์ง€๋ง‰ ์‹œํ€€์Šค ๊ฐฑ์‹ 
64
+ last_sequence = generated_sequence[-1]
65
+
66
+ # ์ฒซ ๊ธ€์ž๋ผ๋ฉด
67
+ else:
68
+ # ๋งˆ์ง€๋ง‰ ์‹œํ€€์Šค ์ €์žฅ
69
+ last_sequence = generated_sequence[-1]
70
+
71
+ # print(last_sequence)
72
+
73
+ # ๊ฒฐ๊ณผ๋ฌผ ๋””์ฝ”๋”ฉ
74
+ decoded_sequence = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
75
+
76
+ # ๊ฒฐ๊ณผ๋ฌผ ๋ฆฌ์ŠคํŠธ์— ๋‹ด๊ธฐ
77
+ res_l.append(decoded_sequence)
78
+
79
+ # print(res_l)
80
+
81
+ # ๊ฒฐ๊ณผ๋ฌผ list์—์„œ ํ•œ ์ค„์”ฉ ์ถœ๋ ฅ
82
+ for letter, res in zip(input_letter, res_l):
83
+ return(f"{letter} :", res)
84
+
85
+ ###
86
+
87
+
88
  st.set_page_config(
89
  page_title="๋…ธ๋ž˜ ๊ฐ€์‚ฌ nํ–‰์‹œ",
90
  page_icon="๐Ÿ’Œ",
 
138
  # Model & Input
139
  row2_spacer1, row2_1, row2_spacer2, row2_2, row2_spacer3 = st.columns((0.01, 1.5, 0.05, 1.5, 0.01))
140
 
 
 
 
 
 
141
  # Genre Selector
142
  if "genre" not in st.session_state:
143
  st.session_state.genre = "์ „์ฒด"
 
163
  st.write("nํ–‰์‹œ ๋‹จ์–ด : ", word_input)
164
 
165
  if st.button('nํ–‰์‹œ ์ œ์ž‘ํ•˜๊ธฐ'):
166
+ st.write(mind(word_input))
167
 
168
 
169