rscolati commited on
Commit
745f0f8
1 Parent(s): 27feda5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +67 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import requests
5
+ import hopsworks
6
+ import joblib
7
+
8
+ project = hopsworks.login()
9
+ fs = project.get_feature_store()
10
+
11
+ mr = project.get_model_registry()
12
+ model = mr.get_model("titanic_modal", version=1)
13
+ model_dir = model.download()
14
+ model = joblib.load(model_dir + "/titanic_model.pkl")
15
+
16
+
17
+ def titanic(age, sex, pclass):
18
+ input_list = []
19
+
20
+ input_list.append(int(age))
21
+ input_list.append(int(sex)) # value returned by dropdown is index of option selected
22
+ input_list.append(int(pclass+1)) # index starts at 0 so increment by 1
23
+
24
+ # Bin input age to bin index of range
25
+ if 0 < input_list[0] <= 20:
26
+ input_list[0] = 0
27
+ elif 20 < input_list[0] <= 50:
28
+ input_list[0] = 1
29
+ elif 50 < input_list[0] <= 75:
30
+ input_list[0] = 2
31
+ elif input_list[0] > 75:
32
+ input_list[0] = 3
33
+ else:
34
+ # we should just assume < 0 = 0..
35
+ print("Incorrect age value set. Try again.")
36
+
37
+ #print(input_list)
38
+ # 'res' is a list of predictions returned as the label.
39
+ #res = model.predict(np.asarray(input_list).reshape(1, -1), ntree_limit=model.best_ntree_limit) # for xgboost
40
+ #print(np.asarray(input_list).reshape(1, -1))
41
+ res = model.predict(np.asarray(input_list).reshape(1, -1))
42
+ # We add '[0]' to the result of the transformed 'res', because 'res' is a list, and we only want
43
+ # the first element.
44
+ #print(res[0]) # 0/1
45
+ # below is just for testing
46
+ if res[0] == 0: #ded
47
+ passenger_url = "https://media.istockphoto.com/id/157612035/sv/foto/shipwreck.jpg?s=612x612&w=0&k=20&c=BSVml8_SqgvSmEijAprhniyp_Wa_l5qIIVIxhmmBgBQ="
48
+ else:
49
+ passenger_url = "https://i.chzbgr.com/full/5420028160/hD88BD9FE/like-a-boss"
50
+ img = Image.open(requests.get(passenger_url, stream=True).raw)
51
+ return img
52
+
53
+
54
+ demo = gr.Interface(
55
+ fn=titanic,
56
+ title="Titanic Passenger Survival Predictive Analytics",
57
+ description="Experiment with some passenger features to predict whether your passenger would have survived or not.",
58
+ allow_flagging="never",
59
+ inputs=[
60
+ gr.inputs.Number(default=1, label="Age"),
61
+ gr.inputs.Dropdown(choices=["Male", "Female"], type="index", label="Sex"),
62
+ gr.inputs.Dropdown(choices=["Class 1", "Class 2", "Class 3"],
63
+ type="index", label="Pclass"),
64
+ ],
65
+ outputs=gr.Image(type="pil"))
66
+
67
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ hopsworks
2
+ joblib
3
+ scikit-learn