Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import pathlib | |
import pandas as pd | |
from collections import defaultdict | |
import json | |
import ast | |
import copy | |
import re | |
import tqdm | |
import pandas as pd | |
from collections import Counter | |
import string | |
import os | |
import streamlit as st | |
import difflib | |
from html import escape | |
def generate_diff_html_word_level(text1, text2): | |
""" | |
Generates word-level difference between text1 and text2 as HTML, correctly handling spaces. | |
""" | |
# Splitting texts into words | |
words1 = text1.split() | |
words2 = text2.split() | |
diff = [] | |
matcher = difflib.SequenceMatcher(None, words1, words2) | |
for opcode in matcher.get_opcodes(): | |
tag, i1, i2, j1, j2 = opcode | |
if tag == 'replace': | |
diff.append('<del style="background-color: #fbb6ce;">' + escape(' '.join(words1[i1:i2])) + '</del>') | |
diff.append('<ins style="background-color: #b7e4c7;">' + escape(' '.join(words2[j1:j2])) + '</ins>') | |
elif tag == 'delete': | |
diff.append('<del style="background-color: #fbb6ce;">' + escape(' '.join(words1[i1:i2])) + '</del>') | |
elif tag == 'insert': | |
diff.append('<ins style="background-color: #b7e4c7;">' + escape(' '.join(words2[j1:j2])) + '</ins>') | |
elif tag == 'equal': | |
diff.append(escape(' '.join(words1[i1:i2]))) | |
# Construct final HTML string | |
final_html = ' '.join(diff).replace('</del> <ins', '</del> <ins') | |
return f'<pre style="white-space: pre-wrap;">{final_html}</pre>' | |
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | |
st.set_page_config(layout="wide") | |
current_checkboxes = [] | |
query_input = None | |
def convert_df(df): | |
# IMPORTANT: Cache the conversion to prevent computation on every rerun | |
return df.to_csv(path_or_buf=None, index=False, quotechar='"').encode('utf-8') | |
def escape_markdown(text): | |
# List of characters to escape | |
# Adding backslash to the list of special characters to escape itself as well | |
text = text.replace("``", "\"") | |
text = text.replace("$", "\$") | |
special_chars = ['\\', '`', '*', '_', '{', '}', '[', ']', '(', ')', '#', '+', '-', '.', '!', '|', "$"] | |
# Escaping each special character | |
escaped_text = "".join(f"\\{char}" if char in special_chars else char for char in text) | |
return escaped_text | |
if 'cur_instance_num' not in st.session_state: | |
st.session_state.cur_instance_num = -1 | |
def validate(config_option, file_loaded): | |
if config_option != "None" and file_loaded is None: | |
st.error("Please upload a file for " + config_option) | |
st.stop() | |
with st.sidebar: | |
st.title("Options") | |
def load_chunked_data(): | |
data = [] | |
with open("chunked_data.jsonl", "r") as f: | |
for line in f: | |
data.append(json.loads(line)) | |
# rename prompt column to text | |
df = pd.DataFrame(data) | |
df = df.rename(columns={"prompt": "text"}) | |
return df | |
def load_generated_data(): | |
with open("generated_data.json", "r") as fin: | |
data = json.load(fin)["outputs"] | |
new_insts = [] | |
for key, value in data.items(): | |
item = { | |
"venue": key | |
} | |
if type(value) == str: | |
value = ast.literal_eval(value) | |
if type(value) == dict: | |
for cur_key, cur_value in value.items(): | |
item[cur_key] = cur_value | |
else: | |
raise ValueError(f"Invalid type {type(value)}: {value}") | |
new_insts.append(item) | |
return pd.DataFrame(new_insts) | |
original_df = load_chunked_data() | |
generated_data = load_generated_data() | |
def combine_text(item): | |
string_text = "" | |
for key, value in item.items(): | |
if key == "venue" or value is None or value == "[]" or type(value) == float or len(value) == 0: | |
continue | |
string_text += f",{', '.join(value)}\n" | |
if "," == string_text[0]: | |
string_text = string_text[1:] | |
return string_text | |
original_map = {item["venue"]: item["text"] for item in original_df.to_dict(orient="records")} | |
generated_map = {item["venue"]: combine_text(item) for item in generated_data.to_dict(orient="records")} | |
col1, col2 = st.columns([1, 3], gap="large") | |
with st.sidebar: | |
st.success("All files uploaded") | |
with col1: | |
# breakpoint() | |
ids = original_df["venue"].tolist() | |
set_of_cols = set(ids) | |
container_for_nav = st.container() | |
name_of_columns = sorted([item for item in set_of_cols]) | |
instances_to_use = name_of_columns | |
st.title("Instances") | |
def sync_from_drop(): | |
if st.session_state.selectbox_instance == "Overview": | |
st.session_state.number_of_col = -1 | |
st.session_state.cur_instance_num = -1 | |
else: | |
index_of_obj = name_of_columns.index(st.session_state.selectbox_instance) | |
# print("Index of obj: ", index_of_obj, type(index_of_obj)) | |
st.session_state.number_of_col = index_of_obj | |
st.session_state.cur_instance_num = index_of_obj | |
def sync_from_number(): | |
st.session_state.cur_instance_num = st.session_state.number_of_col | |
# print("Session state number of col: ", st.session_state.number_of_col, type(st.session_state.number_of_col)) | |
if st.session_state.number_of_col == -1: | |
st.session_state.selectbox_instance = "Overview" | |
else: | |
st.session_state.selectbox_instance = name_of_columns[st.session_state.number_of_col] | |
number_of_col = container_for_nav.number_input(min_value=-1, step=1, max_value=len(instances_to_use) - 1, on_change=sync_from_number, label=f"Select instance by index (up to **{len(instances_to_use) - 1}**)", key="number_of_col") | |
selectbox_instance = container_for_nav.selectbox("Select instance by ID", ["Overview"] + name_of_columns, on_change=sync_from_drop, key="selectbox_instance") | |
st.divider() | |
with col2: | |
# get instance number | |
inst_index = number_of_col | |
if inst_index >= 0: | |
inst_num = instances_to_use[inst_index] | |
st.markdown("<h1 style='text-align: center; color: black;text-decoration: underline;'>Editor</h1>", unsafe_allow_html=True) | |
container = st.container() | |
container.subheader(f"Venue: {inst_num}") | |
container.divider() | |
original_text = original_map[inst_num] | |
generated_text = generated_map[inst_num] | |
container.subheader(f"Original OCR Text") | |
original_input = container.markdown(original_text) | |
container.divider() | |
container.subheader(f"Generated Text") | |
generated_input = container.markdown(generated_text) | |
container.divider() | |
# print("Original text: ", original_text) | |
# print("Generated text: ", generated_text) | |
# Diff | |
if original_text is not None and generated_input is not None: | |
container.subheader("Diff") | |
processed_diff = generate_diff_html_word_level(original_map[inst_num], generated_map[inst_num]) | |
with container.container(border=True): | |
st.markdown(processed_diff, unsafe_allow_html=True) | |
# editable text, starting from the generated text | |
editable_text = container.text_area("Edit the generated text", value=generated_text, height=300) | |
container.divider() | |
# download the editable text and venue name | |
st.download_button( | |
f"Download {inst_num} as CSV", | |
convert_df(pd.DataFrame([{"venue": inst_num, "text": editable_text}])), | |
f"{inst_num}.csv", | |
"text/csv", | |
key=f"download_{inst_num}" | |
) | |
# none checked | |
elif inst_index < 0: | |
st.title("Overview") | |