July-Annotation / app.py
orionweller's picture
fix
7d4d6c4
raw
history blame contribute delete
No virus
13.6 kB
import streamlit as st
import os
import pathlib
import pandas as pd
from collections import defaultdict
import json
import copy
import re
import tqdm
import numpy as np
import pandas as pd
from nltk.corpus import stopwords
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize
from collections import Counter
import string
import os
import streamlit as st
# Ensure you've downloaded the set of stop words the first time you run this
import nltk
# only download if they don't exist
# if not os.path.exists(os.path.join(nltk.data.find('corpora'), 'stopwords')):
nltk.download('punkt')
nltk.download('stopwords')
from dataset_loading import load_local_corpus, load_local_queries, load_local_triples, load_jsonl
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
st.set_page_config(layout="wide")
current_checkboxes = []
query_input = None
infinite_defaultdict = lambda: defaultdict(infinite_defaultdict)
RESULTS = infinite_defaultdict()
if 'results' not in st.session_state:
st.session_state.results = RESULTS
def get_current_data():
cur_query_data = []
cur_query = query_input.replace("\n", "\\n")
for doc_id, checkbox in current_checkboxes:
if checkbox:
cur_query_data.append({
"new_narrative": cur_query,
"qid": st.session_state.selectbox_instance,
"doc_id": doc_id,
"is_relevant": 0
})
# return the data as a CSV pandas
return convert_df(pd.DataFrame(cur_query_data))
def add_results(doc_id, annotation, model_ground_truth):
if annotation is None:
return
st.session_state.results[doc_id] = {
"annotation": annotation,
"model_ground_truth": model_ground_truth
}
print(st.session_state.results[doc_id])
print("Added to results")
def extract_doc_text(doc: dict):
if "title" in doc and doc["title"].strip() != "":
return doc["title"] + "\n\n" + doc["text"]
else:
return doc["text"]
# @st.cache_resource
# def load_data_example():
# d
# with open("example_output.jsonl", "r") as fin:
# data = json.load(fin)
# return data
def create_instance(prompt, output):
if output is None:
return []
if type(output) != list:
return []
# get the query and specific query by extracting the values on that line
query = [item for item in prompt.split("\n") if "Query:" in item][0].replace("Query:", "").strip()
instruction = [item for item in prompt.split("\n") if "Specific Query:" in item][0].replace("Specific Query:", "").strip()
# create each instance with the outputs
final_data = []
for idx, out in enumerate(output):
final_data.append({
"query": query,
"generation": out,
"instruction": instruction,
})
return final_data
def generate_markdown(cur_instance):
st.markdown("<h1 style='text-align: center; color: black;text-decoration: underline;'>Instance</h1>", unsafe_allow_html=True)
st.markdown(f"<h3 style='text-align: left; color: black;'>Query</h3><p>{cur_instance['query']}</p>", unsafe_allow_html=True)
st.markdown(f"<h3 style='text-align: left; color: black;'>Instruction</h3><p>{cur_instance['instruction']}</p>", unsafe_allow_html=True)
st.markdown(f"<h3 style='text-align: left; color: black;'>Generated Passage</h3><p>{extract_doc_text(cur_instance['passage'])}</p>", unsafe_allow_html=True)
# if checkbox is clicked, show the explanation
if st.checkbox("Show Explanation"):
st.markdown(f"<p>{cur_instance['explanation']}</p>", unsafe_allow_html=True)
# if checkbox is clicked show matches_both
if st.checkbox("Show Model Ground Truth"):
st.markdown(f"<p>{cur_instance['matches_both']}</p>", unsafe_allow_html=True)
if 'cur_instance_num' not in st.session_state:
st.session_state.cur_instance_num = -1
if 'number_of_col' not in st.session_state:
st.session_state.number_of_col = -1
def validate(config_option, file_loaded):
if config_option != "None" and file_loaded is None:
st.error("Please upload a file for " + config_option)
st.stop()
with st.sidebar:
st.title("Options")
use_default = st.checkbox("Use default data", value=False)
if use_default:
st.write("Using default data")
data = load_jsonl(open("example_output.jsonl", "r"))
# needs ids and mapping
ids = [item["joint_id"] for item in data]
mapping = {item["joint_id"]: item for item in data}
data2 = None
else:
st.header("Input File")
input_file = st.file_uploader("Choose a file", key="data")
data = load_jsonl(input_file)
if data is not None:
# needs ids and mapping
ids = [item["joint_id"] for item in data]
mapping = {item["joint_id"]: item for item in data}
input_file2 = st.file_uploader("Choose a second file", key="data2")
data2 = load_jsonl(input_file2)
if data2 is not None:
# needs ids2 and mapping2
ids2 = [item["joint_id"] for item in data2]
mapping2 = {item["joint_id"]: item for item in data2}
col1, col2 = st.columns([1, 3], gap="large")
if data is not None:
joint_ids = ids if data2 is None else list(set(ids2).intersection(ids))
# print(f"Not using ids {set(ids) - set(joint_ids)} and {set(ids2) - set(joint_ids)}")
with st.sidebar:
st.success("All files uploaded")
with col1:
set_of_cols = joint_ids
container_for_nav = st.container()
name_of_columns = sorted([item for item in set_of_cols])
instances_to_use = name_of_columns
st.title("Instances")
def sync_from_drop():
if st.session_state.selectbox_instance == "Overview":
st.session_state.number_of_col = -1
st.session_state.cur_instance_num = -1
else:
index_of_obj = name_of_columns.index(st.session_state.selectbox_instance)
# print("Index of obj: ", index_of_obj, type(index_of_obj))
st.session_state.number_of_col = index_of_obj
st.session_state.cur_instance_num = index_of_obj
def sync_from_number():
st.session_state.cur_instance_num = st.session_state.number_of_col
# print("Session state number of col: ", st.session_state.number_of_col, type(st.session_state.number_of_col))
if st.session_state.number_of_col == -1:
st.session_state.selectbox_instance = "Overview"
else:
st.session_state.selectbox_instance = name_of_columns[st.session_state.number_of_col]
number_of_col = container_for_nav.number_input(min_value=-1, step=1, max_value=len(instances_to_use) - 1, on_change=sync_from_number, label=f"Select instance by index (up to **{len(instances_to_use) - 1}**)", key="number_of_col")
selectbox_instance = container_for_nav.selectbox("Select instance by ID", ["Overview"] + name_of_columns, on_change=sync_from_drop, key="selectbox_instance")
st.divider()
st.title(f"Results ({len(st.session_state.results)})")
# show results if they exist
import plotly.express as px
confusion_matrix = {
"True": {
"True": 0,
"False": 0
},
"False": {
"True": 0,
"False": 0
}
}
instruction_matrix = copy.deepcopy(confusion_matrix)
for key, value_dict in st.session_state.results.items():
print(value_dict)
model_pred_instruct = str(value_dict["model_ground_truth"])
model_pred_query = str(True)
human_pred = value_dict["annotation"]
if human_pred == "Instruction Negative, Query Positive":
instruction_matrix[model_pred_instruct]["False"] += 1
confusion_matrix[model_pred_query]["True"] += 1
elif human_pred == "Instruction Positive, Query Positive":
instruction_matrix[model_pred_instruct]["True"] += 1
confusion_matrix[model_pred_query]["True"] += 1
elif human_pred == "Instruction Positive, Query Negative":
instruction_matrix[model_pred_instruct]["True"] += 1
confusion_matrix[model_pred_query]["False"] += 1
elif human_pred == "Instruction Negative, Query Negative":
instruction_matrix[model_pred_instruct]["False"] += 1
confusion_matrix[model_pred_query]["False"] += 1
# make a two confusion matrix plot
tab1, tab2 = st.tabs(["Instruction", "Query"])
with tab1:
fig = px.imshow([[instruction_matrix["True"]["True"], instruction_matrix["True"]["False"]], [instruction_matrix["False"]["True"], instruction_matrix["False"]["False"]]], labels=dict(y="Model Prediction", x="Annotation", color="Count"), x=["Relevant", "Not Relevant"], y=["Relevant", "Not Relevant"])
fig.update_xaxes(side="top")
fig.update_layout(
height=400,
width=250
)
st.plotly_chart(fig)
if st.checkbox("Show Counts", key="instruction_counts"):
st.write(instruction_matrix)
with tab2:
fig2 = px.imshow([[confusion_matrix["True"]["True"], confusion_matrix["True"]["False"]], [confusion_matrix["False"]["True"], confusion_matrix["False"]["False"]]], labels=dict(x="Annotation", y="Model Prediction", color="Count"), x=["Relevant", "Not Relevant"], y=["Relevant", "Not Relevant"])
fig2.update_xaxes(side="top")
fig2.update_layout(
height=400,
width=250
)
st.plotly_chart(fig2)
if st.checkbox("Show Counts", key="show_counts"):
st.write(confusion_matrix)
# show results as a table if button selected
if st.checkbox("Show Results"):
data = []
for key, value_dict in st.session_state.results.items():
data.append({
"ID": key,
"Annotation": value_dict["annotation"],
"Model Ground Truth": value_dict["model_ground_truth"]
})
st.write(pd.DataFrame(data))
with col2:
# get instance number
inst_index = number_of_col
if inst_index >= 0:
cur_instance = mapping[joint_ids[inst_index]]
if data2 is not None:
cur_instance2 = mapping2[joint_ids[inst_index]]
col1_out, col2_out = st.columns([1, 1], gap="small")
with col1_out:
generate_markdown(cur_instance)
# create options for labeling and save to RESULTS
form = st.form("annotate")
# Now add a submit button to the form:
annotation = form.selectbox(
"Annotation",
["Instruction Negative, Query Positive", "Instruction Positive, Query Positive", "Instruction Positive, Query Negative", "Instruction Negative, Query Negative"],
)
submitted = form.form_submit_button("Submit")
if submitted:
add_results(joint_ids[inst_index], annotation, cur_instance["matches_both"])
with col2_out:
generate_markdown(cur_instance2)
# create options for labeling and save to RESULTS
form2 = st.form("annotate2")
# Now add a submit button to the form:
annotation2 = form.selectbox(
"Annotation",
["Instruction Negative, Query Positive", "Instruction Positive, Query Positive", "Instruction Positive, Query Negative", "Instruction Negative, Query Negative"],
)
submitted2 = form2.form_submit_button("Submit")
if submitted2:
add_results(joint_ids[inst_index], annotation2, cur_instance2["matches_both"])
else:
generate_markdown(cur_instance)
# create options for labeling and save to RESULTS
form = st.form("annotate")
# Now add a submit button to the form:
annotation = form.selectbox(
"Annotation",
["Instruction Negative, Query Positive", "Instruction Positive, Query Positive", "Instruction Positive, Query Negative", "Instruction Negative, Query Negative"],
)
submitted = form.form_submit_button("Submit")
if submitted:
add_results(joint_ids[inst_index], annotation, cur_instance["matches_both"])
# if st.checkbox("Download data as CSV"):
# st.download_button(
# label="Download data as CSV",
# data=get_current_data(),
# file_name=f'annotation_query_{inst_num}.csv',
# mime='text/csv',
# )
# none checked
elif inst_index < 0:
st.title("Overview")
else:
st.warning("Please choose an output file from prompting and upload it")