site / app.py
Haderstafed's picture
Update app.py
88ac3d2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.stattools import adfuller
import gradio as gr
import pickle
from datasets import load_dataset
dataset = load_dataset("Haderstafed/DataSetic")
model_path = hf_hub_download(repo_id="Haderstafed/wine", filename="model_sprice.pkl")
model_path = hf_hub_download(repo_id="Haderstafed/wine", filename="model_mprice.pkl")
data = pd.DataFrame(dataset['train'])
# Проверка стационарности ряда с помощью теста Дики-Фуллера
def test_stationarity(timeseries):
result = adfuller(timeseries)
print('ADF статистика:', result[0])
print('p-зачение (простое значение):', result[1])
if result[1] <= 0.05:
print("Ряд стационарен")
else:
print("Ряд не стационарен")
# Функция для загрузки моделей из файлов
def load_models():
with open('model_sprice.pkl', 'rb') as f:
model_fit_sprice = pickle.load(f)
with open('model_mprice.pkl', 'rb') as f:
model_fit_mprice = pickle.load(f)
return model_fit_sprice, model_fit_mprice
# Прогнозирование
def forecast_prices(data):
# Загрузка обученных моделей
model_fit_sprice, model_fit_mprice = load_models()
# Прогнозирование на 365 дней вперед
forecast_sprice = model_fit_sprice.forecast(steps=365)
forecast_mprice = model_fit_mprice.forecast(steps=365)
# Создание датафрейма
forecast_dates = pd.date_range(start=data.index[-1] + pd.Timedelta(days=1), periods=365)
forecast_df = pd.DataFrame({
'date': forecast_dates,
'Sprice_forecast': forecast_sprice,
'Mprice_forecast': forecast_mprice
})
forecast_df.set_index('date', inplace=True)
return forecast_df
# Функция для создания графиков
def plot_forecasts(forecast_df):
plt.figure(figsize=(14, 6))
plt.subplot(1, 2, 1)
plt.plot(data['Sprice'], label='Известные данные')
plt.plot(forecast_df.index, forecast_df['Sprice_forecast'], label='Прогноз', color='r')
plt.title('Прогноз начальной стоимости поездки')
plt.xlabel('Дата/Год')
plt.ylabel('Цена/руб')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(data['Mprice'], label='Известные данные')
plt.plot(forecast_df.index, forecast_df['Mprice_forecast'], label='Прогноз', color='r')
plt.title('Прогноз поминутной стоимости поездки')
plt.xlabel('Дата/Год')
plt.ylabel('Цена/руб')
plt.legend()
plt.tight_layout()
plt.savefig('forecast_plot.png')
plt.close()
# Расчёт стоимости поездки
def cost(date_str, distance):
date = pd.to_datetime(date_str, format='%d/%m/%Y')
forecast_df = forecast_prices(data)
if date in forecast_df.index:
sprice = forecast_df.loc[date, 'Sprice_forecast']
mprice = forecast_df.loc[date, 'Mprice_forecast']
travel_time_seconds = distance / 3 # берём среднюю скорость поездки в 3 м/с
total_cost = sprice + (mprice * (travel_time_seconds / 60))
return f"Итоговая средняя для Ростовской области стоимость поездки: {total_cost:.2f} руб."
return "Дата не найдена в прогнозе. Максимальная дата прогноза на данный момент 01/07/2025, так же используйте формат 'День/Месяц/Год' "
#Gradio
def gradio_interface(date_str, distance):
forecast_df = forecast_prices(data)
plot_forecasts(forecast_df)
cost_message = cost(date_str, distance)
return 'forecast_plot.png', cost_message
txt = gr.Textbox(label="Введите дату (dd/mm/yyyy)")
num = gr.Number(label="Введите расстояние (в метрах)")
iface = gr.Interface(fn=gradio_interface,inputs=[txt,num],outputs=["image", "text"],title="Прогноз цен на прокат электросамокатов")
iface.launch()