CullerWhale commited on
Commit
28ebd2e
·
verified ·
1 Parent(s): 525ba71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -7
app.py CHANGED
@@ -42,16 +42,35 @@ labels = learn.dls.vocab
42
  # pred,pred_idx,probs = learn.predict(img)
43
  # return {labels[i]: float(probs[i]) for i in range(len(labels))}
44
 
45
- def predict(img):
 
46
  img = PILImage.create(img)
47
  pred, pred_idx, probs = learn.predict(img)
48
- results = {labels[i]: float(probs[i]) for i in range(len(labels))}
49
- # Adjust results to highlight when 'Survived' meets the 75% threshold
50
- if results['Survived'] >= 0.85:
51
- results['Survived'] = 1.0 # Indicating high confidence of survival
 
 
 
 
 
52
  else:
53
- results['Survived'] = 0.0 # Indicating it did not meet the threshold
54
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  # def predict(img):
 
42
  # pred,pred_idx,probs = learn.predict(img)
43
  # return {labels[i]: float(probs[i]) for i in range(len(labels))}
44
 
45
+
46
+ def predict(img, percent = 0.9):
47
  img = PILImage.create(img)
48
  pred, pred_idx, probs = learn.predict(img)
49
+ # Get the index of 'Not Survived' and 'Survived' labels
50
+ idx_not_survived = labels.index('Not Survived')
51
+ idx_survived = labels.index('Survived')
52
+ prob_not_survived = probs[idx_not_survived]
53
+ prob_survived = probs[idx_survived]
54
+ # Calculate threshold based on desired percent
55
+ threshold = 1 - percent
56
+ if prob_not_survived > threshold:
57
+ return {'Not Survived': float(prob_not_survived)}
58
  else:
59
+ return {'Survived': float(prob_survived)}
60
+
61
+
62
+
63
+
64
+ # def predict(img):
65
+ # img = PILImage.create(img)
66
+ # pred, pred_idx, probs = learn.predict(img)
67
+ # results = {labels[i]: float(probs[i]) for i in range(len(labels))}
68
+ # # Adjust results to highlight when 'Survived' meets the 75% threshold
69
+ # if results['Survived'] >= 0.85:
70
+ # results['Survived'] = 1.0 # Indicating high confidence of survival
71
+ # else:
72
+ # results['Survived'] = 0.0 # Indicating it did not meet the threshold
73
+ # return results
74
 
75
 
76
  # def predict(img):