mtyrrell's picture
remove defined bullets from base prompt
0a16a11
raw
history blame
3.93 kB
# set path
import glob, os, sys;
sys.path.append('../utils')
#import needed libraries
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from utils.target_classifier import load_targetClassifier, target_classification
import logging
logger = logging.getLogger(__name__)
from utils.config import get_classifier_params
from utils.preprocessing import paraLengthCheck
from io import BytesIO
import xlsxwriter
import plotly.express as px
from utils.target_classifier import label_dict
from appStore.rag import run_query
# Declare all the necessary variables
classifier_identifier = 'target'
params = get_classifier_params(classifier_identifier)
@st.cache_data
def to_excel(df,sectorlist):
len_df = len(df)
output = BytesIO()
writer = pd.ExcelWriter(output, engine='xlsxwriter')
df.to_excel(writer, index=False, sheet_name='Sheet1')
workbook = writer.book
worksheet = writer.sheets['Sheet1']
worksheet.data_validation('S2:S{}'.format(len_df),
{'validate': 'list',
'source': ['No', 'Yes', 'Discard']})
worksheet.data_validation('X2:X{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('T2:T{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('U2:U{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('V2:V{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
worksheet.data_validation('W2:U{}'.format(len_df),
{'validate': 'list',
'source': sectorlist + ['Blank']})
writer.save()
processed_data = output.getvalue()
return processed_data
def app():
### Main app code ###
with st.container():
if 'key1' in st.session_state:
# Load the existing dataset
df = st.session_state.key1
# Filter out all paragraphs that do not have a reference to groups
df = df[df['Vulnerability Label'].apply(lambda x: len(x) > 0 and 'Other' not in x)]
# Load the classifier model
classifier = load_targetClassifier(classifier_name=params['model_name'])
st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
df = target_classification(haystack_doc=df,
threshold= params['threshold'])
# Rename column
df.rename(columns={'Target Label': 'Specific action/target/measure mentioned'}, inplace=True)
st.session_state.key2 = df
def target_display():
### TABLE Output ###
# Assign dataframe a name
df = st.session_state['key2']
# st.write(df)
### RAG Output by group ##
# Expand the DataFrame
df_expand = df.explode('Vulnerability Label')
# Group by 'Vulnerability Label' and concatenate 'text'
df_agg = df_expand.groupby('Vulnerability Label')['text'].agg('; '.join).reset_index()
st.markdown("----")
st.markdown('**DOCUMENT FINDINGS SUMMARY BY VULNERABILITY LABEL:**')
# construct RAG query for each label, send to openai and process response
for i in range(0,len(df_agg)):
st.write(df_agg['Vulnerability Label'].iloc[i])
run_query(context = df_agg['text'].iloc[i], label = df_agg['Vulnerability Label'].iloc[i])
# st.write(df_agg['text'].iloc[i])