# set path import glob, os, sys; sys.path.append('../utils') #import needed libraries import seaborn as sns import matplotlib.pyplot as plt import numpy as np import pandas as pd import streamlit as st from utils.netzero_classifier import load_netzeroClassifier, netzero_classification import logging logger = logging.getLogger(__name__) from utils.config import get_classifier_params from io import BytesIO import xlsxwriter import plotly.express as px # Declare all the necessary variables classifier_identifier = 'netzero' params = get_classifier_params(classifier_identifier) def app(): ### Main app code ### with st.container(): if 'key1' in st.session_state: df = st.session_state.key1 # Load the classifier model classifier = load_netzeroClassifier(classifier_name=params['model_name']) st.session_state['{}_classifier'.format(classifier_identifier)] = classifier if sum(df['Target Label'] == 'TARGET') > 100: warning_msg = ": This might take sometime, please sit back and relax." else: warning_msg = "" df = netzero_classification(haystack_doc=df, threshold= params['threshold']) st.session_state.key1 = df # @st.cache_data # def to_excel(df): # len_df = len(df) # output = BytesIO() # writer = pd.ExcelWriter(output, engine='xlsxwriter') # df.to_excel(writer, index=False, sheet_name='Sheet1') # workbook = writer.book # worksheet = writer.sheets['Sheet1'] # worksheet.data_validation('E2:E{}'.format(len_df), # {'validate': 'list', # 'source': ['No', 'Yes', 'Discard']}) # writer.save() # processed_data = output.getvalue() # return processed_data # def netzero_display(): # if 'key1' in st.session_state: # df = st.session_state.key2 # hits = df[df['Netzero Label'] == 'NETZERO'] # range_val = min(5,len(hits)) # if range_val !=0: # count_df = df['Netzero Label'].value_counts() # count_df = count_df.rename('count') # count_df = count_df.rename_axis('Netzero Label').reset_index() # count_df['Label_def'] = count_df['Netzero Label'].apply(lambda x: _lab_dict[x]) # fig = px.bar(count_df, y="Label_def", x="count", orientation='h', height =200) # c1, c2 = st.columns([1,1]) # with c1: # st.plotly_chart(fig,use_container_width= True) # hits = hits.sort_values(by=['Netzero Score'], ascending=False) # st.write("") # st.markdown("###### Top few NetZero Target Classified paragraph/text results ######") # range_val = min(5,len(hits)) # for i in range(range_val): # # the page number reflects the page that contains the main paragraph # # according to split limit, the overlapping part can be on a separate page # st.write('**Result {}** `page {}` (Relevancy Score: {:.2f})'.format(i+1,hits.iloc[i]['page'],hits.iloc[i]['Netzero Score'])) # st.write("\t Text: \t{}".format(hits.iloc[i]['text'])) # else: # st.info("🤔 No Netzero target found")