roboadvisor / finnlp /utils /data_processor.py
kristada673's picture
Upload 8 files
931d5fb
import os
import pickle
from typing import List
import numpy as np
import pandas as pd
class DataProcessor:
def __init__(
self,
data_source: str,
start_date: str,
end_date: str,
time_interval: str,
**kwargs,
):
self.data_source = data_source
self.start_date = start_date
self.end_date = end_date
self.time_interval = time_interval
self.dataframe = pd.DataFrame()
if self.data_source == "akshare":
from meta.data_processors.akshare import Akshare
processor_dict = {self.data_source: Akshare}
elif self.data_source == "alpaca":
from meta.data_processors.alpaca import Alpaca
processor_dict = {self.data_source: Alpaca}
elif self.data_source == "alphavantage":
from meta.data_processors.alphavantage import Alphavantage
processor_dict = {self.data_source: Alphavantage}
elif self.data_source == "baostock":
from meta.data_processors.baostock import Baostock
processor_dict = {self.data_source: Baostock}
elif self.data_source == "binance":
from meta.data_processors.binance import Binance
processor_dict = {self.data_source: Binance}
elif self.data_source == "ccxt":
from meta.data_processors.ccxt import Ccxt
processor_dict = {self.data_source: Ccxt}
elif self.data_source == "iexcloud":
from meta.data_processors.iexcloud import Iexcloud
processor_dict = {self.data_source: Iexcloud}
elif self.data_source == "joinquant":
from meta.data_processors.joinquant import Joinquant
processor_dict = {self.data_source: Joinquant}
elif self.data_source == "quandl":
from meta.data_processors.quandl import Quandl
processor_dict = {self.data_source: Quandl}
elif self.data_source == "quantconnect":
from meta.data_processors.quantconnect import Quantconnect
processor_dict = {self.data_source: Quantconnect}
elif self.data_source == "ricequant":
from meta.data_processors.ricequant import Ricequant
processor_dict = {self.data_source: Ricequant}
elif self.data_source == "tushare":
from meta.data_processors.tushare import Tushare
processor_dict = {self.data_source: Tushare}
elif self.data_source == "wrds":
from meta.data_processors.wrds import Wrds
processor_dict = {self.data_source: Wrds}
elif self.data_source == "yahoofinance":
from meta.data_processors.yahoofinance import Yahoofinance
processor_dict = {self.data_source: Yahoofinance}
else:
print(f"{self.data_source} is NOT supported yet.")
try:
self.processor = processor_dict.get(self.data_source)(
data_source, start_date, end_date, time_interval, **kwargs
)
print(f"{self.data_source} successfully connected")
except:
raise ValueError(
f"Please input correct account info for {self.data_source}!"
)
def download_data(self, ticker_list):
self.processor.download_data(ticker_list=ticker_list)
self.dataframe = self.processor.dataframe
def clean_data(self):
self.processor.dataframe = self.dataframe
self.processor.clean_data()
self.dataframe = self.processor.dataframe
def add_technical_indicator(
self, tech_indicator_list: List[str], select_stockstats_talib: int = 0
):
self.tech_indicator_list = tech_indicator_list
self.processor.add_technical_indicator(
tech_indicator_list, select_stockstats_talib
)
self.dataframe = self.processor.dataframe
def add_turbulence(self):
self.processor.add_turbulence()
self.dataframe = self.processor.dataframe
def add_vix(self):
self.processor.add_vix()
self.dataframe = self.processor.dataframe
def df_to_array(self, if_vix: bool) -> np.array:
price_array, tech_array, turbulence_array = self.processor.df_to_array(
self.tech_indicator_list, if_vix
)
# fill nan with 0 for technical indicators
tech_nan_positions = np.isnan(tech_array)
tech_array[tech_nan_positions] = 0
return price_array, tech_array, turbulence_array
def data_split(self, df, start, end, target_date_col="time"):
"""
split the dataset into training or testing using date
:param data: (df) pandas dataframe, start, end
:return: (df) pandas dataframe
"""
data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)]
data = data.sort_values([target_date_col, "tic"], ignore_index=True)
data.index = data[target_date_col].factorize()[0]
return data
def fillna(self):
self.processor.dataframe = self.dataframe
self.processor.fillna()
self.dataframe = self.processor.dataframe
def run(
self,
ticker_list: str,
technical_indicator_list: List[str],
if_vix: bool,
cache: bool = False,
select_stockstats_talib: int = 0,
):
if self.time_interval == "1s" and self.data_source != "binance":
raise ValueError(
"Currently 1s interval data is only supported with 'binance' as data source"
)
cache_filename = (
"_".join(
ticker_list
+ [
self.data_source,
self.start_date,
self.end_date,
self.time_interval,
]
)
+ ".pickle"
)
cache_dir = "./cache"
cache_path = os.path.join(cache_dir, cache_filename)
if cache and os.path.isfile(cache_path):
print(f"Using cached file {cache_path}")
self.tech_indicator_list = technical_indicator_list
with open(cache_path, "rb") as handle:
self.processor.dataframe = pickle.load(handle)
else:
self.download_data(ticker_list)
self.clean_data()
if cache:
if not os.path.exists(cache_dir):
os.mkdir(cache_dir)
with open(cache_path, "wb") as handle:
pickle.dump(
self.dataframe,
handle,
protocol=pickle.HIGHEST_PROTOCOL,
)
self.add_technical_indicator(technical_indicator_list, select_stockstats_talib)
if if_vix:
self.add_vix()
price_array, tech_array, turbulence_array = self.df_to_array(if_vix)
tech_nan_positions = np.isnan(tech_array)
tech_array[tech_nan_positions] = 0
return price_array, tech_array, turbulence_array
def test_joinquant():
# TRADE_START_DATE = "2019-09-01"
TRADE_START_DATE = "2020-09-01"
TRADE_END_DATE = "2021-09-11"
# supported time interval: '1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M'
TIME_INTERVAL = "1d"
TECHNICAL_INDICATOR = [
"macd",
"boll_ub",
"boll_lb",
"rsi_30",
"dx_30",
"close_30_sma",
"close_60_sma",
]
kwargs = {"username": "xxx", "password": "xxx"}
p = DataProcessor(
data_source="joinquant",
start_date=TRADE_START_DATE,
end_date=TRADE_END_DATE,
time_interval=TIME_INTERVAL,
**kwargs,
)
ticker_list = ["000612.XSHE", "601808.XSHG"]
p.download_data(ticker_list=ticker_list)
p.clean_data()
p.add_turbulence()
p.add_technical_indicator(TECHNICAL_INDICATOR)
p.add_vix()
price_array, tech_array, turbulence_array = p.run(
ticker_list, TECHNICAL_INDICATOR, if_vix=False, cache=True
)
pass
# if __name__ == "__main__":
# # test_joinquant()
# # test_binance()
# # test_yahoofinance()
# test_baostock()
# # test_quandl()