leavoigt commited on
Commit
47756f1
1 Parent(s): f5a6f87

Update appStore/target.py

Browse files
Files changed (1) hide show
  1. appStore/target.py +38 -24
appStore/target.py CHANGED
@@ -1,27 +1,41 @@
1
- # # set path
2
- # import glob, os, sys;
3
- # sys.path.append('../utils')
4
-
5
- # #import needed libraries
6
- # import seaborn as sns
7
- # import matplotlib.pyplot as plt
8
- # import numpy as np
9
- # import pandas as pd
10
- # import streamlit as st
11
- # from st_aggrid import AgGrid
12
- # from utils.target_classifier import load_targetClassifier, target_classification
13
- # import logging
14
- # logger = logging.getLogger(__name__)
15
- # from utils.config import get_classifier_params
16
- # from io import BytesIO
17
- # import xlsxwriter
18
- # import plotly.express as px
19
- # from pandas.api.types import (
20
- # is_categorical_dtype,
21
- # is_datetime64_any_dtype,
22
- # is_numeric_dtype,
23
- # is_object_dtype,
24
- # is_list_like)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  # # Declare all the necessary variables
27
  # classifier_identifier = 'target'
 
1
+ # set path
2
+ import glob, os, sys;
3
+ sys.path.append('../utils')
4
+
5
+ #import needed libraries
6
+ import seaborn as sns
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import pandas as pd
10
+ import streamlit as st
11
+ from utils.vulnerability_classifier import load_targetClassifier, target_classification
12
+ import logging
13
+ logger = logging.getLogger(__name__)
14
+ from utils.config import get_classifier_params
15
+ from utils.preprocessing import paraLengthCheck
16
+ from io import BytesIO
17
+ import xlsxwriter
18
+ import plotly.express as px
19
+ from utils.vulnerability_classifier import label_dict
20
+
21
+
22
+ def app():
23
+ ### Main app code ###
24
+ with st.container():
25
+
26
+ if 'key1' in st.session_state:
27
+
28
+ # Load the existing dataset
29
+ df = st.session_state.key1
30
+
31
+ # Load the classifier model
32
+ classifier = load_targetClassifier(classifier_name=params['model_name'])
33
+ st.session_state['{}_classifier'.format(classifier_identifier)] = classifier
34
+
35
+
36
+ df = target_classification(haystack_doc=df,
37
+ threshold= params['threshold'])
38
+ st.session_state.key1 = df
39
 
40
  # # Declare all the necessary variables
41
  # classifier_identifier = 'target'