import streamlit as st import pandas as pd import json import requests import os glm2b = pd.read_csv('glm2b-nk.csv') glm10b = pd.read_csv('glm10b-nk.csv') gptj = pd.read_csv('gptj-nk.csv') gptjt = pd.read_csv('gptjt-nk.csv') glm2b_orig = glm2b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'}) glm2b_para = glm2b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'}) glm10b_orig = glm10b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'}) glm10b_para = glm10b[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'}) gptj_orig = gptj[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'}) gptj_para = gptj[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'}) gptjt_orig = gptjt[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'origin']].rename(columns={'origin': 'inference'}) gptjt_para = gptjt[['persona', 'knowledge', 'context', 'prompted text', 'ground truth', 'parallel']].rename(columns={'parallel': 'inference'}) csv_map = [glm2b_orig, glm2b_para, glm10b_orig, glm10b_para, gptj_orig, gptj_para, gptjt_orig, gptjt_para] truth = glm2b['ground truth'].to_list() users = json.load(open('data/users.json')) def new_data(user_data): process_bar.progress((user_data['all_process']-user_data['start_process'])/(user_data['stop_process']-user_data['start_process']), text='进度') if user_data['all_process'] == user_data['stop_process']: return 'finish' csv_idx = user_data['model_list'][user_data['all_process']] sample = csv_map[csv_idx].iloc[user_data['data_idx'][csv_idx][user_data['process'][csv_idx]]] persona = sample.persona.split('\n') new_p = [] for pi in persona: new_p += [pi[i:i+60] for i in range(0, len(pi), 60)] new_p = '\n'.join(new_p) knowledge = sample.knowledge.split('\n') new_k = [] for ki in knowledge: new_k += [ki[i:i+60] for i in range(0, len(ki), 60)] new_k = '\n'.join(new_k) context = sample.context.split('\n') context.remove(context[-1]) new_c = [] for ci in context: new_c += [ci[i:i+120] for i in range(0, len(ci), 120)] new_c = '\n'.join(new_c) gtruth = sample['ground truth'].split('\n') new_g = [] for gi in gtruth: new_g += [gi[i:i+60] for i in range(0, len(gi), 60)] new_g = '\n'.join(new_g) inf = sample.inference.split('\n') new_i = [] for ii in inf: new_i += [ii[i:i+60] for i in range(0, len(ii), 60)] new_i = '\n'.join(new_i) p.text(new_p) k.text(new_k) c.text(new_c) g.text(new_g) infer.text(new_i) return 'not finish' st.set_page_config(layout="wide") st.title('FoCus Annotation') t1, t2 = st.columns(2) with t1: username = st.text_input("请输入用户名") with t2: password = st.text_input("请输入密码", type="password") login_btn = st.button('登录') col1, col2 = st.columns(2) with col1: with st.expander("人设"): p = st.empty() with col2: with st.expander("知识"): k = st.empty() with st.expander('对话上下文'): c = st.empty() a1, a2 = st.columns(2) with a1: st.markdown("**真实值**") g = st.empty() with a2: st.markdown("**待标注样本**") infer = st.empty() cc, kc, pc, hc, fc = st.columns(5) with st.container(): with cc: cs = st.selectbox("对话一致性", [0,1,2], key='cs') with kc: ks = st.selectbox("知识一致性", [0,1,2], key='ks') with pc: ps = st.selectbox("人设一致性", [0,1,2], key='ps') with hc: hs = st.selectbox("精炼度", [0,1,2], key='hs') with fc: fs = st.selectbox("流畅度", [0,1,2], key='fs') process_bar = st.progress(0.0, text='进度') col3, col4 = st.columns(2) with st.container(): with col3: prev = st.button('上一个') with col4: succ = st.button('下一个') if username in users and users[username] == password: data = {'FocusUser': username} user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content user_data = json.loads(str(user_data, encoding="utf-8")) # user_data = json.load(open(f'data/{username}.json')) result = new_data(user_data) else: p.text("登录后开始标注") c.text("登录后开始标注") g.text("登录后开始标注") k.text("登录后开始标注") infer.text("登录后开始标注") if login_btn: if username in users and users[username] == password: st.success('登录成功') data = {'FocusUser': username} user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content user_data = json.loads(str(user_data, encoding="utf-8")) # user_data = json.load(open(f'data/{username}.json')) result = new_data(user_data) if result == 'finish': st.success('您已完成标注') else: username = '' password = '' st.error('用户名或密码错误,请先注册。若已有账号,但忘记密码,请联系管理员修改密码') if succ: if username in users and users[username] == password: data = {'FocusUser': username} user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content user_data = json.loads(str(user_data, encoding="utf-8")) # user_data = json.load(open(f'data/{username}.json')) data_idx, process, all_process, model_list = user_data['data_idx'], user_data['process'], user_data['all_process'], user_data['model_list'] if all_process == user_data['stop_process']: st.success('您已完成标注') else: csv_idx = model_list[all_process] sample = csv_map[csv_idx].iloc[data_idx[csv_idx][process[csv_idx]]] user_data['context_relevance'][csv_idx][process[csv_idx]%100] = cs user_data['knowledge_relevance'][csv_idx][process[csv_idx]%100] = ks user_data['persona_consistency'][csv_idx][process[csv_idx]%100] = ps user_data['hallucination'][csv_idx][process[csv_idx]%100] = hs user_data['fluency'][csv_idx][process[csv_idx]%100] = fs user_data['process'][csv_idx] += 1 user_data['all_process'] += 1 data = {'Focus': user_data, 'username': username} requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')) # json.dump(user_data, open(f'data/{username}.json', 'w'), ensure_ascii=False, indent=2) result = new_data(user_data) if result == 'finish': st.success('您已完成标注') else: st.error('请先登录') if prev: if username in users and users[username] == password: data = {'FocusUser': username} user_data=requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')).content user_data = json.loads(str(user_data, encoding="utf-8")) # user_data = json.load(open(f'data/{username}.json')) model_list = user_data['model_list'] if user_data['all_process'] == user_data['start_process']: result = new_data(user_data) st.error('已是首个数据') else: user_data['all_process'] -= 1 csv_idx = model_list[user_data['all_process']] user_data['process'][csv_idx] -= 1 result = new_data(user_data) data = {'Focus': user_data, 'username': username} requests.post(os.environ.get("URL"), data=json.dumps(data, ensure_ascii=False).encode('utf-8')) # json.dump(user_data, open(f'data/{username}.json', 'w'), ensure_ascii=False, indent=2) else: st.error('请先登录')