Spaces:
Runtime error
Runtime error
block gr
Browse files
app.py
CHANGED
@@ -1,7 +1,135 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
def
|
4 |
-
return "
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
+
def update(name):
|
4 |
+
return f"Welcome to Gradio, {name}!"
|
5 |
|
6 |
+
demo = gr.Blocks()
|
7 |
+
|
8 |
+
with demo:
|
9 |
+
gr.Markdown("Start typing below and then click **Run** to see the output.")
|
10 |
+
with gr.Row():
|
11 |
+
inp = gr.Textbox(placeholder="What is your name?")
|
12 |
+
out = gr.Textbox()
|
13 |
+
btn = gr.Button("Run")
|
14 |
+
btn.click(fn=update, inputs=inp, outputs=out)
|
15 |
+
|
16 |
+
demo.launch()
|
17 |
+
|
18 |
+
|
19 |
+
def main(model_name):
|
20 |
+
warnings.filterwarnings("ignore")
|
21 |
+
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
|
23 |
+
special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ#', '#@๊ณ์ #', '#@์ ์#', '#@์ ๋ฒ#', '#@๊ธ์ต#', '#@๋ฒํธ#', '#@์ฃผ์#', '#@์์#', '#@๊ธฐํ#']}
|
24 |
+
num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
|
25 |
+
|
26 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
27 |
+
model.resize_token_embeddings(len(tokenizer))
|
28 |
+
model = model.cuda()
|
29 |
+
|
30 |
+
info = ""
|
31 |
+
|
32 |
+
while True:
|
33 |
+
if info == "":
|
34 |
+
print(
|
35 |
+
f"์ง๊ธ๋ถํฐ ๋ํ ์ ๋ณด๋ฅผ ์
๋ ฅ ๋ฐ๊ฒ ์ต๋๋ค.\n"
|
36 |
+
f"๊ฐ ์ง๋ฌธ์ ๋๋ต ํ Enter ํด์ฃผ์ธ์.\n"
|
37 |
+
f"์๋ฌด ์
๋ ฅ ์์ด Enter ํ ๊ฒฝ์ฐ, ๋ฏธ๋ฆฌ ์ง์ ๋ ๊ฐ ์ค ๋๋ค์ผ๋ก ์ ํ๊ฒ ๋ฉ๋๋ค.\n"
|
38 |
+
)
|
39 |
+
|
40 |
+
time.sleep(1)
|
41 |
+
|
42 |
+
yon = "no"
|
43 |
+
else:
|
44 |
+
yon = input(
|
45 |
+
f"์ด์ ๋ํ ์ ๋ณด๋ฅผ ๊ทธ๋๋ก ์ ์งํ ๊น์? (yes : ์ ์ง, no : ์๋ก ์์ฑ) :"
|
46 |
+
)
|
47 |
+
|
48 |
+
if yon == "no":
|
49 |
+
info = "์ผ์ ๋ํ "
|
50 |
+
|
51 |
+
topic = input("๋ํ ์ฃผ์ ๋ฅผ ์ ํด์ฃผ์ธ์ (e.g. ์ฌ๊ฐ ์ํ, ์ผ๊ณผ ์ง์
, ๊ฐ์ธ ๋ฐ ๊ด๊ณ, etc...) :")
|
52 |
+
if topic == "":
|
53 |
+
topic = random.choice(['์ฌ๊ฐ ์ํ', '์์ฌ/๊ต์ก', '๋ฏธ์ฉ๊ณผ ๊ฑด๊ฐ', '์์๋ฃ', '์๊ฑฐ๋(์ผํ)', '์ผ๊ณผ ์ง์
', '์ฃผ๊ฑฐ์ ์ํ', '๊ฐ์ธ ๋ฐ ๊ด๊ณ', 'ํ์ฌ'])
|
54 |
+
print(topic)
|
55 |
+
info += topic + "<sep>"
|
56 |
+
|
57 |
+
def ask_info(who, ment):
|
58 |
+
print(ment)
|
59 |
+
text = who + ":"
|
60 |
+
addr = input("์ด๋ ์ฌ์ธ์? (e.g. ์์ธํน๋ณ์, ์ ์ฃผ๋, etc...) :").strip()
|
61 |
+
if addr == "":
|
62 |
+
addr = random.choice(['์์ธํน๋ณ์', '๊ฒฝ๊ธฐ๋', '๋ถ์ฐ๊ด์ญ์', '๋์ ๊ด์ญ์', '๊ด์ฃผ๊ด์ญ์', '์ธ์ฐ๊ด์ญ์', '๊ฒฝ์๋จ๋', '์ธ์ฒ๊ด์ญ์', '์ถฉ์ฒญ๋ถ๋', '์ ์ฃผ๋', '๊ฐ์๋', '์ถฉ์ฒญ๋จ๋', '์ ๋ผ๋ถ๋', '๋๊ตฌ๊ด์ญ์', '์ ๋ผ๋จ๋', '๊ฒฝ์๋ถ๋', '์ธ์ข
ํน๋ณ์์น์', '๊ธฐํ'])
|
63 |
+
print(addr)
|
64 |
+
text += addr + " "
|
65 |
+
|
66 |
+
age = input("๋์ด๊ฐ? (e.g. 20๋, 70๋ ์ด์, etc...) :").strip()
|
67 |
+
if age == "":
|
68 |
+
age = random.choice(['20๋', '30๋', '50๋', '20๋ ๋ฏธ๋ง', '60๋', '40๋', '70๋ ์ด์'])
|
69 |
+
print(age)
|
70 |
+
text += age + " "
|
71 |
+
|
72 |
+
sex = input("์ฑ๋ณ์ด? (e.g. ๋จ์ฑ, ์ฌ์ฑ, etc... (?)) :").strip()
|
73 |
+
if sex == "":
|
74 |
+
sex = random.choice(['๋จ์ฑ', '์ฌ์ฑ'])
|
75 |
+
print(sex)
|
76 |
+
text += sex + "<sep>"
|
77 |
+
return text
|
78 |
+
|
79 |
+
info += ask_info(who="P01", ment=f"\n๋น์ ์ ๋ํด ์๋ ค์ฃผ์ธ์.\n")
|
80 |
+
info += ask_info(who="P02", ment=f"\n์ฑ๋ด์ ๋ํด ์๋ ค์ฃผ์ธ์.\n")
|
81 |
+
|
82 |
+
pp = info.replace('<sep>', '\n')
|
83 |
+
print(
|
84 |
+
f"\n----------------\n"
|
85 |
+
f"<์
๋ ฅ ์ ๋ณด ํ์ธ> (P01 : ๋น์ , P02 : ์ฑ๋ด)\n"
|
86 |
+
f"{pp}"
|
87 |
+
f"----------------\n"
|
88 |
+
f"๋ํ๋ฅผ ์ข
๋ฃํ๊ณ ์ถ์ผ๋ฉด ์ธ์ ๋ ์ง 'end' ๋ผ๊ณ ๋งํด์ฃผ์ธ์~\n"
|
89 |
+
)
|
90 |
+
talk = []
|
91 |
+
switch = True
|
92 |
+
switch2 = True
|
93 |
+
while True:
|
94 |
+
inp = "P01<sos>"
|
95 |
+
myinp = input("๋น์ : ")
|
96 |
+
if myinp == "end":
|
97 |
+
print("๋ํ ์ข
๋ฃ!")
|
98 |
+
break
|
99 |
+
inp += myinp + "<eos>"
|
100 |
+
talk.append(inp)
|
101 |
+
talk.append("P02<sos>")
|
102 |
+
|
103 |
+
while True:
|
104 |
+
now_inp = info + "".join(talk)
|
105 |
+
inpu = tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
|
106 |
+
seq_len = inpu.input_ids.size(1)
|
107 |
+
if seq_len > 512 * 0.8 and switch:
|
108 |
+
print(
|
109 |
+
f"<์ฃผ์> ํ์ฌ ๋ํ ๊ธธ์ด๊ฐ ๊ณง ์ต๋ ๊ธธ์ด์ ๋๋ฌํฉ๋๋ค. ({seq_len} / 512)"
|
110 |
+
)
|
111 |
+
switch = False
|
112 |
+
|
113 |
+
if seq_len >= 512 and switch2:
|
114 |
+
print("<์ฃผ์> ๋ํ ๊ธธ์ด๊ฐ ๋๋ฌด ๊ธธ์ด์ก๊ธฐ ๋๋ฌธ์, ์ดํ ๋ํ๋ ๋งจ ์์ ๋ฐํ๋ฅผ ์กฐ๊ธ์ฉ ์ง์ฐ๋ฉด์ ์งํ๋ฉ๋๋ค.")
|
115 |
+
talk = talk[1:]
|
116 |
+
switch2 = False
|
117 |
+
else:
|
118 |
+
break
|
119 |
+
|
120 |
+
out = model.generate(
|
121 |
+
inputs=inpu.input_ids.cuda(),
|
122 |
+
attention_mask=inpu.attention_mask.cuda(),
|
123 |
+
max_length=512,
|
124 |
+
do_sample=True,
|
125 |
+
pad_token_id=tokenizer.pad_token_id,
|
126 |
+
eos_token_id=tokenizer.encode('<eos>')[0]
|
127 |
+
)
|
128 |
+
output = tokenizer.batch_decode(out)
|
129 |
+
print("์ฑ๋ด : " + output[0][len(now_inp):-5])
|
130 |
+
talk[-1] += output[0][len(now_inp):]
|
131 |
+
|
132 |
+
again = input(f"๋ค๋ฅธ ๋ํ๋ฅผ ์์ํ ๊น์? (yes : ์๋ก์ด ์์, no : ์ข
๋ฃ) :")
|
133 |
+
if again == "no":
|
134 |
+
break
|
135 |
+
|