test_annotation / app.py
maybeMayank's picture
updated storage
2dcb694
import os
import json
import random
import datetime
import streamlit as st
from huggingface_hub import Repository
# --- Authentication ---
PASSWORD = os.getenv("ANNOTATION_APP_PASSWORD")
if "auth" not in st.session_state:
st.session_state.auth = False
if not st.session_state.auth:
st.title("Login")
pwd = st.text_input("Enter password:", type="password")
if st.button("Login"):
if pwd == PASSWORD:
st.session_state.auth = True
else:
st.error("Incorrect password, try again.")
st.stop()
# --- Annotator Identification ---
if "annotator" not in st.session_state:
st.title("Annotator Info")
name = st.text_input("Enter your name or ID:")
if st.button("Start Annotation"):
if name:
st.session_state.annotator = name
else:
st.error("Please enter a valid name or ID.")
st.stop()
ANNOTATOR = st.session_state.annotator
# --- JSONL Loader ---
def load_jsonl(path):
if not os.path.exists(path):
return []
with open(path, "r", encoding="utf-8") as f:
return [json.loads(line) for line in f]
# --- Load Data Pairs ---
JSONL1 = os.getenv("JSONL_FILE1_PATH", "data/eval_v2_results.jsonl")
JSONL2 = os.getenv("JSONL_FILE2_PATH", "data/noisy_eval_v2_results.jsonl")
data1 = load_jsonl(JSONL1)
data2 = load_jsonl(JSONL2)
pairs = list(zip(data1, data2))
# --- Stable Shuffle per Annotator ---
shuffle_rng = random.Random(ANNOTATOR)
shuffle_rng.shuffle(pairs)
# --- HF Dataset Repo Init ---
def init_repo():
HF_TOKEN = os.getenv("HF_TOKEN")
HF_DATASET_REPO = os.getenv("HF_DATASET_REPO")
repo_dir = "hf_dataset"
if not os.path.exists(repo_dir):
repo = Repository(
local_dir=repo_dir,
clone_from=HF_DATASET_REPO,
repo_type="dataset",
use_auth_token=HF_TOKEN
)
else:
repo = Repository(local_dir=repo_dir)
return repo, repo_dir
repo, repo_dir = init_repo()
ann_filename = f"annotations_{ANNOTATOR}.jsonl"
ann_path = os.path.join(repo_dir, ann_filename)
FIELD = "discharge_instructions"
# --- Load Existing Annotations ---
existing = load_jsonl(ann_path)
# --- Resume Index Setup ---
if "idx" not in st.session_state:
st.session_state.pairs = pairs
st.session_state.idx = len(existing)
# --- Sidebar: Edit Past Annotations ---
st.sidebar.header("Your Annotations")
can_edit = len(existing) > 0
edit_mode = st.sidebar.checkbox("Edit previous annotation", disabled=not can_edit)
if edit_mode:
sel = st.sidebar.selectbox(
"Select annotation to edit", list(range(1, len(existing) + 1))
)
rec = existing[sel - 1]
# Find corresponding pair
for p in pairs:
if p[0].get("hadm_id") == rec["id1"] and p[1].get("hadm_id") == rec["id2"]:
edit_pair = p
break
st.header(f"Editing annotation {sel} of {len(existing)}")
col1, col2 = st.columns(2)
with col1:
st.subheader("Option A")
st.text_area("", value=edit_pair[0].get(FIELD, f"{FIELD} not found"), height=300, key="opt_a")
with col2:
st.subheader("Option B")
st.text_area("", value=edit_pair[1].get(FIELD, f"{FIELD} not found"), height=300, key="opt_b")
choice = st.radio(
"Which is better?", ("A", "B", "Equal"),
index=["A", "B", "Equal"].index(rec["choice"])
)
if st.button("Update Annotation"):
existing[sel - 1]["choice"] = choice
existing[sel - 1]["timestamp"] = datetime.datetime.utcnow().isoformat()
# Overwrite entire file
with open(ann_path, "w", encoding="utf-8") as f:
for r in existing:
f.write(json.dumps(r) + "\n")
repo.git_add(ann_filename)
repo.git_commit(f"Update annotation {sel} by {ANNOTATOR}")
repo.git_push()
st.success("Annotation updated!")
st.rerun()
else:
idx = st.session_state.idx
if idx < len(pairs):
current = pairs[idx]
st.header(f"Pair {idx + 1} of {len(pairs)}")
col1, col2 = st.columns(2)
with col1:
st.subheader("Option A")
st.text_area("", value=current[0].get(FIELD, f"{FIELD} not found"), height=300, key="opt_a")
with col2:
st.subheader("Option B")
st.text_area("", value=current[1].get(FIELD, f"{FIELD} not found"), height=300, key="opt_b")
choice = st.radio("Which is better?", ("A", "B", "Equal"), key="choice")
if st.button("Submit Rating"):
record = {
"id1": current[0].get("hadm_id"),
"id2": current[1].get("hadm_id"),
"choice": choice,
"annotator": ANNOTATOR,
"timestamp": datetime.datetime.utcnow().isoformat()
}
existing.append(record)
# Overwrite file with updated list
with open(ann_path, "w", encoding="utf-8") as f:
for r in existing:
f.write(json.dumps(r) + "\n")
repo.git_add(ann_filename)
repo.git_commit(f"Add annotation {idx + 1} by {ANNOTATOR}")
repo.git_push()
st.success("Rating submitted!")
st.session_state.idx += 1
st.rerun()
else:
st.balloons()
st.write("All pairs annotated. Thank you!")