| | import os |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| | import json |
| | import torch |
| | import pickle |
| | import gradio as gr |
| | import textstat |
| | from sentence_transformers import SentenceTransformer, util |
| |
|
| | |
| | LANG_CODE = "en" |
| | CHUNKS_PATH = f"/home/mshahidul/readctrl/data/vector_db/db_model/wiki_{LANG_CODE}_chunks.pkl" |
| | EMBS_PATH = f"/home/mshahidul/readctrl/data/vector_db/db_model/wiki_{LANG_CODE}_embs.pt" |
| | TARGET_DOCS_PATH = f"/home/mshahidul/readctrl/data/synthetic_dataset_diff_labels/syn_data_diff_labels_{LANG_CODE}_v1.json" |
| | SAVE_PATH = f"/home/mshahidul/readctrl/data/data_annotator_data/manual_selections_{LANG_CODE}.json" |
| |
|
| | |
| | print("Loading Model and Tensors...") |
| | model = SentenceTransformer('all-MiniLM-L6-v2') |
| |
|
| | with open(CHUNKS_PATH, "rb") as f: |
| | wiki_chunks = pickle.load(f) |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | wiki_embs = torch.load(EMBS_PATH).to(device) |
| |
|
| | with open(TARGET_DOCS_PATH, "r") as f: |
| | raw_targets = json.load(f) |
| |
|
| | target_list = [] |
| | for item in raw_targets: |
| | for label, text in item['diff_label_texts'].items(): |
| | target_list.append({ |
| | "index": item['index'], |
| | "label": label, |
| | "text": text |
| | }) |
| |
|
| | |
| | def get_resume_index(): |
| | """Finds the first index in target_list that hasn't been saved yet.""" |
| | if not os.path.exists(SAVE_PATH): |
| | return 0 |
| | |
| | try: |
| | with open(SAVE_PATH, "r") as f: |
| | saved_data = json.load(f) |
| | |
| | |
| | done_keys = {(d['index'], d['label']) for d in saved_data} |
| | |
| | for i, item in enumerate(target_list): |
| | if (item['index'], item['label']) not in done_keys: |
| | return i |
| | return len(target_list) - 1 |
| | except Exception as e: |
| | print(f"Error loading save file: {e}") |
| | return 0 |
| |
|
| | START_INDEX = get_resume_index() |
| | print(f"Resuming from index: {START_INDEX}") |
| |
|
| | |
| | def get_candidates(target_text, top_k=20): |
| | query_emb = model.encode(target_text, convert_to_tensor=True).to(device) |
| | hits = util.semantic_search(query_emb, wiki_embs, top_k=top_k)[0] |
| | |
| | candidates = [] |
| | for hit in hits: |
| | candidates.append(wiki_chunks[hit['corpus_id']]) |
| | return candidates |
| |
|
| | def calculate_stats(text): |
| | if not text: return "N/A" |
| | wc = len(text.split()) |
| | fk = textstat.flesch_kincaid_grade(text) |
| | return f"π Words: {wc} | π FKGL: {fk}" |
| |
|
| | def save_selection(target_idx, label, original_text, selected_wiki): |
| | entry = { |
| | "index": target_idx, |
| | "label": label, |
| | "original_text": original_text, |
| | "selected_wiki_anchor": selected_wiki, |
| | "wiki_fkgl": textstat.flesch_kincaid_grade(selected_wiki), |
| | "doc_fkgl": textstat.flesch_kincaid_grade(original_text) |
| | } |
| | |
| | existing_data = [] |
| | if os.path.exists(SAVE_PATH): |
| | try: |
| | with open(SAVE_PATH, "r") as f: |
| | existing_data = json.load(f) |
| | except: |
| | existing_data = [] |
| | |
| | |
| | existing_data = [d for d in existing_data if not (d['index'] == target_idx and d['label'] == label)] |
| | existing_data.append(entry) |
| | |
| | with open(SAVE_PATH, "w") as f: |
| | json.dump(existing_data, f, indent=2) |
| | return f"β
Saved: ID {target_idx} ({label})" |
| |
|
| | |
| | with gr.Blocks(theme=gr.themes.Soft(), title="Wiki Anchor Selector") as demo: |
| | gr.Markdown(f"# π ReadCtrl: Anchor Selection (Resume Mode)") |
| | |
| | |
| | current_idx = gr.State(START_INDEX) |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | target_info = gr.Markdown("### Loading...") |
| | label_display = gr.Textbox(label="Target Readability Level", interactive=False) |
| | display_text = gr.Textbox(label="Medical Text", lines=12, interactive=False) |
| | target_stats = gr.Markdown("Stats: ...") |
| | |
| | with gr.Column(scale=2): |
| | wiki_dropdown = gr.Dropdown( |
| | label="Select Candidate Number", |
| | choices=[], |
| | interactive=True |
| | ) |
| | full_wiki_view = gr.Textbox(label="Wikipedia Chunk Preview", lines=12, interactive=False) |
| | wiki_stats = gr.Markdown("Stats: ...") |
| |
|
| | status_msg = gr.Markdown("### *Status: Ready*") |
| |
|
| | with gr.Row(): |
| | prev_btn = gr.Button("β¬
οΈ Previous") |
| | save_btn = gr.Button("πΎ Confirm & Save", variant="primary") |
| | next_btn = gr.Button("Next / Skip β‘οΈ") |
| |
|
| | def load_item(idx): |
| | if not (0 <= idx < len(target_list)): |
| | return "End", "None", "", "", gr.update(choices=[], value=None), "", "", "Finished all items!" |
| | |
| | doc = target_list[idx] |
| | candidates = get_candidates(doc['text'], top_k=20) |
| | |
| | info = f"### Document {idx + 1} of {len(target_list)} (ID: {doc['index']})" |
| | t_stats = calculate_stats(doc['text']) |
| | |
| | dropdown_choices = [(f"Candidate {i+1}", c) for i, c in enumerate(candidates)] |
| | |
| | return ( |
| | info, |
| | doc['label'].upper(), |
| | doc['text'], |
| | t_stats, |
| | gr.update(choices=dropdown_choices, value=candidates[0]), |
| | candidates[0], |
| | calculate_stats(candidates[0]), |
| | f"Currently viewing index {idx}" |
| | ) |
| |
|
| | def on_dropdown_change(selected_text): |
| | if not selected_text: return "", "" |
| | return selected_text, calculate_stats(selected_text) |
| |
|
| | def handle_next(idx): |
| | new_idx = min(len(target_list) - 1, idx + 1) |
| | return [new_idx] + list(load_item(new_idx)) |
| |
|
| | def handle_prev(idx): |
| | new_idx = max(0, idx - 1) |
| | return [new_idx] + list(load_item(new_idx)) |
| |
|
| | |
| | |
| | demo.load(load_item, inputs=[current_idx], |
| | outputs=[target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg]) |
| | |
| | wiki_dropdown.change(on_dropdown_change, inputs=wiki_dropdown, outputs=[full_wiki_view, wiki_stats]) |
| | |
| | save_btn.click(lambda i, t, w: save_selection(target_list[i]['index'], target_list[i]['label'], t, w), |
| | inputs=[current_idx, display_text, wiki_dropdown], |
| | outputs=[status_msg]) |
| | |
| | next_btn.click(handle_next, inputs=[current_idx], outputs=[current_idx, target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg]) |
| | prev_btn.click(handle_prev, inputs=[current_idx], outputs=[current_idx, target_info, label_display, display_text, target_stats, wiki_dropdown, full_wiki_view, wiki_stats, status_msg]) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(server_name="0.0.0.0", server_port=7861, share=True) |