JYYong commited on
Commit
273c30f
โ€ข
1 Parent(s): fa23262
Files changed (1) hide show
  1. app.py +132 -4
app.py CHANGED
@@ -1,7 +1,135 @@
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
+ def update(name):
4
+ return f"Welcome to Gradio, {name}!"
5
 
6
+ demo = gr.Blocks()
7
+
8
+ with demo:
9
+ gr.Markdown("Start typing below and then click **Run** to see the output.")
10
+ with gr.Row():
11
+ inp = gr.Textbox(placeholder="What is your name?")
12
+ out = gr.Textbox()
13
+ btn = gr.Button("Run")
14
+ btn.click(fn=update, inputs=inp, outputs=out)
15
+
16
+ demo.launch()
17
+
18
+
19
+ def main(model_name):
20
+ warnings.filterwarnings("ignore")
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
23
+ special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ„#', '#@๊ณ„์ •#', '#@์‹ ์›#', '#@์ „๋ฒˆ#', '#@๊ธˆ์œต#', '#@๋ฒˆํ˜ธ#', '#@์ฃผ์†Œ#', '#@์†Œ์†#', '#@๊ธฐํƒ€#']}
24
+ num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
25
+
26
+ model = AutoModelForCausalLM.from_pretrained(model_name)
27
+ model.resize_token_embeddings(len(tokenizer))
28
+ model = model.cuda()
29
+
30
+ info = ""
31
+
32
+ while True:
33
+ if info == "":
34
+ print(
35
+ f"์ง€๊ธˆ๋ถ€ํ„ฐ ๋Œ€ํ™” ์ •๋ณด๋ฅผ ์ž…๋ ฅ ๋ฐ›๊ฒ ์Šต๋‹ˆ๋‹ค.\n"
36
+ f"๊ฐ ์งˆ๋ฌธ์— ๋Œ€๋‹ต ํ›„ Enter ํ•ด์ฃผ์„ธ์š”.\n"
37
+ f"์•„๋ฌด ์ž…๋ ฅ ์—†์ด Enter ํ•  ๊ฒฝ์šฐ, ๋ฏธ๋ฆฌ ์ง€์ •๋œ ๊ฐ’ ์ค‘ ๋žœ๋ค์œผ๋กœ ์ •ํ•˜๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.\n"
38
+ )
39
+
40
+ time.sleep(1)
41
+
42
+ yon = "no"
43
+ else:
44
+ yon = input(
45
+ f"์ด์ „ ๋Œ€ํ™” ์ •๋ณด๋ฅผ ๊ทธ๋Œ€๋กœ ์œ ์ง€ํ• ๊นŒ์š”? (yes : ์œ ์ง€, no : ์ƒˆ๋กœ ์ž‘์„ฑ) :"
46
+ )
47
+
48
+ if yon == "no":
49
+ info = "์ผ์ƒ ๋Œ€ํ™” "
50
+
51
+ topic = input("๋Œ€ํ™” ์ฃผ์ œ๋ฅผ ์ •ํ•ด์ฃผ์„ธ์š” (e.g. ์—ฌ๊ฐ€ ์ƒํ™œ, ์ผ๊ณผ ์ง์—…, ๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„, etc...) :")
52
+ if topic == "":
53
+ topic = random.choice(['์—ฌ๊ฐ€ ์ƒํ™œ', '์‹œ์‚ฌ/๊ต์œก', '๋ฏธ์šฉ๊ณผ ๊ฑด๊ฐ•', '์‹์Œ๋ฃŒ', '์ƒ๊ฑฐ๋ž˜(์‡ผํ•‘)', '์ผ๊ณผ ์ง์—…', '์ฃผ๊ฑฐ์™€ ์ƒํ™œ', '๊ฐœ์ธ ๋ฐ ๊ด€๊ณ„', 'ํ–‰์‚ฌ'])
54
+ print(topic)
55
+ info += topic + "<sep>"
56
+
57
+ def ask_info(who, ment):
58
+ print(ment)
59
+ text = who + ":"
60
+ addr = input("์–ด๋”” ์‚ฌ์„ธ์š”? (e.g. ์„œ์šธํŠน๋ณ„์‹œ, ์ œ์ฃผ๋„, etc...) :").strip()
61
+ if addr == "":
62
+ addr = random.choice(['์„œ์šธํŠน๋ณ„์‹œ', '๊ฒฝ๊ธฐ๋„', '๋ถ€์‚ฐ๊ด‘์—ญ์‹œ', '๋Œ€์ „๊ด‘์—ญ์‹œ', '๊ด‘์ฃผ๊ด‘์—ญ์‹œ', '์šธ์‚ฐ๊ด‘์—ญ์‹œ', '๊ฒฝ์ƒ๋‚จ๋„', '์ธ์ฒœ๊ด‘์—ญ์‹œ', '์ถฉ์ฒญ๋ถ๋„', '์ œ์ฃผ๋„', '๊ฐ•์›๋„', '์ถฉ์ฒญ๋‚จ๋„', '์ „๋ผ๋ถ๋„', '๋Œ€๊ตฌ๊ด‘์—ญ์‹œ', '์ „๋ผ๋‚จ๋„', '๊ฒฝ์ƒ๋ถ๋„', '์„ธ์ข…ํŠน๋ณ„์ž์น˜์‹œ', '๊ธฐํƒ€'])
63
+ print(addr)
64
+ text += addr + " "
65
+
66
+ age = input("๋‚˜์ด๊ฐ€? (e.g. 20๋Œ€, 70๋Œ€ ์ด์ƒ, etc...) :").strip()
67
+ if age == "":
68
+ age = random.choice(['20๋Œ€', '30๋Œ€', '50๋Œ€', '20๋Œ€ ๋ฏธ๋งŒ', '60๋Œ€', '40๋Œ€', '70๋Œ€ ์ด์ƒ'])
69
+ print(age)
70
+ text += age + " "
71
+
72
+ sex = input("์„ฑ๋ณ„์ด? (e.g. ๋‚จ์„ฑ, ์—ฌ์„ฑ, etc... (?)) :").strip()
73
+ if sex == "":
74
+ sex = random.choice(['๋‚จ์„ฑ', '์—ฌ์„ฑ'])
75
+ print(sex)
76
+ text += sex + "<sep>"
77
+ return text
78
+
79
+ info += ask_info(who="P01", ment=f"\n๋‹น์‹ ์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”.\n")
80
+ info += ask_info(who="P02", ment=f"\n์ฑ—๋ด‡์— ๋Œ€ํ•ด ์•Œ๋ ค์ฃผ์„ธ์š”.\n")
81
+
82
+ pp = info.replace('<sep>', '\n')
83
+ print(
84
+ f"\n----------------\n"
85
+ f"<์ž…๋ ฅ ์ •๋ณด ํ™•์ธ> (P01 : ๋‹น์‹ , P02 : ์ฑ—๋ด‡)\n"
86
+ f"{pp}"
87
+ f"----------------\n"
88
+ f"๋Œ€ํ™”๋ฅผ ์ข…๋ฃŒํ•˜๊ณ  ์‹ถ์œผ๋ฉด ์–ธ์ œ๋“ ์ง€ 'end' ๋ผ๊ณ  ๋งํ•ด์ฃผ์„ธ์š”~\n"
89
+ )
90
+ talk = []
91
+ switch = True
92
+ switch2 = True
93
+ while True:
94
+ inp = "P01<sos>"
95
+ myinp = input("๋‹น์‹  : ")
96
+ if myinp == "end":
97
+ print("๋Œ€ํ™” ์ข…๋ฃŒ!")
98
+ break
99
+ inp += myinp + "<eos>"
100
+ talk.append(inp)
101
+ talk.append("P02<sos>")
102
+
103
+ while True:
104
+ now_inp = info + "".join(talk)
105
+ inpu = tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
106
+ seq_len = inpu.input_ids.size(1)
107
+ if seq_len > 512 * 0.8 and switch:
108
+ print(
109
+ f"<์ฃผ์˜> ํ˜„์žฌ ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๊ณง ์ตœ๋Œ€ ๊ธธ์ด์— ๋„๋‹ฌํ•ฉ๋‹ˆ๋‹ค. ({seq_len} / 512)"
110
+ )
111
+ switch = False
112
+
113
+ if seq_len >= 512 and switch2:
114
+ print("<์ฃผ์˜> ๋Œ€ํ™” ๊ธธ์ด๊ฐ€ ๋„ˆ๋ฌด ๊ธธ์–ด์กŒ๊ธฐ ๋•Œ๋ฌธ์—, ์ดํ›„ ๋Œ€ํ™”๋Š” ๋งจ ์•ž์˜ ๋ฐœํ™”๋ฅผ ์กฐ๊ธˆ์”ฉ ์ง€์šฐ๋ฉด์„œ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค.")
115
+ talk = talk[1:]
116
+ switch2 = False
117
+ else:
118
+ break
119
+
120
+ out = model.generate(
121
+ inputs=inpu.input_ids.cuda(),
122
+ attention_mask=inpu.attention_mask.cuda(),
123
+ max_length=512,
124
+ do_sample=True,
125
+ pad_token_id=tokenizer.pad_token_id,
126
+ eos_token_id=tokenizer.encode('<eos>')[0]
127
+ )
128
+ output = tokenizer.batch_decode(out)
129
+ print("์ฑ—๋ด‡ : " + output[0][len(now_inp):-5])
130
+ talk[-1] += output[0][len(now_inp):]
131
+
132
+ again = input(f"๋‹ค๋ฅธ ๋Œ€ํ™”๋ฅผ ์‹œ์ž‘ํ• ๊นŒ์š”? (yes : ์ƒˆ๋กœ์šด ์‹œ์ž‘, no : ์ข…๋ฃŒ) :")
133
+ if again == "no":
134
+ break
135
+