import os import csv import streamlit as st import polars as pl from io import BytesIO, StringIO from gliner import GLiNER from gliner_file import run_ner import time st.set_page_config( page_title="GliNER", page_icon="🧊", layout="wide", initial_sidebar_state="expanded" ) # Modified function to load data from either an Excel or CSV file @st.cache_data def load_data(file): _, file_ext = os.path.splitext(file.name) if file_ext.lower() in [".xls", ".xlsx"]: return pl.read_excel(file) elif file_ext.lower() == ".csv": file.seek(0) # Go back to the beginning of the file try: sample = file.read(4096).decode( "utf-8" ) # Try to decode the sample in UTF-8 encoding = "utf-8" except UnicodeDecodeError: encoding = "latin1" # Switch to 'latin1' if UTF-8 fails file.seek(0) sample = file.read(4096).decode(encoding) file.seek(0) dialect = csv.Sniffer().sniff(sample) # Detect the delimiter file.seek(0) if encoding != "utf-8": file_content = file.read().decode(encoding) file = StringIO(file_content) else: file_content = file.read().decode("utf-8") file = StringIO(file_content) return pl.read_csv( file, separator=dialect.delimiter, truncate_ragged_lines=True, ignore_errors=True, ) else: raise ValueError("The uploaded file must be a CSV or Excel file.") # Function to perform NER and update the UI def perform_ner(filtered_df, selected_column, labels_list): ner_results_dict = {label: [] for label in labels_list} progress_bar = st.progress(0) progress_text = st.empty() start_time = time.time() # Record start time for total runtime for index, row in enumerate(filtered_df.to_pandas().itertuples(), 1): iteration_start_time = time.time() # Start time for this iteration if st.session_state.stop_processing: progress_text.text("Process stopped by the user.") break text_to_analyze = getattr(row, selected_column) ner_results = run_ner( st.session_state.gliner_model, text_to_analyze, labels_list ) for label in labels_list: texts = ner_results.get(label, []) concatenated_texts = ", ".join(texts) ner_results_dict[label].append(concatenated_texts) progress = index / filtered_df.height progress_bar.progress(progress) iteration_time = ( time.time() - iteration_start_time ) # Calculate runtime for this iteration total_time = time.time() - start_time # Calculate total elapsed time so far progress_text.text( f"Progress: {index}/{filtered_df.height} - {progress * 100:.0f}% (Iteration: {iteration_time:.2f}s, Total: {total_time:.2f}s)" ) end_time = time.time() # Record end time total_execution_time = end_time - start_time # Calculate total runtime progress_text.text( f"Processing complete! Total execution time: {total_execution_time:.2f}s" ) for label, texts in ner_results_dict.items(): filtered_df = filtered_df.with_columns(pl.Series(name=label, values=texts)) return filtered_df def main(): st.title("Online NER with GliNER") st.markdown("Prototype v0.1") # Ensure the stop_processing flag is initialized if "stop_processing" not in st.session_state: st.session_state.stop_processing = False uploaded_file = st.sidebar.file_uploader("Choose a file") if uploaded_file is None: st.warning("Please upload a file.") return try: df = load_data(uploaded_file) except ValueError as e: st.error(str(e)) return selected_column = st.selectbox("Select the column for NER:", df.columns, index=0) filter_text = st.text_input("Filter column by input text", "") ner_labels = st.text_input( "Enter all your different labels, separated by a comma", "" ) filtered_df = ( df.filter(pl.col(selected_column).str.contains(f"(?i).*{filter_text}.*")) if filter_text else df ) st.dataframe(filtered_df) if st.button("Start NER"): if not ner_labels: st.warning("Please enter some labels for NER.") else: # Load GLiNER model if not already loaded if "gliner_model" not in st.session_state: with st.spinner("Loading GLiNER model... Please wait."): st.session_state.gliner_model = GLiNER.from_pretrained( "urchade/gliner_largev2" ) st.session_state.gliner_model.eval() labels_list = ner_labels.split(",") updated_df = perform_ner(filtered_df, selected_column, labels_list) st.dataframe(updated_df) def to_excel(df): output = BytesIO() df.to_pandas().to_excel(output, index=False, engine="openpyxl") return output.getvalue() df_excel = to_excel(updated_df) st.download_button( label="📥 Download Excel", data=df_excel, file_name="ner_results.xlsx", mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ) st.button( "Stop Processing", on_click=lambda: setattr(st.session_state, "stop_processing", True), ) if __name__ == "__main__": main()