Update utils/target_classifier.py
Browse files
utils/target_classifier.py
CHANGED
@@ -23,7 +23,8 @@ def get_target_labels(preds):
|
|
23 |
|
24 |
# Get label names
|
25 |
preds_list = preds.tolist()
|
26 |
-
|
|
|
27 |
predictions_names=[]
|
28 |
|
29 |
# loop through each prediction
|
@@ -112,13 +113,13 @@ def target_classification(haystack_doc:pd.DataFrame,
|
|
112 |
|
113 |
# Get predictions
|
114 |
predictions = classifier_model(list(haystack_doc.text))
|
115 |
-
st.write(
|
116 |
-
st.write(predictions)
|
117 |
|
118 |
# Get labels for predictions
|
119 |
pred_labels = get_target_labels(predictions)
|
120 |
-
st.write(pred_labels)
|
121 |
-
|
122 |
# Save labels
|
123 |
haystack_doc['Target Label'] = pred_labels
|
124 |
|
|
|
23 |
|
24 |
# Get label names
|
25 |
preds_list = preds.tolist()
|
26 |
+
st.write('preds_list')
|
27 |
+
st.write(preds_list)
|
28 |
predictions_names=[]
|
29 |
|
30 |
# loop through each prediction
|
|
|
113 |
|
114 |
# Get predictions
|
115 |
predictions = classifier_model(list(haystack_doc.text))
|
116 |
+
st.write('predictions')
|
117 |
+
st.write(predictions[:10])
|
118 |
|
119 |
# Get labels for predictions
|
120 |
pred_labels = get_target_labels(predictions)
|
121 |
+
st.write('pred_labels')
|
122 |
+
st.write(pred_labels[:10])
|
123 |
# Save labels
|
124 |
haystack_doc['Target Label'] = pred_labels
|
125 |
|