|
import gradio as gr |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
from sklearn.linear_model import LinearRegression |
|
import numpy as np |
|
from pandas.tseries.offsets import MonthEnd |
|
|
|
def plot_and_predict(zip, start_date, prediction_months): |
|
|
|
df = pd.read_csv('https://files.zillowstatic.com/research/public_csvs/zhvi/Zip_zhvi_uc_sfrcondo_tier_0.33_0.67_sm_sa_month.csv') |
|
df = df[df['RegionName'] == int(zip)] |
|
df = df.loc[:, '2000-01-31':] |
|
df = df.T.reset_index() |
|
df.columns = ['Date', 'Price'] |
|
df['Date'] = pd.to_datetime(df['Date']) |
|
|
|
|
|
start_date = pd.to_datetime(start_date) |
|
df = df[df['Date'] >= start_date] |
|
|
|
|
|
df['MonthsSinceStart'] = np.arange(len(df)) |
|
X = df['MonthsSinceStart'].values.reshape(-1, 1) |
|
y = df['Price'].values |
|
model = LinearRegression() |
|
model.fit(X, y) |
|
|
|
|
|
last_month_index = df['MonthsSinceStart'].iloc[-1] |
|
future_months = np.array([last_month_index + i for i in range(1, prediction_months + 1)]).reshape(-1, 1) |
|
predicted_prices = model.predict(future_months) |
|
|
|
|
|
historical_prices_trace = go.Scatter( |
|
x=df['Date'], |
|
y=df['Price'], |
|
mode="lines", |
|
name="Historical Prices" |
|
) |
|
future_dates = [df['Date'].iloc[-1] + MonthEnd(i) for i in range(1, prediction_months + 1)] |
|
predicted_prices_trace = go.Scatter( |
|
x=future_dates, |
|
y=predicted_prices, |
|
mode="lines", |
|
name="Predicted Prices" |
|
) |
|
|
|
|
|
fig = go.Figure() |
|
fig.add_trace(historical_prices_trace) |
|
fig.add_trace(predicted_prices_trace) |
|
fig.update_layout( |
|
title=f"Real Estate Price Prediction for Zip Code {zip}", |
|
xaxis_title="Date", |
|
yaxis_title="Price", |
|
legend_title_text="Data" |
|
) |
|
|
|
return fig |
|
|
|
|
|
interface = gr.Interface( |
|
fn=plot_and_predict, |
|
inputs=[ |
|
gr.Textbox(label="ZIP Code"), |
|
gr.Textbox(label="Start Date (YYYY-MM-DD)", placeholder="YYYY-MM-DD"), |
|
gr.Slider(minimum=1, maximum=60, step=1, label="Prediction Months"), |
|
], |
|
outputs="plot" |
|
) |
|
|
|
|
|
interface.launch() |