mkw18 commited on
Commit
71b0571
1 Parent(s): 3c81662

first commit

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import json
4
+ import requests
5
+ import os
6
+
7
+ glm2b = pd.read_csv('blocklm-2b-512-validation-170000-4-False-0-dialog.csv')
8
+ glm10b = pd.read_csv('blocklm-10b-1024-validation-126000-4-False-0-dialog.csv')
9
+ gptj = pd.read_csv('checkpoints-validation-gpt-j-6B-4-False-0-dialog.csv')
10
+ gptjt = pd.read_csv('checkpoints-validation-gpt-jt-6B-4-False-0-dialog.csv')
11
+
12
+ glm2b_orig = glm2b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'})
13
+ glm2b_para = glm2b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'})
14
+ glm10b_orig = glm10b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'})
15
+ glm10b_para = glm10b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'})
16
+ gptj_orig = gptj[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'})
17
+ gptj_para = gptj[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'})
18
+ gptjt_orig = gptjt[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'})
19
+ gptjt_para = gptjt[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'})
20
+
21
+ csv_map = [glm2b_orig, glm2b_para, glm10b_orig, glm10b_para, gptj_orig, gptj_para, gptjt_orig, gptjt_para]
22
+
23
+ truth = glm2b['ground truth'].to_list()
24
+ users = json.load(open('data/users.json'))
25
+
26
+
27
+ def new_data(user_data):
28
+ process_bar.progress((user_data['all_process']-user_data['start_process'])/(user_data['stop_process']-user_data['start_process']), text='进度')
29
+ if user_data['all_process'] == user_data['stop_process']:
30
+ return 'finish'
31
+ csv_idx = user_data['model_list'][user_data['all_process']]
32
+ sample = csv_map[csv_idx].iloc[user_data['data_idx'][csv_idx][user_data['process'][csv_idx]]]
33
+ persona = sample.persona.split('\n')
34
+ new_p = []
35
+ for pi in persona:
36
+ new_p += [pi[i:i+67] for i in range(0, len(pi), 67)]
37
+ new_p = '\n'.join(new_p)
38
+ knowledge = sample.knowledge.split('\n')
39
+ new_k = []
40
+ for ki in knowledge:
41
+ new_k += [ki[i:i+67] for i in range(0, len(ki), 67)]
42
+ new_k = '\n'.join(new_k)
43
+ context = sample.context.split('\n')
44
+ context.remove(context[-1])
45
+ new_c = []
46
+ for ci in context:
47
+ new_c += [ci[i:i+67] for i in range(0, len(ci), 67)]
48
+ new_c = '\n'.join(new_c)
49
+ prompt = sample['prompted text'].split('\n')
50
+ new_pr = []
51
+ for pri in prompt:
52
+ new_pr += [pri[i:i+67] for i in range(0, len(pri), 67)]
53
+ new_pr = '\n'.join(new_pr)
54
+ gtruth = sample['ground truth'].split('\n')
55
+ new_g = []
56
+ for gi in gtruth:
57
+ new_g += [gi[i:i+67] for i in range(0, len(gi), 67)]
58
+ new_g = '\n'.join(new_g)
59
+ inf = sample.inference.split('\n')
60
+ new_i = []
61
+ for ii in inf:
62
+ new_i += [ii[i:i+67] for i in range(0, len(ii), 67)]
63
+ new_i = '\n'.join(new_i)
64
+ p.text(new_p)
65
+ k.text(new_k)
66
+ c.text(new_c)
67
+ pr.text(new_pr)
68
+ g.text(new_g)
69
+ infer.text(new_i)
70
+ return 'not finish'
71
+
72
+
73
+ st.set_page_config(layout="wide")
74
+ st.title('FoCus Annotation')
75
+
76
+ t1, t2 = st.columns(2)
77
+ with t1:
78
+ username = st.text_input("请输入用户名")
79
+ with t2:
80
+ password = st.text_input("请输入密码", type="password")
81
+
82
+ login_btn = st.button('登录')
83
+
84
+ col1, col2 = st.columns(2)
85
+ with col1:
86
+ with st.expander("人设"):
87
+ p = st.empty()
88
+ with st.expander('对话上下文'):
89
+ c = st.empty()
90
+ with col2:
91
+ with st.expander("知识"):
92
+ k = st.empty()
93
+ with st.expander("Prompted Text"):
94
+ pr = st.empty()
95
+
96
+ a1, a2 = st.columns(2)
97
+ with a1:
98
+ st.markdown("**真实值**")
99
+ g = st.empty()
100
+ with a2:
101
+ st.markdown("**待标注样本**")
102
+ infer = st.empty()
103
+
104
+ cc, kc, pc, hc, fc = st.columns(5)
105
+ with st.container():
106
+ with cc:
107
+ cs = st.selectbox("对话一致性", [0,1,2], key='cs')
108
+ with kc:
109
+ ks = st.selectbox("知识一致性", [0,1,2], key='ks')
110
+ with pc:
111
+ ps = st.selectbox("人设一致性", [0,1,2], key='ps')
112
+ with hc:
113
+ hs = st.selectbox("幻觉现象", [0,1], key='hs')
114
+ with fc:
115
+ fs = st.selectbox("流畅度", [0,1,2], key='fs')
116
+
117
+ process_bar = st.progress(0.0, text='进度')
118
+
119
+ col3, col4 = st.columns(2)
120
+ with st.container():
121
+ with col3:
122
+ prev = st.button('上一个')
123
+ with col4:
124
+ succ = st.button('下一个')
125
+
126
+
127
+ if username in users and users[username] == password:
128
+ data = {'FocusUser': username}
129
+ user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content
130
+ user_data = json.loads(str(user_data, encoding="utf-8"))
131
+ # user_data = json.load(open(f'data/{username}.json'))
132
+ result = new_data(user_data)
133
+ else:
134
+ p.text("登录后开始标注")
135
+ c.text("登录后开始标注")
136
+ pr.text("登录后开始标注")
137
+ g.text("登录后开始标注")
138
+ k.text("登录后开始标注")
139
+ infer.text("登录后开始标注")
140
+
141
+
142
+ if login_btn:
143
+ if username in users and users[username] == password:
144
+ st.success('登录成功')
145
+ data = {'FocusUser': username}
146
+ user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content
147
+ user_data = json.loads(str(user_data, encoding="utf-8"))
148
+ # user_data = json.load(open(f'data/{username}.json'))
149
+ result = new_data(user_data)
150
+ if result == 'finish':
151
+ st.success('您已完成标注')
152
+ else:
153
+ username = ''
154
+ password = ''
155
+ st.error('用户名或密码错误,请先注册。若已有账号,但忘记密码,请联系管理员修改密码')
156
+
157
+
158
+ if succ:
159
+ if username in users and users[username] == password:
160
+ data = {'FocusUser': username}
161
+ user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content
162
+ user_data = json.loads(str(user_data, encoding="utf-8"))
163
+ # user_data = json.load(open(f'data/{username}.json'))
164
+ data_idx, process, all_process, model_list = user_data['data_idx'], user_data['process'], user_data['all_process'], user_data['model_list']
165
+ if all_process == user_data['stop_process']:
166
+ st.success('您已完成标注')
167
+ else:
168
+ csv_idx = model_list[all_process]
169
+ sample = csv_map[csv_idx].iloc[data_idx[csv_idx][process[csv_idx]]]
170
+ user_data['context_relevance'][csv_idx][process[csv_idx]%250] = cs
171
+ user_data['knowledge_relevance'][csv_idx][process[csv_idx]%250] = ks
172
+ user_data['persona_consistency'][csv_idx][process[csv_idx]%250] = ps
173
+ user_data['hallucination'][csv_idx][process[csv_idx]%250] = hs
174
+ user_data['fluency'][csv_idx][process[csv_idx]%250] = fs
175
+ user_data['process'][csv_idx] += 1
176
+ user_data['all_process'] += 1
177
+ data = {'Focus': user_data, 'username': username}
178
+ requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8'))
179
+ # json.dump(user_data, open(f'data/{username}.json', 'w'), ensure_ascii=False, indent=2)
180
+ result = new_data(user_data)
181
+ if result == 'finish':
182
+ st.success('您已完成标注')
183
+ else:
184
+ st.error('请先登录')
185
+
186
+
187
+ if prev:
188
+ if username in users and users[username] == password:
189
+ data = {'FocusUser': username}
190
+ user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content
191
+ user_data = json.loads(str(user_data, encoding="utf-8"))
192
+ # user_data = json.load(open(f'data/{username}.json'))
193
+ model_list = user_data['model_list']
194
+ if user_data['all_process'] == user_data['start_process']:
195
+ result = new_data(user_data)
196
+ st.error('已是首个数据')
197
+ else:
198
+ user_data['all_process'] -= 1
199
+ csv_idx = model_list[user_data['all_process']]
200
+ user_data['process'][csv_idx] -= 1
201
+ result = new_data(user_data)
202
+ data = {'Focus': user_data, 'username': username}
203
+ requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8'))
204
+ # json.dump(user_data, open(f'data/{username}.json', 'w'), ensure_ascii=False, indent=2)
205
+ else:
206
+ st.error('请先登录')
blocklm-10b-1024-validation-126000-4-False-0-dialog.csv ADDED
The diff for this file is too large to render. See raw diff
 
blocklm-2b-512-validation-170000-4-False-0-dialog.csv ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints-validation-gpt-j-6B-4-False-0-dialog.csv ADDED
The diff for this file is too large to render. See raw diff
 
checkpoints-validation-gpt-jt-6B-4-False-0-dialog.csv ADDED
The diff for this file is too large to render. See raw diff
 
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/users.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "focus0": "focus0",
3
+ "focus1": "focus1",
4
+ "focus2": "focus2",
5
+ "focus3": "focus3",
6
+ "focus4": "focus4",
7
+ "focus5": "focus5",
8
+ "focus6": "focus6",
9
+ "focus7": "focus7",
10
+ "focus8": "focus8",
11
+ "focus9": "focus9"
12
+ }