njgroene commited on
Commit
fed756e
1 Parent(s): 1ea506d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -6
app.py CHANGED
@@ -15,10 +15,15 @@ from torch_mtcnn import show_bboxes
15
  def pipeline(img):
16
  bounding_boxes, landmarks = detect_faces(img)
17
  bb = [bounding_boxes[0,0], bounding_boxes[0,1], bounding_boxes[0,2], bounding_boxes[0,3]]
 
 
 
 
 
 
18
  img_cropped = img.crop(bb)
19
-
20
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
-
22
  model_fair_7 = torchvision.models.resnet34(pretrained=True)
23
  model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
24
  model_fair_7.load_state_dict(torch.load('res34_fair_align_multi_7_20190809.pt', map_location=torch.device('cpu')))
@@ -80,12 +85,15 @@ def pipeline(img):
80
  result.loc[result['age_preds_fair'] == 6, 'age'] = '50-59'
81
  result.loc[result['age_preds_fair'] == 7, 'age'] = '60-69'
82
  result.loc[result['age_preds_fair'] == 8, 'age'] = '70+'
83
-
84
- return [result['gender'][0],result['age'][0]]
85
 
86
  def predict(image):
87
- predictions = pipeline(image)
88
- return "A " + predictions[0] + " in the age range of " + predictions[1]
 
 
 
89
 
90
  gr.Interface(
91
  predict,
 
15
  def pipeline(img):
16
  bounding_boxes, landmarks = detect_faces(img)
17
  bb = [bounding_boxes[0,0], bounding_boxes[0,1], bounding_boxes[0,2], bounding_boxes[0,3]]
18
+
19
+ if len(bb) == 0:
20
+ raise Exception("Didn't face any faces, try another image!")
21
+ if len(bb) > 1:
22
+ raise Exception("Found more than one face, try a profile picture with only one person in it!")
23
+
24
  img_cropped = img.crop(bb)
25
+
26
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
27
  model_fair_7 = torchvision.models.resnet34(pretrained=True)
28
  model_fair_7.fc = nn.Linear(model_fair_7.fc.in_features, 18)
29
  model_fair_7.load_state_dict(torch.load('res34_fair_align_multi_7_20190809.pt', map_location=torch.device('cpu')))
 
85
  result.loc[result['age_preds_fair'] == 6, 'age'] = '50-59'
86
  result.loc[result['age_preds_fair'] == 7, 'age'] = '60-69'
87
  result.loc[result['age_preds_fair'] == 8, 'age'] = '70+'
88
+
89
+ return "A " + result['gender'][0] + " in the age range of " + result['age'][0]
90
 
91
  def predict(image):
92
+ try :
93
+ predictions = pipeline(image)
94
+ except Exception as e:
95
+ predictions = e
96
+ return predictions
97
 
98
  gr.Interface(
99
  predict,