IanRonk commited on
Commit
5ae3a16
1 Parent(s): 737faca

Add conditional

Browse files
Files changed (1) hide show
  1. functions/model_infer.py +7 -4
functions/model_infer.py CHANGED
@@ -39,8 +39,11 @@ def predict_from_document(sentences):
39
  preprop = preprocess(sentences)
40
  prediction = model.predict(preprop)
41
  # Set the prediction threshold to 0.8 instead of 0.5, now use mean
42
- output = (
43
- prediction.flatten()[: len(sentences)]
44
- >= np.mean(prediction) * 1.20 # + np.std(prediction)
45
- ).astype(int)
 
 
 
46
  return output, prediction.flatten()[: len(sentences)]
 
39
  preprop = preprocess(sentences)
40
  prediction = model.predict(preprop)
41
  # Set the prediction threshold to 0.8 instead of 0.5, now use mean
42
+ if np.mean(prediction) < 0.5:
43
+ output = (prediction.flatten()[: len(sentences)] >= 0.5).astype(int)
44
+ else:
45
+ output = (
46
+ prediction.flatten()[: len(sentences)]
47
+ >= np.mean(prediction) * 1.20 # + np.std(prediction)
48
+ ).astype(int)
49
  return output, prediction.flatten()[: len(sentences)]