irena commited on
Commit
377072d
1 Parent(s): 8ed51ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -26
app.py CHANGED
@@ -11,45 +11,65 @@ fs = project.get_feature_store()
11
 
12
 
13
  mr = project.get_model_registry()
14
- model = mr.get_model("titanic_modal", version=3)
15
  model_dir = model.download()
16
- model = joblib.load(model_dir + "/titanic_model.pkl")
17
 
18
- def titanic(pclass, sex, sibsp, parch):
19
- input_list = []
20
-
21
- input_list.append(int(sex)) # value returned by dropdown is index of option selected
22
- input_list.append(int(pclass+1)) # index starts at 0 so increment by 1
23
- input_list.append(int(sibsp))
24
- input_list.append(int(parch))
25
-
26
- print(input_list)
27
- # 'res' is a list of predictions returned as the label.
28
- #res = model.predict(np.asarray(input_list).reshape(1, -1), ntree_limit=model.best_ntree_limit) # for xgboost
29
- print(np.asarray(input_list).reshape(1, -1))
30
- res = model.predict(np.asarray(input_list).reshape(1, -1))
31
- # We add '[0]' to the result of the transformed 'res', because 'res' is a list, and we only want
32
- # the first element.
33
- print(res[0]) # 0/1
34
- # below is just for testing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  if res[0] == 0: #ded
36
- passenger_url = "https://media.istockphoto.com/id/157612035/sv/foto/shipwreck.jpg?s=612x612&w=0&k=20&c=BSVml8_SqgvSmEijAprhniyp_Wa_l5qIIVIxhmmBgBQ="
 
 
37
  else:
38
- passenger_url = "https://i.chzbgr.com/full/5420028160/hD88BD9FE/like-a-boss"
 
39
  img = Image.open(requests.get(passenger_url, stream=True).raw)
40
  return img
41
-
42
  demo = gr.Interface(
43
  fn=titanic,
44
  title="Titanic Passenger Survival Predictive Analytics",
45
  description="If one person is on titanic, predict whether he or she will survive.",
46
  allow_flagging="never",
47
- inputs=[
48
  gr.inputs.Dropdown(choices=["Male", "Female"], type="index", label="sex"),
49
- gr.inputs.Dropdown(choices=["Class 1","Class 2","Class 3"], type="index", label="pclass"),
50
- gr.inputs.Number(default=1.0, label="SibSp"),
51
- gr.inputs.Number(default=1.0, label="parch"),
 
 
52
  ],
 
53
  outputs=gr.Image(type="pil"))
54
 
55
  demo.launch()
 
11
 
12
 
13
  mr = project.get_model_registry()
14
+ model = mr.get_model("titanic_survival_modal", version=1)
15
  model_dir = model.download()
16
+ model = joblib.load(model_dir + "/titanic_survival_model.pkl")
17
 
18
+ def titanic(pclass, sex, age, fare, embarked, title, isalone):
19
+ input_list = []
20
+ input_list.append(int(pclass+1))
21
+ input_list.append(int(sex))
22
+ if age<=16:
23
+ input_list.append(0)
24
+ elif age>16 and age<=32:
25
+ input_list.append(1)
26
+ elif age>32 and age<=48:
27
+ input_list.append(2)
28
+ elif age>48 and age<=64:
29
+ input_list.append(3)
30
+ else:
31
+ input_list.append(4)
32
+ if fare<=7.91:
33
+ input_list.append(0)
34
+ elif fare>7.91 and fare<=14.454:
35
+ input_list.append(1)
36
+ elif fare>14.454 and fare<=31:
37
+ input_list.append(2)
38
+ else:
39
+ input_list.append(3)
40
+ if embarked=='C':
41
+ input_list.append(1)
42
+ elif embarked=='S':
43
+ input_list.append(0)
44
+ else:
45
+ input_list.append(2)
46
+ input_list.append(title)
47
+ input_list.append(isalone)
48
+ res = model.predict(np.asarray(input_list,dtype=object).reshape(1,-1))
49
  if res[0] == 0: #ded
50
+ person="dead"
51
+ passenger_url = "https://raw.githubusercontent.com/irena123333/titanic-prediction/main/dead.png"
52
+
53
  else:
54
+ person="survived"
55
+ passenger_url = "https://raw.githubusercontent.com/irena123333/titanic-prediction/main/survived.png"
56
  img = Image.open(requests.get(passenger_url, stream=True).raw)
57
  return img
58
+
59
  demo = gr.Interface(
60
  fn=titanic,
61
  title="Titanic Passenger Survival Predictive Analytics",
62
  description="If one person is on titanic, predict whether he or she will survive.",
63
  allow_flagging="never",
64
+ inputs=[gr.inputs.Dropdown(choices=["Class 1","Class 2","Class 3"], type="index", label="pclass"),
65
  gr.inputs.Dropdown(choices=["Male", "Female"], type="index", label="sex"),
66
+ gr.inputs.Slider(0,150,label='Age'),
67
+ gr.inputs.Number(default=8.0, label="Fare"),
68
+ gr.inputs.Radio(default='S', label="Embarkation Port", choices=['C', 'Q', 'S']),
69
+ gr.inputs.Dropdown(choices=["Master","Miss","Mr","Mrs","Other"], type="index", label="Title"),
70
+ gr.inputs.Dropdown(choices=["False", "True"], type="index", label="IsAlone"),
71
  ],
72
+
73
  outputs=gr.Image(type="pil"))
74
 
75
  demo.launch()