Spaces:
Sleeping
Sleeping
File size: 4,494 Bytes
9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 1a9ccec 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 9ecca49 8631332 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
"""
Gradio web application
"""
import os
import json
import requests
import gradio as gr
from dotenv import load_dotenv, find_dotenv
from classification.classifier import Classifier
AWS_API = None
# Initialize API URLs from env file or global settings
def retrieve_api():
"""Initialize API URLs from env file or global settings"""
env_path = find_dotenv("config_api.env")
if env_path:
load_dotenv(dotenv_path=env_path)
print("config_api.env file loaded successfully.")
else:
print("config_api.env file not found.")
# Use of AWS endpoint or local container by default
global AWS_API
AWS_API = os.getenv("AWS_API", default="http://localhost:8000")
def initialize_classifier():
"""Initialize ML classifier"""
cls = Classifier()
return cls
def predict_class_local(sepl, sepw, petl, petw):
"""ML prediction using direct source code - local"""
data = list(map(float, [sepl, sepw, petl, petw]))
cls = initialize_classifier()
results = cls.load_and_test(data)
return results
def predict_class_aws(sepl, sepw, petl, petw):
"""ML prediction using AWS API endpoint"""
if AWS_API == "http://localhost:8080":
api_endpoint = AWS_API + "/2015-03-31/functions/function/invocations"
else:
api_endpoint = AWS_API + "/test/classify"
data = list(map(float, [sepl, sepw, petl, petw]))
json_object = {"features": [data]}
response = requests.post(api_endpoint, json=json_object, timeout=60)
if response.status_code == 200:
# Process the response
response_json = response.json()
results_dict = json.loads(response_json["body"])
else:
results_dict = {"Error": response.status_code}
gr.Error(f"\t API Error: {response.status_code}")
return results_dict
def predict(sepl, sepw, petl, petw, execution_type):
"""ML prediction - local or via API endpoint"""
print("ML prediction type: ", execution_type)
results = None
if execution_type == "Local":
results = predict_class_local(sepl, sepw, petl, petw)
elif execution_type == "AWS API":
results = predict_class_aws(sepl, sepw, petl, petw)
prediction = results["predictions"][0]
confidence = max(results["probabilities"][0])
return f"Prediction: {prediction} \t - \t Confidence: {confidence:.3f}"
# Define the Gradio interface
def user_interface():
"""Gradio application"""
description = """
Aims: Categorization of different species of iris flowers (Setosa, Versicolor, and Virginica)
based on measurements of physical characteristics (sepals and petals).
Notes: This web application uses two types of machine learning predictions:
- local prediction (direct source code)
- cloud prediction via an AWS API (i.e. use of ECR, Lambda function and API Gateway)
"""
with gr.Blocks() as demo:
gr.Markdown("# IRIS classification task - use of AWS Lambda")
gr.Markdown(description)
with gr.Row():
with gr.Column():
with gr.Group():
gr_sepl = gr.Slider(
minimum=4.0, maximum=8.0, step=0.1, label="Sepal Length (in cm)"
)
gr_sepw = gr.Slider(
minimum=2.0, maximum=5.0, step=0.1, label="Sepal Width (in cm)"
)
gr_petl = gr.Slider(
minimum=1.0, maximum=7.0, step=0.1, label="Petal Length (in cm)"
)
gr_petw = gr.Slider(
minimum=0.1, maximum=2.8, step=0.1, label="Petal Width (in cm)"
)
with gr.Column():
with gr.Row():
gr_execution_type = gr.Radio(
["Local", "AWS API"], value="Local", label="Prediction type"
)
with gr.Row():
gr_output = gr.Textbox(label="Prediction output")
with gr.Row():
submit_btn = gr.Button("Submit")
clear_button = gr.ClearButton()
submit_btn.click(
fn=predict,
inputs=[gr_sepl, gr_sepw, gr_petl, gr_petw, gr_execution_type],
outputs=[gr_output],
)
clear_button.click(lambda: None, inputs=None, outputs=[gr_output], queue=False)
demo.queue().launch(debug=True)
if __name__ == "__main__":
retrieve_api()
user_interface()
|