JYYong commited on
Commit
4cb51c5
β€’
1 Parent(s): 73dae5e

maybe complete

Browse files
Files changed (2) hide show
  1. app.py +106 -136
  2. flagged/log.csv +4 -0
app.py CHANGED
@@ -1,144 +1,114 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def update(name):
4
- return f"Welcome to Gradio, {name}!"
5
-
6
- demo = gr.Blocks()
7
-
8
- with demo:
9
- gr.Markdown(f"각 μ§ˆλ¬Έμ— λŒ€λ‹΅ ν›„ Enter ν•΄μ£Όμ„Έμš”.\n\n")
10
- with gr.Row():
11
- topic = gr.Textbox(label="Topic", placeholder="λŒ€ν™” 주제λ₯Ό μ •ν•΄μ£Όμ„Έμš” (e.g. μ—¬κ°€ μƒν™œ, 일과 직업, 개인 및 관계, etc...)")
12
- with gr.Row():
13
- with gr.Column():
14
- addr = gr.Textbox(label="지역", placeholder="e.g. μ—¬κ°€ μƒν™œ, 일과 직업, 개인 및 관계, etc...")
15
- age = gr.Textbox(label="λ‚˜μ΄", placeholder="e.g. 20λŒ€ 미만, 40λŒ€, 70λŒ€ 이상, etc...")
16
- sex = gr.Textbox(label="성별", placeholder="e.g. 남성, μ—¬μ„±, etc...")
17
- with gr.Column():
18
- addr = gr.Textbox(label="지역", placeholder="e.g. μ—¬κ°€ μƒν™œ, 일과 직업, 개인 및 관계, etc...")
19
- age = gr.Textbox(label="λ‚˜μ΄", placeholder="e.g. 20λŒ€ 미만, 40λŒ€, 70λŒ€ 이상, etc...")
20
- sex = gr.Textbox(label="성별", placeholder="e.g. 남성, μ—¬μ„±, etc...")
21
- out = gr.Textbox()
22
- btn = gr.Button("Run")
23
- # btn.click(fn=update, inputs=inp, outputs=out)
24
-
25
- demo.launch()
 
 
26
 
27
 
28
- def main(model_name):
29
  warnings.filterwarnings("ignore")
30
 
31
- tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
32
- special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@이름#', '#@계정#', '#@신원#', '#@μ „λ²ˆ#', '#@금육#', '#@번호#', '#@μ£Όμ†Œ#', '#@μ†Œμ†#', '#@기타#']}
33
- num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
34
-
35
- model = AutoModelForCausalLM.from_pretrained(model_name)
36
- model.resize_token_embeddings(len(tokenizer))
37
- model = model.cuda()
38
-
39
- info = ""
40
-
41
- while True:
42
- if info == "":
43
- print(
44
- f"μ§€κΈˆλΆ€ν„° λŒ€ν™” 정보λ₯Ό μž…λ ₯ λ°›κ² μŠ΅λ‹ˆλ‹€.\n"
45
- f"각 μ§ˆλ¬Έμ— λŒ€λ‹΅ ν›„ Enter ν•΄μ£Όμ„Έμš”.\n"
46
- f"아무 μž…λ ₯ 없이 Enter ν•  경우, 미리 μ§€μ •λœ κ°’ 쀑 랜덀으둜 μ •ν•˜κ²Œ λ©λ‹ˆλ‹€.\n"
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- time.sleep(1)
50
 
51
- yon = "no"
52
- else:
53
- yon = input(
54
- f"이전 λŒ€ν™” 정보λ₯Ό κ·ΈλŒ€λ‘œ μœ μ§€ν• κΉŒμš”? (yes : μœ μ§€, no : μƒˆλ‘œ μž‘μ„±) :"
55
- )
56
-
57
- if yon == "no":
58
- info = "일상 λŒ€ν™” "
59
-
60
- topic = input("λŒ€ν™” 주제λ₯Ό μ •ν•΄μ£Όμ„Έμš” (e.g. μ—¬κ°€ μƒν™œ, 일과 직업, 개인 및 관계, etc...) :")
61
- if topic == "":
62
- topic = random.choice(['μ—¬κ°€ μƒν™œ', 'μ‹œμ‚¬/ꡐ윑', '미용과 건강', 'μ‹μŒλ£Œ', 'μƒκ±°λž˜(μ‡Όν•‘)', '일과 직업', '주거와 μƒν™œ', '개인 및 관계', '행사'])
63
- print(topic)
64
- info += topic + "<sep>"
65
-
66
- def ask_info(who, ment):
67
- print(ment)
68
- text = who + ":"
69
- addr = input("μ–΄λ”” μ‚¬μ„Έμš”? (e.g. μ„œμšΈνŠΉλ³„μ‹œ, μ œμ£Όλ„, etc...) :").strip()
70
- if addr == "":
71
- addr = random.choice(['μ„œμšΈνŠΉλ³„μ‹œ', '경기도', 'λΆ€μ‚°κ΄‘μ—­μ‹œ', 'λŒ€μ „κ΄‘μ—­μ‹œ', 'κ΄‘μ£Όκ΄‘μ—­μ‹œ', 'μšΈμ‚°κ΄‘μ—­μ‹œ', '경상남도', 'μΈμ²œκ΄‘μ—­μ‹œ', '좩청뢁도', 'μ œμ£Όλ„', '강원도', '좩청남도', '전라뢁도', 'λŒ€κ΅¬κ΄‘μ—­μ‹œ', '전라남도', '경상뢁도', 'μ„Έμ’…νŠΉλ³„μžμΉ˜μ‹œ', '기타'])
72
- print(addr)
73
- text += addr + " "
74
-
75
- age = input("λ‚˜μ΄κ°€? (e.g. 20λŒ€, 70λŒ€ 이상, etc...) :").strip()
76
- if age == "":
77
- age = random.choice(['20λŒ€', '30λŒ€', '50λŒ€', '20λŒ€ 미만', '60λŒ€', '40λŒ€', '70λŒ€ 이상'])
78
- print(age)
79
- text += age + " "
80
-
81
- sex = input("성별이? (e.g. 남성, μ—¬μ„±, etc... (?)) :").strip()
82
- if sex == "":
83
- sex = random.choice(['남성', 'μ—¬μ„±'])
84
- print(sex)
85
- text += sex + "<sep>"
86
- return text
87
-
88
- info += ask_info(who="P01", ment=f"\n당신에 λŒ€ν•΄ μ•Œλ €μ£Όμ„Έμš”.\n")
89
- info += ask_info(who="P02", ment=f"\n챗봇에 λŒ€ν•΄ μ•Œλ €μ£Όμ„Έμš”.\n")
90
-
91
- pp = info.replace('<sep>', '\n')
92
- print(
93
- f"\n----------------\n"
94
- f"<μž…λ ₯ 정보 확인> (P01 : λ‹Ήμ‹ , P02 : 챗봇)\n"
95
- f"{pp}"
96
- f"----------------\n"
97
- f"λŒ€ν™”λ₯Ό μ’…λ£Œν•˜κ³  μ‹ΆμœΌλ©΄ μ–Έμ œλ“ μ§€ 'end' 라고 λ§ν•΄μ£Όμ„Έμš”~\n"
98
- )
99
- talk = []
100
- switch = True
101
- switch2 = True
102
- while True:
103
- inp = "P01<sos>"
104
- myinp = input("λ‹Ήμ‹  : ")
105
- if myinp == "end":
106
- print("λŒ€ν™” μ’…λ£Œ!")
107
- break
108
- inp += myinp + "<eos>"
109
- talk.append(inp)
110
- talk.append("P02<sos>")
111
-
112
- while True:
113
- now_inp = info + "".join(talk)
114
- inpu = tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
115
- seq_len = inpu.input_ids.size(1)
116
- if seq_len > 512 * 0.8 and switch:
117
- print(
118
- f"<주의> ν˜„μž¬ λŒ€ν™” 길이가 곧 μ΅œλŒ€ 길이에 λ„λ‹¬ν•©λ‹ˆλ‹€. ({seq_len} / 512)"
119
- )
120
- switch = False
121
-
122
- if seq_len >= 512 and switch2:
123
- print("<주의> λŒ€ν™” 길이가 λ„ˆλ¬΄ κΈΈμ–΄μ‘ŒκΈ° λ•Œλ¬Έμ—, 이후 λŒ€ν™”λŠ” 맨 μ•žμ˜ λ°œν™”λ₯Ό μ‘°κΈˆμ”© μ§€μš°λ©΄μ„œ μ§„ν–‰λ©λ‹ˆλ‹€.")
124
- talk = talk[1:]
125
- switch2 = False
126
- else:
127
- break
128
-
129
- out = model.generate(
130
- inputs=inpu.input_ids.cuda(),
131
- attention_mask=inpu.attention_mask.cuda(),
132
- max_length=512,
133
- do_sample=True,
134
- pad_token_id=tokenizer.pad_token_id,
135
- eos_token_id=tokenizer.encode('<eos>')[0]
136
- )
137
- output = tokenizer.batch_decode(out)
138
- print("챗봇 : " + output[0][len(now_inp):-5])
139
- talk[-1] += output[0][len(now_inp):]
140
-
141
- again = input(f"λ‹€λ₯Έ λŒ€ν™”λ₯Ό μ‹œμž‘ν• κΉŒμš”? (yes : μƒˆλ‘œμš΄ μ‹œμž‘, no : μ’…λ£Œ) :")
142
- if again == "no":
143
- break
144
-
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import warnings
4
+
5
+
6
+ class Chatbot():
7
+ def __init__(self):
8
+ self.tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b')
9
+ special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@이름#', '#@계정#', '#@신원#', '#@μ „λ²ˆ#', '#@금육#', '#@번호#', '#@μ£Όμ†Œ#', '#@μ†Œμ†#', '#@기타#']}
10
+ num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict)
11
+
12
+ self.model = AutoModelForCausalLM.from_pretrained("/workspace/test_trainer/checkpoint-10000")
13
+ self.model.resize_token_embeddings(len(self.tokenizer))
14
+ self.model = self.model.cuda()
15
+
16
+ self.info = None
17
+ self.talk = []
18
+
19
+ def initialize(self, topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex):
20
+ def encode(age):
21
+ if age < 20:
22
+ age = "20λŒ€ 미만"
23
+ elif age >= 70:
24
+ age = "70λŒ€ 이상"
25
+ else:
26
+ age = str(age // 10 * 10) + "λŒ€"
27
+ return age
28
+ bot_age = encode(bot_age)
29
+ my_age = encode(my_age)
30
+ self.info = f"일상 λŒ€ν™” {topic}<sep>P01:{my_addr} {my_age} {my_sex}<sep>P02:{bot_addr} {bot_age} {bot_sex}<sep>"
31
+ return self.info_check()
32
+
33
+ def info_check(self):
34
+ return self.info.replace('<sep>', '\n').replace('P01', 'λ‹Ήμ‹ ').replace('P02', '챗봇')
35
+
36
+ def reset_talk(self):
37
+ self.talk = []
38
+
39
+ def test(self, myinp):
40
+ state = None
41
+ inp = "P01<sos>" + myinp + "<eos>"
42
+ self.talk.append(inp)
43
+ self.talk.append("P02<sos>")
44
 
45
+ while True:
46
+ now_inp = self.info + "".join(self.talk)
47
+ inputs = self.tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt')
48
+ seq_len = inputs.input_ids.size(1)
49
+ if seq_len > 512 * 0.8:
50
+ state = f"<주의> ν˜„μž¬ λŒ€ν™” 길이가 곧 μ΅œλŒ€ 길이에 λ„λ‹¬ν•©λ‹ˆλ‹€. ({seq_len} / 512)"
51
+
52
+ if seq_len >= 512:
53
+ state = "<주의> λŒ€ν™” 길이가 λ„ˆλ¬΄ κΈΈμ–΄μ‘ŒκΈ° λ•Œλ¬Έμ—, 이후 λŒ€ν™”λŠ” 맨 μ•žμ˜ λ°œν™”λ₯Ό μ‘°κΈˆμ”© μ§€μš°λ©΄μ„œ μ§„ν–‰λ©λ‹ˆλ‹€."
54
+ talk = talk[1:]
55
+ else:
56
+ break
57
+
58
+ out = self.model.generate(
59
+ inputs=inputs.input_ids.cuda(),
60
+ attention_mask=inputs.attention_mask.cuda(),
61
+ max_length=512,
62
+ do_sample=True,
63
+ pad_token_id=self.tokenizer.pad_token_id,
64
+ eos_token_id=self.tokenizer.encode('<eos>')[0]
65
+ )
66
+ out = self.tokenizer.batch_decode(out)
67
+ real_out = out[0][len(now_inp):-5]
68
+ self.talk[-1] += out[0][len(now_inp):]
69
+ return [(self.talk[i][8:-5], self.talk[i+1][8:-5]) for i in range(0, len(self.talk)-1, 2)]
70
 
71
 
72
+ if __name__ == "__main__":
73
  warnings.filterwarnings("ignore")
74
 
75
+ chatbot = Chatbot()
76
+ demo = gr.Blocks()
77
+
78
+ with demo:
79
+ gr.Markdown("# <center>MINDs Lab Brain's Fast Neural Chit-Chatbot</center>")
80
+ with gr.Row():
81
+ with gr.Column():
82
+ topic = gr.Radio(label="Topic", choices=['μ—¬κ°€ μƒν™œ', 'μ‹œμ‚¬/ꡐ윑', '미용과 건강', 'μ‹μŒλ£Œ', 'μƒκ±°λž˜(μ‡Όν•‘)', '일과 직업', '주거와 μƒν™œ', '개인 및 관계', '행사'])
83
+ with gr.Column():
84
+ gr.Markdown(f"Bot's persona")
85
+ bot_addr = gr.Dropdown(label="지역", choices=['μ„œμšΈνŠΉλ³„μ‹œ', '경기도', 'λΆ€μ‚°κ΄‘μ—­μ‹œ', 'λŒ€μ „κ΄‘μ—­μ‹œ', 'κ΄‘μ£Όκ΄‘μ—­μ‹œ', 'μšΈμ‚°κ΄‘μ—­μ‹œ', '경상남도', 'μΈμ²œκ΄‘μ—­μ‹œ', '좩청뢁도', 'μ œμ£Όλ„', '강원도', '좩청남도', '전라뢁도', 'λŒ€κ΅¬κ΄‘μ—­μ‹œ', '전라남도', '경상뢁도', 'μ„Έμ’…νŠΉλ³„μžμΉ˜μ‹œ', '기타'])
86
+ bot_age = gr.Slider(label="λ‚˜μ΄", minimum=10, maximum=80, value=45, step=1)
87
+ bot_sex = gr.Radio(label="성별", choices=["남성", "μ—¬μ„±"])
88
+ with gr.Column():
89
+ gr.Markdown(f"Your persona")
90
+ my_addr = gr.Dropdown(label="지역", choices=['μ„œμšΈνŠΉλ³„μ‹œ', '경기도', 'λΆ€μ‚°κ΄‘μ—­μ‹œ', 'λŒ€μ „κ΄‘μ—­μ‹œ', 'κ΄‘μ£Όκ΄‘μ—­μ‹œ', 'μšΈμ‚°κ΄‘μ—­μ‹œ', '경상남도', 'μΈμ²œκ΄‘μ—­μ‹œ', '좩청뢁도', 'μ œμ£Όλ„', '강원도', '좩청남도', '전라뢁도', 'λŒ€κ΅¬κ΄‘μ—­μ‹œ', '전라남도', '경상뢁도', 'μ„Έμ’…νŠΉλ³„μžμΉ˜μ‹œ', '기타'])
91
+ my_age = gr.Slider(label="λ‚˜μ΄", minimum=10, maximum=80, value=45, step=1)
92
+ my_sex = gr.Radio(label="성별", choices=["남성", "μ—¬μ„±"])
93
+ with gr.Row():
94
+ btn = gr.Button(label="적용")
95
+ state = gr.Textbox(label="μƒνƒœ")
96
+ btn.click(
97
+ fn=chatbot.initialize,
98
+ inputs=[topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex],
99
+ outputs=state
100
+ )
101
+
102
+ with gr.Column():
103
+ screen = gr.Chatbot(label="읡λͺ…μ˜ μƒλŒ€")
104
+ with gr.Row():
105
+ speak = gr.Textbox(label="μž…λ ₯μ°½")
106
+ btn = gr.Button(label="Talk")
107
+ btn.click(
108
+ fn=chatbot.test,
109
+ inputs=speak,
110
+ outputs=screen
111
+ )
112
+ demo.launch(share=True)
113
 
 
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flagged/log.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
1
+ 'self','output','flag','username','timestamp'
2
+ 'λ­ν•˜κ³  κ³„μ„Έμš”?','[(''μ•ˆλ…•ν•˜μ„Έμš”'', ''λ„΅''), (''λ­ν•˜κ³  κ³„μ„Έμš”?'', ''μ € κ²Œμž„ν•˜λ©΄μ„œ μžˆμ–΄μš©'')]','','','2022-06-29 07:59:03.609856'
3
+ 'λ­ν•˜κ³  κ³„μ„Έμš”?','[(''μ•ˆλ…•ν•˜μ„Έμš”'', ''λ„΅''), (''λ­ν•˜κ³  κ³„μ„Έμš”?'', ''μ € κ²Œμž„ν•˜λ©΄μ„œ μžˆμ–΄μš©'')]','','','2022-06-29 07:59:07.265460'
4
+ 'μ•ˆλ…•ν•˜μ„Έμš”?','[[''μ•ˆλ…•ν•˜μ„Έμš”?'', ''μ•„λ‹ˆ'']]','','','2022-06-29 08:15:33.284872'