StockPredict / core /data.py
aromidvar1355's picture
Update core/data.py
4a34e93 verified
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
try:
import yfinance as yf
except ImportError:
raise ImportError("yfinance must be installed to fetch financial data.")
def load_data(data_src="yahoo", ticker="AAPL", file_upload=None, start="2020-01-01", end="2023-01-01"):
if data_src == "yahoo":
try:
info = yf.Ticker(ticker).info
if not info:
raise ValueError(f"Ticker '{ticker}' not found.")
df = yf.download(ticker, start=start, end=end, progress=False)
if df.empty:
raise ValueError(f"No data found for ticker '{ticker}' in the specified date range. Please check the symbol and dates.")
except Exception as e:
raise ValueError(f"Error fetching data for ticker '{ticker}': {e}")
df = df[['Close']].dropna().rename(columns={'Close': 'value'})
df.reset_index(inplace=True)
elif data_src == "csv":
if file_upload is None:
raise ValueError("CSV file upload required but not provided.")
try:
df = pd.read_csv(file_upload)
except Exception as e:
raise ValueError(f"Failed to read uploaded CSV file: {e}")
if 'value' not in df.columns:
if 'Close' in df.columns:
df = df[['Close']].rename(columns={'Close': 'value'})
else:
raise ValueError("CSV must contain a 'value' or 'Close' column.")
df = df[['value']].dropna().reset_index(drop=True)
else:
raise ValueError("Invalid data source. 'csv' or 'yahoo' expected.")
return df
def preprocess_data(df, column, window_size=30):
scaler = StandardScaler()
data = df[[column]].values.astype(float)
scaled = scaler.fit_transform(data)
X, y = [], []
for i in range(len(scaled) - window_size):
X.append(scaled[i:i + window_size])
y.append(scaled[i + window_size])
X = np.array(X)
y = np.array(y)
return X, y, scaler