File size: 4,494 Bytes
9ecca49
 
 
 
8631332
9ecca49
8631332
 
 
9ecca49
 
 
 
 
8631332
 
 
9ecca49
8631332
9ecca49
8631332
 
 
 
 
 
 
 
 
 
9ecca49
8631332
9ecca49
 
8631332
9ecca49
8631332
 
 
9ecca49
 
8631332
9ecca49
8631332
 
 
 
 
9ecca49
 
1a9ccec
9ecca49
8631332
9ecca49
8631332
 
9ecca49
8631332
9ecca49
8631332
 
 
 
 
 
 
 
 
 
9ecca49
 
 
 
 
 
8631332
9ecca49
8631332
 
 
 
 
 
 
 
 
 
9ecca49
 
 
 
 
 
 
 
 
 
 
8631332
 
9ecca49
8631332
 
 
 
9ecca49
 
 
 
 
 
 
 
 
 
 
 
8631332
 
9ecca49
 
 
8631332
 
 
 
 
 
 
9ecca49
 
 
 
 
8631332
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Gradio web application
"""

import os
import json
import requests
import gradio as gr
from dotenv import load_dotenv, find_dotenv

from classification.classifier import Classifier

AWS_API = None


# Initialize API URLs from env file or global settings
def retrieve_api():
    """Initialize API URLs from env file or global settings"""

    env_path = find_dotenv("config_api.env")
    if env_path:
        load_dotenv(dotenv_path=env_path)
        print("config_api.env file loaded successfully.")
    else:
        print("config_api.env file not found.")

    # Use of AWS endpoint or local container by default
    global AWS_API
    AWS_API = os.getenv("AWS_API", default="http://localhost:8000")


def initialize_classifier():
    """Initialize ML classifier"""

    cls = Classifier()
    return cls


def predict_class_local(sepl, sepw, petl, petw):
    """ML prediction using direct source code - local"""

    data = list(map(float, [sepl, sepw, petl, petw]))
    cls = initialize_classifier()
    results = cls.load_and_test(data)
    return results


def predict_class_aws(sepl, sepw, petl, petw):
    """ML prediction using AWS API endpoint"""

    if AWS_API == "http://localhost:8080":
        api_endpoint = AWS_API + "/2015-03-31/functions/function/invocations"
    else:
        api_endpoint = AWS_API + "/test/classify"

    data = list(map(float, [sepl, sepw, petl, petw]))
    json_object = {"features": [data]}

    response = requests.post(api_endpoint, json=json_object, timeout=60)
    if response.status_code == 200:
        # Process the response
        response_json = response.json()
        results_dict = json.loads(response_json["body"])
    else:
        results_dict = {"Error": response.status_code}
        gr.Error(f"\t API Error: {response.status_code}")
    return results_dict


def predict(sepl, sepw, petl, petw, execution_type):
    """ML prediction - local or via API endpoint"""

    print("ML prediction type: ", execution_type)
    results = None
    if execution_type == "Local":
        results = predict_class_local(sepl, sepw, petl, petw)
    elif execution_type == "AWS API":
        results = predict_class_aws(sepl, sepw, petl, petw)

    prediction = results["predictions"][0]
    confidence = max(results["probabilities"][0])

    return f"Prediction: {prediction} \t - \t Confidence: {confidence:.3f}"


# Define the Gradio interface
def user_interface():
    """Gradio application"""

    description = """
    Aims: Categorization of different species of iris flowers (Setosa, Versicolor, and Virginica) 
    based on measurements of physical characteristics (sepals and petals).

    Notes: This web application uses two types of machine learning predictions:
       - local prediction (direct source code) 
       - cloud prediction via an AWS API (i.e. use of ECR, Lambda function and API Gateway)
    """

    with gr.Blocks() as demo:
        gr.Markdown("# IRIS classification task - use of AWS Lambda")
        gr.Markdown(description)

        with gr.Row():
            with gr.Column():
                with gr.Group():
                    gr_sepl = gr.Slider(
                        minimum=4.0, maximum=8.0, step=0.1, label="Sepal Length (in cm)"
                    )
                    gr_sepw = gr.Slider(
                        minimum=2.0, maximum=5.0, step=0.1, label="Sepal Width (in cm)"
                    )
                    gr_petl = gr.Slider(
                        minimum=1.0, maximum=7.0, step=0.1, label="Petal Length (in cm)"
                    )
                    gr_petw = gr.Slider(
                        minimum=0.1, maximum=2.8, step=0.1, label="Petal Width (in cm)"
                    )
            with gr.Column():
                with gr.Row():
                    gr_execution_type = gr.Radio(
                        ["Local", "AWS API"], value="Local", label="Prediction type"
                    )
                with gr.Row():
                    gr_output = gr.Textbox(label="Prediction output")

        with gr.Row():
            submit_btn = gr.Button("Submit")
            clear_button = gr.ClearButton()

        submit_btn.click(
            fn=predict,
            inputs=[gr_sepl, gr_sepw, gr_petl, gr_petw, gr_execution_type],
            outputs=[gr_output],
        )
        clear_button.click(lambda: None, inputs=None, outputs=[gr_output], queue=False)
    demo.queue().launch(debug=True)


if __name__ == "__main__":
    retrieve_api()
    user_interface()