Lab4-Demo / app.py
slyskawa's picture
Update app.py
6b803dc verified
raw
history blame contribute delete
No virus
3.02 kB
import gradio as gr
from joblib import load
import numpy as np
from PIL import Image
#import ML model
ml_model = load('working_model.joblib')
#load map of housing
cali_image = Image.open('cali_map.png')
otterDot_image = Image.open('otterDot.png')
#input sliders
input_module1 = gr.Slider(-124.35, -114.35, step=0.25, label="Longitude")
input_module2 = gr.Slider(32, 41, step=0.25, label="Latitude")
input_module3 = gr.Slider(1, 52, step=1, label="Housing Median Age")
input_module4 = gr.Slider(1, 39996, step=12, label="Total Rooms")
input_module5 = gr.Slider(1, 6441, step=1, label="Total Bedrooms")
input_module6 = gr.Slider(3, 35678, step=50, label="Population")
input_module7 = gr.Slider(1, 6081, step=10, label="Households")
input_module8 = gr.Slider(0, 15, step=0.1, label="Median Income")
#output modules
output_module1 = gr.Textbox(label = "Predicted Housing Price")
output_module2 = gr.Image(label = "California Housing Map")
#function
def predict(longitude, latitude, housing_median_age, total_rooms, total_bedrooms, population, households, median_income):
input_features = np.array([longitude, latitude, housing_median_age, total_rooms, total_bedrooms, population, households, median_income])
prediction = ml_model.predict([input_features])
#output 1 handling
output1 = f"${prediction[0]:,.2f}"
#output2 handling
calimap = cali_image.copy()
lon_min,lon_max=-124,-114
lat_min,lat_max=32,42
x_min,x_max=125,744
y_min,y_max=86,625
scale_x=(x_max-x_min)/(lon_max-lon_min)
scale_y=(y_max-y_min)/(lat_max-lat_min)
def longlat_to_img(longitude,latitude):
x=(longitude-lon_min)*scale_x+x_min
y=(lat_max-latitude)*scale_y+y_min
return int(x),int(y)
x,y=longlat_to_img(longitude,latitude)
calimap.paste(otterDot_image,(x-otterDot_image.width//2,y-otterDot_image.height//2),otterDot_image)
output2 = calimap
#return
return output1,output2
examples = [
[-122.24, 37.85, 52.0, 1467.0, 190.0, 496.0, 177.0, 7.2574, 352100.0],
[-122.25, 37.84, 52.0, 2535.0, 489.0, 1094.0, 514.0, 3.6591, 299200.0],
[-118.13, 33.87, 45.0, 1606.0, 300.0, 735.0, 295.0, 4.6765, 198400.0],
[-118.14, 33.87, 44.0, 1661.0, 315.0, 985.0, 319.0, 4.3942, 219500.0],
[-116.19, 33.67, 16.0, 1859.0, 476.0, 1994.0, 477.0, 1.7297, 67500.0],
[-114.56, 33.69, 17.0, 720.0, 174.0, 333.0, 117.0, 1.6509, 85700.0],
[-114.57, 33.64, 14.0, 1501.0, 337.0, 515.0, 226.0, 3.1917, 73400.0],
[-120.92, 37.65, 23.0, 505.0, 124.0, 163.0, 129.0, 1.3696, 275000.0],
[-116.94, 32.78, 17.0, 13559.0, 2656.0, 6990.0, 2533.0, 3.434, 193200.0],
[-120.98, 37.62, 26.0, 3819.0, 955.0, 3010.0, 932.0, 1.9206, 81300.0]
]
gr.Interface(fn=predict,
inputs=[input_module1, input_module2, input_module3,
input_module4, input_module5, input_module6,
input_module7,input_module8],
outputs=[output_module1, output_module2],examples = examples
).launch(debug=True)