GLiNER_file / app.py
Olivier CARON
Update app.py
b25dbae verified
import os
import csv
import streamlit as st
import polars as pl
from io import BytesIO, StringIO
from gliner import GLiNER
from gliner_file import run_ner
import time
st.set_page_config(
page_title="GliNER", page_icon="🧊", layout="wide", initial_sidebar_state="expanded"
)
# Modified function to load data from either an Excel or CSV file
@st.cache_data
def load_data(file):
_, file_ext = os.path.splitext(file.name)
if file_ext.lower() in [".xls", ".xlsx"]:
return pl.read_excel(file)
elif file_ext.lower() == ".csv":
file.seek(0) # Go back to the beginning of the file
try:
sample = file.read(4096).decode(
"utf-8"
) # Try to decode the sample in UTF-8
encoding = "utf-8"
except UnicodeDecodeError:
encoding = "latin1" # Switch to 'latin1' if UTF-8 fails
file.seek(0)
sample = file.read(4096).decode(encoding)
file.seek(0)
dialect = csv.Sniffer().sniff(sample) # Detect the delimiter
file.seek(0)
if encoding != "utf-8":
file_content = file.read().decode(encoding)
file = StringIO(file_content)
else:
file_content = file.read().decode("utf-8")
file = StringIO(file_content)
return pl.read_csv(
file,
separator=dialect.delimiter,
truncate_ragged_lines=True,
ignore_errors=True,
)
else:
raise ValueError("The uploaded file must be a CSV or Excel file.")
# Function to perform NER and update the UI
def perform_ner(filtered_df, selected_column, labels_list):
ner_results_dict = {label: [] for label in labels_list}
progress_bar = st.progress(0)
progress_text = st.empty()
start_time = time.time() # Record start time for total runtime
for index, row in enumerate(filtered_df.to_pandas().itertuples(), 1):
iteration_start_time = time.time() # Start time for this iteration
if st.session_state.stop_processing:
progress_text.text("Process stopped by the user.")
break
text_to_analyze = getattr(row, selected_column)
ner_results = run_ner(
st.session_state.gliner_model, text_to_analyze, labels_list
)
for label in labels_list:
texts = ner_results.get(label, [])
concatenated_texts = ", ".join(texts)
ner_results_dict[label].append(concatenated_texts)
progress = index / filtered_df.height
progress_bar.progress(progress)
iteration_time = (
time.time() - iteration_start_time
) # Calculate runtime for this iteration
total_time = time.time() - start_time # Calculate total elapsed time so far
progress_text.text(
f"Progress: {index}/{filtered_df.height} - {progress * 100:.0f}% (Iteration: {iteration_time:.2f}s, Total: {total_time:.2f}s)"
)
end_time = time.time() # Record end time
total_execution_time = end_time - start_time # Calculate total runtime
progress_text.text(
f"Processing complete! Total execution time: {total_execution_time:.2f}s"
)
for label, texts in ner_results_dict.items():
filtered_df = filtered_df.with_columns(pl.Series(name=label, values=texts))
return filtered_df
def main():
st.title("Online NER with GliNER")
st.markdown("Prototype v0.1")
# Ensure the stop_processing flag is initialized
if "stop_processing" not in st.session_state:
st.session_state.stop_processing = False
uploaded_file = st.sidebar.file_uploader("Choose a file")
if uploaded_file is None:
st.warning("Please upload a file.")
return
try:
df = load_data(uploaded_file)
except ValueError as e:
st.error(str(e))
return
selected_column = st.selectbox("Select the column for NER:", df.columns, index=0)
filter_text = st.text_input("Filter column by input text", "")
ner_labels = st.text_input(
"Enter all your different labels, separated by a comma", ""
)
filtered_df = (
df.filter(pl.col(selected_column).str.contains(f"(?i).*{filter_text}.*"))
if filter_text
else df
)
st.dataframe(filtered_df)
if st.button("Start NER"):
if not ner_labels:
st.warning("Please enter some labels for NER.")
else:
# Load GLiNER model if not already loaded
if "gliner_model" not in st.session_state:
with st.spinner("Loading GLiNER model... Please wait."):
st.session_state.gliner_model = GLiNER.from_pretrained(
"urchade/gliner_largev2"
)
st.session_state.gliner_model.eval()
labels_list = ner_labels.split(",")
updated_df = perform_ner(filtered_df, selected_column, labels_list)
st.dataframe(updated_df)
def to_excel(df):
output = BytesIO()
df.to_pandas().to_excel(output, index=False, engine="openpyxl")
return output.getvalue()
df_excel = to_excel(updated_df)
st.download_button(
label="πŸ“₯ Download Excel",
data=df_excel,
file_name="ner_results.xlsx",
mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
st.button(
"Stop Processing",
on_click=lambda: setattr(st.session_state, "stop_processing", True),
)
if __name__ == "__main__":
main()