Paula Leonova commited on
Commit
d4be6e6
1 Parent(s): 2b16dfe

Add evaluation metrics

Browse files
Files changed (2) hide show
  1. app.py +20 -0
  2. requirements.txt +1 -0
app.py CHANGED
@@ -5,6 +5,8 @@ import pandas as pd
5
  import base64
6
  from typing import Sequence
7
  import streamlit as st
 
 
8
 
9
  from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
10
  from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
@@ -102,7 +104,16 @@ if submit_button:
102
  plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
103
 
104
  data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
 
105
  data2 = pd.merge(data, data_ex_text, on = ['label'])
 
 
 
 
 
 
 
 
106
  st.markdown("### Data Table")
107
 
108
  with st.spinner('Generating a table of results and a download link...'):
@@ -112,5 +123,14 @@ if submit_button:
112
  unsafe_allow_html = True
113
  )
114
  st.dataframe(data2)
 
 
 
 
 
 
 
 
 
115
  st.success('All done!')
116
  st.balloons()
 
5
  import base64
6
  from typing import Sequence
7
  import streamlit as st
8
+ from sklearn.metrics import classification_report
9
+
10
 
11
  from models import create_nest_sentences, load_summary_model, summarizer_gen, load_model, classifier_zero
12
  from utils import plot_result, plot_dual_bar_chart, examples_load, example_long_text_load
 
104
  plot_dual_bar_chart(topics, scores, topics_ex_text, scores_ex_text)
105
 
106
  data_ex_text = pd.DataFrame({'label': topics_ex_text, 'scores_from_full_text': scores_ex_text})
107
+
108
  data2 = pd.merge(data, data_ex_text, on = ['label'])
109
+
110
+ if len(glabels) > 0:
111
+ gdata = pd.DataFrame({'label': glabels})
112
+ gdata['is_true_label'] = 1
113
+
114
+ data2 = pd.merge(data2, gdata, how = 'left', on = ['label'])
115
+ data2['is_true_label'].fillna(0, inplace = True)
116
+
117
  st.markdown("### Data Table")
118
 
119
  with st.spinner('Generating a table of results and a download link...'):
 
123
  unsafe_allow_html = True
124
  )
125
  st.dataframe(data2)
126
+
127
+ if len(glabels) > 0:
128
+ with st.spinner('Evaluating output against ground truth...'):
129
+ report = classification_report(y_true = data2[['is_true_label']],
130
+ y_pred = (data2[['scores_from_full_text']] >= threshold_value) * 1.0,
131
+ output_dict=True)
132
+ df_report = pd.DataFrame(report).transpose()
133
+ st.dataframe(df_report)
134
+
135
  st.success('All done!')
136
  st.balloons()
requirements.txt CHANGED
@@ -3,5 +3,6 @@ pandas
3
  streamlit
4
  plotly
5
  torch
 
6
  spacy>=2.2.0,<3.0.0
7
  https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz#egg=en_core_web_sm
 
3
  streamlit
4
  plotly
5
  torch
6
+ sklearn
7
  spacy>=2.2.0,<3.0.0
8
  https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.0/en_core_web_sm-2.2.0.tar.gz#egg=en_core_web_sm