Basanth commited on
Commit
cdd41d2
1 Parent(s): 91ec262

prediction modfied

Browse files
Files changed (1) hide show
  1. utils.py +5 -3
utils.py CHANGED
@@ -63,9 +63,11 @@ def clean_text(text):
63
 
64
 
65
  def predict_cat(model, text):
66
- p = int(model.predict(text,return_proba=True).max()*100)
67
- cat = model.predict(text)
68
- return p,cat
 
 
69
 
70
 
71
  def grouper(iterable):
 
63
 
64
 
65
  def predict_cat(model, text):
66
+
67
+ logits = model.predict(text,return_proba=True)
68
+ prob = int(logits.max()*100)
69
+ cat= label_df.iloc[logits.argmax()].values[0]
70
+ return prob,cat
71
 
72
 
73
  def grouper(iterable):