DavidMoreda commited on
Commit
c1e3ba9
·
1 Parent(s): 0d542f6

Subida app

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import the necessary libraries
2
+ import gradio as gr
3
+ import joblib
4
+ import numpy as np
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ # Download both models from HF Hub at startup
8
+ # Both models live in the same repository: brjapon/iris-dt
9
+ dt_path = hf_hub_download(repo_id="DavidMoreda/iris-dt", filename="iris_dt.joblib", repo_type="model")
10
+ lr_path = hf_hub_download(repo_id="DavidMoreda/iris-dt", filename="iris_logreg.joblib", repo_type="model")
11
+
12
+ models = {
13
+ "Decision Tree": joblib.load(dt_path),
14
+ "Logistic Regression": joblib.load(lr_path),
15
+ }
16
+
17
+ LABELS = {0: "Iris-setosa", 1: "Iris-versicolor", 2: "Iris-virginica"}
18
+
19
+ # The function now accepts a model_choice parameter (from a Gradio dropdown)
20
+ def predict_iris(model_choice, sepal_length, sepal_width, petal_length, petal_width):
21
+ pipeline = models[model_choice]
22
+ input = np.array([[sepal_length, sepal_width, petal_length, petal_width]])
23
+ prediction = pipeline.predict(input)
24
+ return LABELS.get(int(prediction[0]), "Invalid prediction")
25
+
26
+ interface = gr.Interface(
27
+ fn=predict_iris,
28
+ inputs=[
29
+ gr.Dropdown(choices=["Decision Tree", "Logistic Regression"], label="Model", value="Decision Tree"),
30
+ gr.Number(label="Sepal Length (cm)"),
31
+ gr.Number(label="Sepal Width (cm)"),
32
+ gr.Number(label="Petal Length (cm)"),
33
+ gr.Number(label="Petal Width (cm)"),
34
+ ],
35
+ outputs="text",
36
+ live=True,
37
+ title="Iris Species Identifier",
38
+ description="Choose a model and enter the four measurements to predict the Iris species.",
39
+ flagging_mode="manual",
40
+ flagging_dir="flagged"
41
+ )
42
+
43
+ if __name__ == "__main__":
44
+ interface.launch()
45
+
46
+ '''
47
+ # The Flag button allows users (or testers) to mark or “flag”
48
+ # a particular input-output interaction for later review.
49
+ # When someone clicks Flag, Gradio saves the input values (and often the output) to a log.csv file
50
+ # letting you keep track of interesting or potentially problematic cases for debugging or analysis later on
51
+ '''