|
import streamlit as st |
|
import pandas as pd |
|
import json |
|
import requests |
|
import os |
|
|
|
glm2b = pd.read_csv('blocklm-2b-512-validation-170000-4-False-0-dialog.csv') |
|
glm10b = pd.read_csv('blocklm-10b-1024-validation-126000-4-False-0-dialog.csv') |
|
gptj = pd.read_csv('checkpoints-validation-gpt-j-6B-4-False-0-dialog.csv') |
|
gptjt = pd.read_csv('checkpoints-validation-gpt-jt-6B-4-False-0-dialog.csv') |
|
|
|
glm2b_orig = glm2b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'}) |
|
glm2b_para = glm2b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'}) |
|
glm10b_orig = glm10b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'}) |
|
glm10b_para = glm10b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'}) |
|
gptj_orig = gptj[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'}) |
|
gptj_para = gptj[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'}) |
|
gptjt_orig = gptjt[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'}) |
|
gptjt_para = gptjt[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'}) |
|
|
|
csv_map = [glm2b_orig, glm2b_para, glm10b_orig, glm10b_para, gptj_orig, gptj_para, gptjt_orig, gptjt_para] |
|
|
|
truth = glm2b['ground truth'].to_list() |
|
users = json.load(open('data/users.json')) |
|
|
|
|
|
def new_data(user_data): |
|
process_bar.progress((user_data['all_process']-user_data['start_process'])/(user_data['stop_process']-user_data['start_process']), text='进度') |
|
if user_data['all_process'] == user_data['stop_process']: |
|
return 'finish' |
|
csv_idx = user_data['model_list'][user_data['all_process']] |
|
sample = csv_map[csv_idx].iloc[user_data['data_idx'][csv_idx][user_data['process'][csv_idx]]] |
|
persona = sample.persona.split('\n') |
|
new_p = [] |
|
for pi in persona: |
|
new_p += [pi[i:i+67] for i in range(0, len(pi), 67)] |
|
new_p = '\n'.join(new_p) |
|
knowledge = sample.knowledge.split('\n') |
|
new_k = [] |
|
for ki in knowledge: |
|
new_k += [ki[i:i+67] for i in range(0, len(ki), 67)] |
|
new_k = '\n'.join(new_k) |
|
context = sample.context.split('\n') |
|
context.remove(context[-1]) |
|
new_c = [] |
|
for ci in context: |
|
new_c += [ci[i:i+67] for i in range(0, len(ci), 67)] |
|
new_c = '\n'.join(new_c) |
|
prompt = sample['prompted text'].split('\n') |
|
new_pr = [] |
|
for pri in prompt: |
|
new_pr += [pri[i:i+67] for i in range(0, len(pri), 67)] |
|
new_pr = '\n'.join(new_pr) |
|
gtruth = sample['ground truth'].split('\n') |
|
new_g = [] |
|
for gi in gtruth: |
|
new_g += [gi[i:i+67] for i in range(0, len(gi), 67)] |
|
new_g = '\n'.join(new_g) |
|
inf = sample.inference.split('\n') |
|
new_i = [] |
|
for ii in inf: |
|
new_i += [ii[i:i+67] for i in range(0, len(ii), 67)] |
|
new_i = '\n'.join(new_i) |
|
p.text(new_p) |
|
k.text(new_k) |
|
c.text(new_c) |
|
pr.text(new_pr) |
|
g.text(new_g) |
|
infer.text(new_i) |
|
return 'not finish' |
|
|
|
|
|
st.set_page_config(layout="wide") |
|
st.title('FoCus Annotation') |
|
|
|
t1, t2 = st.columns(2) |
|
with t1: |
|
username = st.text_input("请输入用户名") |
|
with t2: |
|
password = st.text_input("请输入密码", type="password") |
|
|
|
login_btn = st.button('登录') |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
with st.expander("人设"): |
|
p = st.empty() |
|
with st.expander('对话上下文'): |
|
c = st.empty() |
|
with col2: |
|
with st.expander("知识"): |
|
k = st.empty() |
|
with st.expander("Prompted Text"): |
|
pr = st.empty() |
|
|
|
a1, a2 = st.columns(2) |
|
with a1: |
|
st.markdown("**真实值**") |
|
g = st.empty() |
|
with a2: |
|
st.markdown("**待标注样本**") |
|
infer = st.empty() |
|
|
|
cc, kc, pc, hc, fc = st.columns(5) |
|
with st.container(): |
|
with cc: |
|
cs = st.selectbox("对话一致性", [0,1,2], key='cs') |
|
with kc: |
|
ks = st.selectbox("知识一致性", [0,1,2], key='ks') |
|
with pc: |
|
ps = st.selectbox("人设一致性", [0,1,2], key='ps') |
|
with hc: |
|
hs = st.selectbox("幻觉现象", [0,1], key='hs') |
|
with fc: |
|
fs = st.selectbox("流畅度", [0,1,2], key='fs') |
|
|
|
process_bar = st.progress(0.0, text='进度') |
|
|
|
col3, col4 = st.columns(2) |
|
with st.container(): |
|
with col3: |
|
prev = st.button('上一个') |
|
with col4: |
|
succ = st.button('下一个') |
|
|
|
|
|
if username in users and users[username] == password: |
|
data = {'FocusUser': username} |
|
user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content |
|
user_data = json.loads(str(user_data, encoding="utf-8")) |
|
|
|
result = new_data(user_data) |
|
else: |
|
p.text("登录后开始标注") |
|
c.text("登录后开始标注") |
|
pr.text("登录后开始标注") |
|
g.text("登录后开始标注") |
|
k.text("登录后开始标注") |
|
infer.text("登录后开始标注") |
|
|
|
|
|
if login_btn: |
|
if username in users and users[username] == password: |
|
st.success('登录成功') |
|
data = {'FocusUser': username} |
|
user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content |
|
user_data = json.loads(str(user_data, encoding="utf-8")) |
|
|
|
result = new_data(user_data) |
|
if result == 'finish': |
|
st.success('您已完成标注') |
|
else: |
|
username = '' |
|
password = '' |
|
st.error('用户名或密码错误,请先注册。若已有账号,但忘记密码,请联系管理员修改密码') |
|
|
|
|
|
if succ: |
|
if username in users and users[username] == password: |
|
data = {'FocusUser': username} |
|
user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content |
|
user_data = json.loads(str(user_data, encoding="utf-8")) |
|
|
|
data_idx, process, all_process, model_list = user_data['data_idx'], user_data['process'], user_data['all_process'], user_data['model_list'] |
|
if all_process == user_data['stop_process']: |
|
st.success('您已完成标注') |
|
else: |
|
csv_idx = model_list[all_process] |
|
sample = csv_map[csv_idx].iloc[data_idx[csv_idx][process[csv_idx]]] |
|
user_data['context_relevance'][csv_idx][process[csv_idx]%250] = cs |
|
user_data['knowledge_relevance'][csv_idx][process[csv_idx]%250] = ks |
|
user_data['persona_consistency'][csv_idx][process[csv_idx]%250] = ps |
|
user_data['hallucination'][csv_idx][process[csv_idx]%250] = hs |
|
user_data['fluency'][csv_idx][process[csv_idx]%250] = fs |
|
user_data['process'][csv_idx] += 1 |
|
user_data['all_process'] += 1 |
|
data = {'Focus': user_data, 'username': username} |
|
requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')) |
|
|
|
result = new_data(user_data) |
|
if result == 'finish': |
|
st.success('您已完成标注') |
|
else: |
|
st.error('请先登录') |
|
|
|
|
|
if prev: |
|
if username in users and users[username] == password: |
|
data = {'FocusUser': username} |
|
user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content |
|
user_data = json.loads(str(user_data, encoding="utf-8")) |
|
|
|
model_list = user_data['model_list'] |
|
if user_data['all_process'] == user_data['start_process']: |
|
result = new_data(user_data) |
|
st.error('已是首个数据') |
|
else: |
|
user_data['all_process'] -= 1 |
|
csv_idx = model_list[user_data['all_process']] |
|
user_data['process'][csv_idx] -= 1 |
|
result = new_data(user_data) |
|
data = {'Focus': user_data, 'username': username} |
|
requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')) |
|
|
|
else: |
|
st.error('请先登录') |
|
|