Joshua Lochner commited on
Commit
ad7fc61
1 Parent(s): a6de017

Use classifier category if transformer generates unknown category

Browse files
Files changed (1) hide show
  1. src/predict.py +6 -4
src/predict.py CHANGED
@@ -106,7 +106,7 @@ class ClassifierArguments:
106
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
107
 
108
 
109
- def add_predictions(predictions, classifier_args): # classifier, vectorizer,
110
  """Use classifier to filter predictions"""
111
  if not predictions:
112
  return predictions
@@ -134,8 +134,10 @@ def add_predictions(predictions, classifier_args): # classifier, vectorizer,
134
  if classifier_category is None and classifier_probability > classifier_args.min_probability:
135
  continue # Ignore
136
 
137
- if classifier_category is not None and classifier_probability > 0.5: # TODO make param
138
- # Confident enough to overrule, so we update category
 
 
139
  prediction['category'] = classifier_category
140
 
141
  prediction['probability'] = predicted_probabilities[prediction['category']]
@@ -173,7 +175,7 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
173
 
174
  # TODO add back
175
  if classifier_args is not None:
176
- predictions = add_predictions(predictions, classifier_args)
177
 
178
  return predictions
179
 
106
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
107
 
108
 
109
+ def filter_and_add_probabilities(predictions, classifier_args): # classifier, vectorizer,
110
  """Use classifier to filter predictions"""
111
  if not predictions:
112
  return predictions
134
  if classifier_category is None and classifier_probability > classifier_args.min_probability:
135
  continue # Ignore
136
 
137
+ if (prediction['category'] not in predicted_probabilities) \
138
+ or (classifier_category is not None and classifier_probability > 0.5): # TODO make param
139
+ # Unknown category or we are confident enough to overrule,
140
+ # so change category to what was predicted by classifier
141
  prediction['category'] = classifier_category
142
 
143
  prediction['probability'] = predicted_probabilities[prediction['category']]
175
 
176
  # TODO add back
177
  if classifier_args is not None:
178
+ predictions = filter_and_add_probabilities(predictions, classifier_args)
179
 
180
  return predictions
181