Spaces:
Runtime error
Runtime error
import os | |
import pickle | |
from typing import List | |
import numpy as np | |
import pandas as pd | |
class DataProcessor: | |
def __init__( | |
self, | |
data_source: str, | |
start_date: str, | |
end_date: str, | |
time_interval: str, | |
**kwargs, | |
): | |
self.data_source = data_source | |
self.start_date = start_date | |
self.end_date = end_date | |
self.time_interval = time_interval | |
self.dataframe = pd.DataFrame() | |
if self.data_source == "akshare": | |
from meta.data_processors.akshare import Akshare | |
processor_dict = {self.data_source: Akshare} | |
elif self.data_source == "alpaca": | |
from meta.data_processors.alpaca import Alpaca | |
processor_dict = {self.data_source: Alpaca} | |
elif self.data_source == "alphavantage": | |
from meta.data_processors.alphavantage import Alphavantage | |
processor_dict = {self.data_source: Alphavantage} | |
elif self.data_source == "baostock": | |
from meta.data_processors.baostock import Baostock | |
processor_dict = {self.data_source: Baostock} | |
elif self.data_source == "binance": | |
from meta.data_processors.binance import Binance | |
processor_dict = {self.data_source: Binance} | |
elif self.data_source == "ccxt": | |
from meta.data_processors.ccxt import Ccxt | |
processor_dict = {self.data_source: Ccxt} | |
elif self.data_source == "iexcloud": | |
from meta.data_processors.iexcloud import Iexcloud | |
processor_dict = {self.data_source: Iexcloud} | |
elif self.data_source == "joinquant": | |
from meta.data_processors.joinquant import Joinquant | |
processor_dict = {self.data_source: Joinquant} | |
elif self.data_source == "quandl": | |
from meta.data_processors.quandl import Quandl | |
processor_dict = {self.data_source: Quandl} | |
elif self.data_source == "quantconnect": | |
from meta.data_processors.quantconnect import Quantconnect | |
processor_dict = {self.data_source: Quantconnect} | |
elif self.data_source == "ricequant": | |
from meta.data_processors.ricequant import Ricequant | |
processor_dict = {self.data_source: Ricequant} | |
elif self.data_source == "tushare": | |
from meta.data_processors.tushare import Tushare | |
processor_dict = {self.data_source: Tushare} | |
elif self.data_source == "wrds": | |
from meta.data_processors.wrds import Wrds | |
processor_dict = {self.data_source: Wrds} | |
elif self.data_source == "yahoofinance": | |
from meta.data_processors.yahoofinance import Yahoofinance | |
processor_dict = {self.data_source: Yahoofinance} | |
else: | |
print(f"{self.data_source} is NOT supported yet.") | |
try: | |
self.processor = processor_dict.get(self.data_source)( | |
data_source, start_date, end_date, time_interval, **kwargs | |
) | |
print(f"{self.data_source} successfully connected") | |
except: | |
raise ValueError( | |
f"Please input correct account info for {self.data_source}!" | |
) | |
def download_data(self, ticker_list): | |
self.processor.download_data(ticker_list=ticker_list) | |
self.dataframe = self.processor.dataframe | |
def clean_data(self): | |
self.processor.dataframe = self.dataframe | |
self.processor.clean_data() | |
self.dataframe = self.processor.dataframe | |
def add_technical_indicator( | |
self, tech_indicator_list: List[str], select_stockstats_talib: int = 0 | |
): | |
self.tech_indicator_list = tech_indicator_list | |
self.processor.add_technical_indicator( | |
tech_indicator_list, select_stockstats_talib | |
) | |
self.dataframe = self.processor.dataframe | |
def add_turbulence(self): | |
self.processor.add_turbulence() | |
self.dataframe = self.processor.dataframe | |
def add_vix(self): | |
self.processor.add_vix() | |
self.dataframe = self.processor.dataframe | |
def df_to_array(self, if_vix: bool) -> np.array: | |
price_array, tech_array, turbulence_array = self.processor.df_to_array( | |
self.tech_indicator_list, if_vix | |
) | |
# fill nan with 0 for technical indicators | |
tech_nan_positions = np.isnan(tech_array) | |
tech_array[tech_nan_positions] = 0 | |
return price_array, tech_array, turbulence_array | |
def data_split(self, df, start, end, target_date_col="time"): | |
""" | |
split the dataset into training or testing using date | |
:param data: (df) pandas dataframe, start, end | |
:return: (df) pandas dataframe | |
""" | |
data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)] | |
data = data.sort_values([target_date_col, "tic"], ignore_index=True) | |
data.index = data[target_date_col].factorize()[0] | |
return data | |
def fillna(self): | |
self.processor.dataframe = self.dataframe | |
self.processor.fillna() | |
self.dataframe = self.processor.dataframe | |
def run( | |
self, | |
ticker_list: str, | |
technical_indicator_list: List[str], | |
if_vix: bool, | |
cache: bool = False, | |
select_stockstats_talib: int = 0, | |
): | |
if self.time_interval == "1s" and self.data_source != "binance": | |
raise ValueError( | |
"Currently 1s interval data is only supported with 'binance' as data source" | |
) | |
cache_filename = ( | |
"_".join( | |
ticker_list | |
+ [ | |
self.data_source, | |
self.start_date, | |
self.end_date, | |
self.time_interval, | |
] | |
) | |
+ ".pickle" | |
) | |
cache_dir = "./cache" | |
cache_path = os.path.join(cache_dir, cache_filename) | |
if cache and os.path.isfile(cache_path): | |
print(f"Using cached file {cache_path}") | |
self.tech_indicator_list = technical_indicator_list | |
with open(cache_path, "rb") as handle: | |
self.processor.dataframe = pickle.load(handle) | |
else: | |
self.download_data(ticker_list) | |
self.clean_data() | |
if cache: | |
if not os.path.exists(cache_dir): | |
os.mkdir(cache_dir) | |
with open(cache_path, "wb") as handle: | |
pickle.dump( | |
self.dataframe, | |
handle, | |
protocol=pickle.HIGHEST_PROTOCOL, | |
) | |
self.add_technical_indicator(technical_indicator_list, select_stockstats_talib) | |
if if_vix: | |
self.add_vix() | |
price_array, tech_array, turbulence_array = self.df_to_array(if_vix) | |
tech_nan_positions = np.isnan(tech_array) | |
tech_array[tech_nan_positions] = 0 | |
return price_array, tech_array, turbulence_array | |
def test_joinquant(): | |
# TRADE_START_DATE = "2019-09-01" | |
TRADE_START_DATE = "2020-09-01" | |
TRADE_END_DATE = "2021-09-11" | |
# supported time interval: '1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M' | |
TIME_INTERVAL = "1d" | |
TECHNICAL_INDICATOR = [ | |
"macd", | |
"boll_ub", | |
"boll_lb", | |
"rsi_30", | |
"dx_30", | |
"close_30_sma", | |
"close_60_sma", | |
] | |
kwargs = {"username": "xxx", "password": "xxx"} | |
p = DataProcessor( | |
data_source="joinquant", | |
start_date=TRADE_START_DATE, | |
end_date=TRADE_END_DATE, | |
time_interval=TIME_INTERVAL, | |
**kwargs, | |
) | |
ticker_list = ["000612.XSHE", "601808.XSHG"] | |
p.download_data(ticker_list=ticker_list) | |
p.clean_data() | |
p.add_turbulence() | |
p.add_technical_indicator(TECHNICAL_INDICATOR) | |
p.add_vix() | |
price_array, tech_array, turbulence_array = p.run( | |
ticker_list, TECHNICAL_INDICATOR, if_vix=False, cache=True | |
) | |
pass | |
# if __name__ == "__main__": | |
# # test_joinquant() | |
# # test_binance() | |
# # test_yahoofinance() | |
# test_baostock() | |
# # test_quandl() | |