antonbol commited on
Commit
e8c8ffd
1 Parent(s): 946d734
Files changed (2) hide show
  1. app.py +26 -21
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,44 +2,49 @@ import gradio as gr
2
  import numpy as np
3
  from PIL import Image
4
  import requests
5
-
6
  import hopsworks
7
  import joblib
 
8
 
9
  project = hopsworks.login()
10
  fs = project.get_feature_store()
11
 
12
 
13
  mr = project.get_model_registry()
14
- model = mr.get_model("iris_modal", version=1)
15
  model_dir = model.download()
16
- model = joblib.load(model_dir + "/iris_model.pkl")
17
-
 
18
 
19
- def iris(sepal_length, sepal_width, petal_length, petal_width):
20
- input_list = []
21
- input_list.append(sepal_length)
22
- input_list.append(sepal_width)
23
- input_list.append(petal_length)
24
- input_list.append(petal_width)
25
  # 'res' is a list of predictions returned as the label.
26
- res = model.predict(np.asarray(input_list).reshape(1, -1))
27
  # We add '[0]' to the result of the transformed 'res', because 'res' is a list, and we only want
28
  # the first element.
29
- flower_url = "https://raw.githubusercontent.com/featurestoreorg/serverless-ml-course/main/src/01-module/assets/" + res[0] + ".png"
30
- img = Image.open(requests.get(flower_url, stream=True).raw)
31
  return img
32
 
33
- demo = gr.Interface(
34
- fn=iris,
35
- title="Iris Flower Predictive Analytics",
36
- description="Experiment with sepal/petal lengths/widths to predict which flower it is.",
37
  allow_flagging="never",
38
  inputs=[
39
- gr.inputs.Number(default=1.0, label="sepal length (cm)"),
40
- gr.inputs.Number(default=1.0, label="sepal width (cm)"),
41
- gr.inputs.Number(default=1.0, label="petal length (cm)"),
42
- gr.inputs.Number(default=1.0, label="petal width (cm)"),
 
 
 
 
 
 
43
  ],
44
  outputs=gr.Image(type="pil"))
45
 
 
2
  import numpy as np
3
  from PIL import Image
4
  import requests
5
+ from feature_engineering import feat_eng
6
  import hopsworks
7
  import joblib
8
+ import pandas as pd
9
 
10
  project = hopsworks.login()
11
  fs = project.get_feature_store()
12
 
13
 
14
  mr = project.get_model_registry()
15
+ model = mr.get_model("titanic_modal_simple_classifier", version=1)
16
  model_dir = model.download()
17
+ model = joblib.load(model_dir + "/titanic_model.pkl")
18
+ leo_url = "https://media.tenor.com/FghTtX3ZgbAAAAAC/drowning-leo.gif"
19
+ rose_url = "https://media4.giphy.com/media/6A5zBPtbknIGY/giphy.gif?cid=ecf05e477syp5zeoheii45de76uicvgu0nuegojslz3zgodt&rid=giphy.gif&ct=g"
20
 
21
+ def titanic(pclass, name, sex, age, sibsp, parch, ticket, fare, cabin, embarked):
22
+ df_pre = pd.DataFrame({"PassengerId":[-1], "Pclass": [pclass], "Name": [name], "Sex": [sex], "Age": [age], "SibSp": [sibsp], "Parch": [parch], "Ticket": [ticket], "Fare": [fare], "Cabin": [cabin], "Embarked": [embarked]})
23
+ df_post = feat_eng(df_pre)
 
 
 
24
  # 'res' is a list of predictions returned as the label.
25
+ res = model.predict(df_post)[0]
26
  # We add '[0]' to the result of the transformed 'res', because 'res' is a list, and we only want
27
  # the first element.
28
+
29
+ img = Image.open(leo_url) if res == 0 else Image.open(rose_url)
30
  return img
31
 
32
+ demo = gr.Interface(
33
+ fn=titanic,
34
+ title="Titanic Survival Predictive Analytics",
35
+ description="Experiment with Titanic Passenger data to predict survival",
36
  allow_flagging="never",
37
  inputs=[
38
+ gr.inputs.Number(default=1.0, label="pclass, [1,2,3]"),
39
+ gr.inputs.Textbox(default="Anton", label="name"),
40
+ gr.inputs.Textbox(default="male", label="sex, male or female"),
41
+ gr.inputs.Number(default=25, label="age"),
42
+ gr.inputs.Number(default=2, label="sibsb"),
43
+ gr.inputs.Number(default=2, label="parch"),
44
+ gr.inputs.Textbox(default="blabla", label="Ticket"),
45
+ gr.inputs.Number(default=200, label="Fare"),
46
+ gr.inputs.Textbox(default="blabla", label="Cabin"),
47
+ gr.inputs.Textbox(default="blabla", label="Embarked: [S, C, Q]")
48
  ],
49
  outputs=gr.Image(type="pil"))
50
 
requirements.txt CHANGED
@@ -5,4 +5,5 @@ seaborn
5
  pandas
6
  numpy
7
  dataframe-image
8
- modal-client
 
 
5
  pandas
6
  numpy
7
  dataframe-image
8
+ modal-client
9
+ gradio