File size: 3,177 Bytes
47756f1
 
 
 
 
 
 
 
 
 
83a24ec
47756f1
 
 
 
 
 
 
83a24ec
47756f1
dbd62d7
1968c31
dbd62d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47756f1
 
1968c31
47756f1
 
 
 
84050ab
47756f1
 
 
aeb4fb8
2f12850
aeb4fb8
47756f1
 
eb83e3d
47756f1
0a2b1df
47756f1
 
84050ab
f1aec70
3f5271b
f1aec70
 
329d6cf
8f420e0
55c1b89
 
 
 
329d6cf
55c1b89
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# set path
import glob, os, sys; 
sys.path.append('../utils')

#import needed libraries
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import streamlit as st
from utils.target_classifier import load_targetClassifier, target_classification
import logging
logger = logging.getLogger(__name__)
from utils.config import get_classifier_params
from utils.preprocessing import paraLengthCheck
from io import BytesIO
import xlsxwriter
import plotly.express as px
from utils.target_classifier import label_dict

# Declare all the necessary variables
classifier_identifier = 'target'
params  = get_classifier_params(classifier_identifier)

@st.cache_data
def to_excel(df,sectorlist):
    len_df = len(df)
    output = BytesIO()
    writer = pd.ExcelWriter(output, engine='xlsxwriter')
    df.to_excel(writer, index=False, sheet_name='Sheet1')
    workbook = writer.book
    worksheet = writer.sheets['Sheet1']
    worksheet.data_validation('S2:S{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': ['No', 'Yes', 'Discard']})
    worksheet.data_validation('X2:X{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})
    worksheet.data_validation('T2:T{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})
    worksheet.data_validation('U2:U{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})                               
    worksheet.data_validation('V2:V{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})
    worksheet.data_validation('W2:U{}'.format(len_df), 
                              {'validate': 'list', 
                               'source': sectorlist + ['Blank']})                            
    writer.save()
    processed_data = output.getvalue()
    return processed_data

def app():
    
    ### Main app code ###
    with st.container():
        
        if 'key1' in st.session_state:
           
            # Load the existing dataset
            df = st.session_state.key1

            # Filter out all paragraphs that do not have a reference to groups 
            df = df[df['Vulnerability Label'].apply(lambda x: len(x) > 0 and 'Other' not in x)]

            # Load the classifier model
            classifier = load_targetClassifier(classifier_name=params['model_name'])
         
            st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
                
            df = target_classification(haystack_doc=df,
                                        threshold= params['threshold'])

            # Rename column 
            df.rename(columns={'Target Label': 'Specific action/target/measure mentioned'}, inplace=True)


            st.session_state.key2 = df


def target_display(): 
    
    # Assign dataframe a name
    df = st.session_state['key2']

    st.write(df)