Spaces:
Runtime error
Runtime error
import copy | |
import os | |
import urllib | |
import zipfile | |
from datetime import * | |
from pathlib import Path | |
from typing import List | |
import numpy as np | |
import pandas as pd | |
import stockstats | |
import talib | |
from finnlp.utils.config import BINANCE_BASE_URL | |
from finnlp.utils.config import TIME_ZONE_BERLIN | |
from finnlp.utils.config import TIME_ZONE_JAKARTA | |
from finnlp.utils.config import TIME_ZONE_PARIS | |
from finnlp.utils.config import TIME_ZONE_SELFDEFINED | |
from finnlp.utils.config import TIME_ZONE_SHANGHAI | |
from finnlp.utils.config import TIME_ZONE_USEASTERN | |
from finnlp.utils.config import USE_TIME_ZONE_SELFDEFINED | |
from finnlp.utils.config_tickers import CAC_40_TICKER | |
from finnlp.utils.config_tickers import CSI_300_TICKER | |
from finnlp.utils.config_tickers import DAX_30_TICKER | |
from finnlp.utils.config_tickers import DOW_30_TICKER | |
from finnlp.utils.config_tickers import HSI_50_TICKER | |
from finnlp.utils.config_tickers import LQ45_TICKER | |
from finnlp.utils.config_tickers import MDAX_50_TICKER | |
from finnlp.utils.config_tickers import NAS_100_TICKER | |
from finnlp.utils.config_tickers import SDAX_50_TICKER | |
from finnlp.utils.config_tickers import SP_500_TICKER | |
from finnlp.utils.config_tickers import SSE_50_TICKER | |
from finnlp.utils.config_tickers import TECDAX_TICKER | |
class _Base: | |
def __init__( | |
self, | |
data_source: str, | |
start_date: str, | |
end_date: str, | |
time_interval: str, | |
**kwargs, | |
): | |
self.data_source: str = data_source | |
self.start_date: str = start_date | |
self.end_date: str = end_date | |
self.time_interval: str = time_interval # standard time_interval | |
# transferred_time_interval will be supported in the future. | |
# self.nonstandard_time_interval: str = self.calc_nonstandard_time_interval() # transferred time_interval of this processor | |
self.time_zone: str = "" | |
self.dataframe: pd.DataFrame = pd.DataFrame() | |
self.dictnumpy: dict = ( | |
{} | |
) # e.g., self.dictnumpy["open"] = np.array([1, 2, 3]), self.dictnumpy["close"] = np.array([1, 2, 3]) | |
def download_data(self, ticker_list: List[str]): | |
pass | |
def clean_data(self): | |
if "date" in self.dataframe.columns.values.tolist(): | |
self.dataframe.rename(columns={"date": "time"}, inplace=True) | |
if "datetime" in self.dataframe.columns.values.tolist(): | |
self.dataframe.rename(columns={"datetime": "time"}, inplace=True) | |
if self.data_source == "ccxt": | |
self.dataframe.rename(columns={"index": "time"}, inplace=True) | |
if self.data_source == "ricequant": | |
"""RiceQuant data is already cleaned, we only need to transform data format here. | |
No need for filling NaN data""" | |
self.dataframe.rename(columns={"order_book_id": "tic"}, inplace=True) | |
# raw df uses multi-index (tic,time), reset it to single index (time) | |
self.dataframe.reset_index(level=[0, 1], inplace=True) | |
# check if there is NaN values | |
assert not self.dataframe.isnull().values.any() | |
elif self.data_source == "baostock": | |
self.dataframe.rename(columns={"code": "tic"}, inplace=True) | |
self.dataframe.dropna(inplace=True) | |
# adjusted_close: adjusted close price | |
if "adjusted_close" not in self.dataframe.columns.values.tolist(): | |
self.dataframe["adjusted_close"] = self.dataframe["close"] | |
self.dataframe.sort_values(by=["time", "tic"], inplace=True) | |
self.dataframe = self.dataframe[ | |
[ | |
"tic", | |
"time", | |
"open", | |
"high", | |
"low", | |
"close", | |
"adjusted_close", | |
"volume", | |
] | |
] | |
def fillna(self): | |
df = self.dataframe | |
dfcode = pd.DataFrame(columns=["tic"]) | |
dfdate = pd.DataFrame(columns=["time"]) | |
dfcode.tic = df.tic.unique() | |
dfdate.time = df.time.unique() | |
dfdate.sort_values(by="time", ascending=False, ignore_index=True, inplace=True) | |
# the old pandas may not support pd.merge(how="cross") | |
try: | |
df1 = pd.merge(dfcode, dfdate, how="cross") | |
except: | |
print("Please wait for a few seconds...") | |
df1 = pd.DataFrame(columns=["tic", "time"]) | |
for i in range(dfcode.shape[0]): | |
for j in range(dfdate.shape[0]): | |
df1 = df1.append( | |
pd.DataFrame( | |
data={ | |
"tic": dfcode.iat[i, 0], | |
"time": dfdate.iat[j, 0], | |
}, | |
index=[(i + 1) * (j + 1) - 1], | |
) | |
) | |
df = pd.merge(df1, df, how="left", on=["tic", "time"]) | |
# back fill missing data then front fill | |
df_new = pd.DataFrame(columns=df.columns) | |
for i in df.tic.unique(): | |
df_tmp = df[df.tic == i].fillna(method="bfill").fillna(method="ffill") | |
df_new = pd.concat([df_new, df_tmp], ignore_index=True) | |
df_new = df_new.fillna(0) | |
# reshape dataframe | |
df_new = df_new.sort_values(by=["time", "tic"]).reset_index(drop=True) | |
print("Shape of DataFrame: ", df_new.shape) | |
self.dataframe = df_new | |
def get_trading_days(self, start: str, end: str) -> List[str]: | |
if self.data_source in [ | |
"binance", | |
"ccxt", | |
"quantconnect", | |
"ricequant", | |
"tushare", | |
]: | |
print( | |
f"Calculate get_trading_days not supported for {self.data_source} yet." | |
) | |
return None | |
# select_stockstats_talib: 0 (stockstats, default), or 1 (use talib). Users can choose the method. | |
# drop_na_timestep: 0 (not dropping timesteps that contain nan), or 1 (dropping timesteps that contain nan, default). Users can choose the method. | |
def add_technical_indicator( | |
self, | |
tech_indicator_list: List[str], | |
select_stockstats_talib: int = 0, | |
drop_na_timesteps: int = 1, | |
): | |
""" | |
calculate technical indicators | |
use stockstats/talib package to add technical inidactors | |
:param data: (df) pandas dataframe | |
:return: (df) pandas dataframe | |
""" | |
if "date" in self.dataframe.columns.values.tolist(): | |
self.dataframe.rename(columns={"date": "time"}, inplace=True) | |
if self.data_source == "ccxt": | |
self.dataframe.rename(columns={"index": "time"}, inplace=True) | |
self.dataframe.reset_index(drop=False, inplace=True) | |
if "level_1" in self.dataframe.columns: | |
self.dataframe.drop(columns=["level_1"], inplace=True) | |
if "level_0" in self.dataframe.columns and "tic" not in self.dataframe.columns: | |
self.dataframe.rename(columns={"level_0": "tic"}, inplace=True) | |
assert select_stockstats_talib in {0, 1} | |
print("tech_indicator_list: ", tech_indicator_list) | |
if select_stockstats_talib == 0: # use stockstats | |
stock = stockstats.StockDataFrame.retype(self.dataframe) | |
unique_ticker = stock.tic.unique() | |
for indicator in tech_indicator_list: | |
print("indicator: ", indicator) | |
indicator_df = pd.DataFrame() | |
for i in range(len(unique_ticker)): | |
try: | |
temp_indicator = stock[stock.tic == unique_ticker[i]][indicator] | |
temp_indicator = pd.DataFrame(temp_indicator) | |
temp_indicator["tic"] = unique_ticker[i] | |
temp_indicator["time"] = self.dataframe[ | |
self.dataframe.tic == unique_ticker[i] | |
]["time"].to_list() | |
indicator_df = pd.concat( | |
[indicator_df, temp_indicator], | |
axis=0, | |
join="outer", | |
ignore_index=True, | |
) | |
except Exception as e: | |
print(e) | |
if not indicator_df.empty: | |
self.dataframe = self.dataframe.merge( | |
indicator_df[["tic", "time", indicator]], | |
on=["tic", "time"], | |
how="left", | |
) | |
else: # use talib | |
final_df = pd.DataFrame() | |
for i in self.dataframe.tic.unique(): | |
tic_df = self.dataframe[self.dataframe.tic == i] | |
( | |
tic_df.loc["macd"], | |
tic_df.loc["macd_signal"], | |
tic_df.loc["macd_hist"], | |
) = talib.MACD( | |
tic_df["close"], | |
fastperiod=12, | |
slowperiod=26, | |
signalperiod=9, | |
) | |
tic_df.loc["rsi"] = talib.RSI(tic_df["close"], timeperiod=14) | |
tic_df.loc["cci"] = talib.CCI( | |
tic_df["high"], | |
tic_df["low"], | |
tic_df["close"], | |
timeperiod=14, | |
) | |
tic_df.loc["dx"] = talib.DX( | |
tic_df["high"], | |
tic_df["low"], | |
tic_df["close"], | |
timeperiod=14, | |
) | |
final_df = pd.concat([final_df, tic_df], axis=0, join="outer") | |
self.dataframe = final_df | |
self.dataframe.sort_values(by=["time", "tic"], inplace=True) | |
if drop_na_timesteps: | |
time_to_drop = self.dataframe[ | |
self.dataframe.isna().any(axis=1) | |
].time.unique() | |
self.dataframe = self.dataframe[~self.dataframe.time.isin(time_to_drop)] | |
print("Succesfully add technical indicators") | |
def add_turbulence(self): | |
""" | |
add turbulence index from a precalcualted dataframe | |
:param data: (df) pandas dataframe | |
:return: (df) pandas dataframe | |
""" | |
# df = data.copy() | |
# turbulence_index = self.calculate_turbulence(df) | |
# df = df.merge(turbulence_index, on="time") | |
# df = df.sort_values(["time", "tic"]).reset_index(drop=True) | |
# return df | |
if self.data_source in [ | |
"binance", | |
"ccxt", | |
"iexcloud", | |
"joinquant", | |
"quantconnect", | |
]: | |
print( | |
f"Turbulence not supported for {self.data_source} yet. Return original DataFrame." | |
) | |
if self.data_source in [ | |
"alpaca", | |
"ricequant", | |
"tushare", | |
"wrds", | |
"yahoofinance", | |
]: | |
turbulence_index = self.calculate_turbulence() | |
self.dataframe = self.dataframe.merge(turbulence_index, on="time") | |
self.dataframe.sort_values(["time", "tic"], inplace=True) | |
self.dataframe.reset_index(drop=True, inplace=True) | |
def calculate_turbulence(self, time_period: int = 252) -> pd.DataFrame: | |
"""calculate turbulence index based on dow 30""" | |
# can add other market assets | |
df_price_pivot = self.dataframe.pivot( | |
index="time", columns="tic", values="close" | |
) | |
# use returns to calculate turbulence | |
df_price_pivot = df_price_pivot.pct_change() | |
unique_date = self.dataframe["time"].unique() | |
# start after a year | |
start = time_period | |
turbulence_index = [0] * start | |
# turbulence_index = [0] | |
count = 0 | |
for i in range(start, len(unique_date)): | |
current_price = df_price_pivot[df_price_pivot.index == unique_date[i]] | |
# use one year rolling window to calcualte covariance | |
hist_price = df_price_pivot[ | |
(df_price_pivot.index < unique_date[i]) | |
& (df_price_pivot.index >= unique_date[i - time_period]) | |
] | |
# Drop tickers which has number missing values more than the "oldest" ticker | |
filtered_hist_price = hist_price.iloc[ | |
hist_price.isna().sum().min() : | |
].dropna(axis=1) | |
cov_temp = filtered_hist_price.cov() | |
current_temp = current_price[list(filtered_hist_price)] - np.mean( | |
filtered_hist_price, axis=0 | |
) | |
# cov_temp = hist_price.cov() | |
# current_temp=(current_price - np.mean(hist_price,axis=0)) | |
temp = current_temp.values.dot(np.linalg.pinv(cov_temp)).dot( | |
current_temp.values.T | |
) | |
if temp > 0: | |
count += 1 | |
# avoid large outlier because of the calculation just begins: else turbulence_temp = 0 | |
turbulence_temp = temp[0][0] if count > 2 else 0 | |
else: | |
turbulence_temp = 0 | |
turbulence_index.append(turbulence_temp) | |
turbulence_index = pd.DataFrame( | |
{"time": df_price_pivot.index, "turbulence": turbulence_index} | |
) | |
return turbulence_index | |
def add_vix(self): | |
""" | |
add vix from processors | |
:param data: (df) pandas dataframe | |
:return: (df) pandas dataframe | |
""" | |
if self.data_source in [ | |
"binance", | |
"ccxt", | |
"iexcloud", | |
"joinquant", | |
"quantconnect", | |
"ricequant", | |
"tushare", | |
]: | |
print( | |
f"VIX is not applicable for {self.data_source}. Return original DataFrame" | |
) | |
return None | |
# if self.data_source == 'yahoofinance': | |
# df = data.copy() | |
# df_vix = self.download_data( | |
# start_date=df.time.min(), | |
# end_date=df.time.max(), | |
# ticker_list=["^VIX"], | |
# time_interval=self.time_interval, | |
# ) | |
# df_vix = self.clean_data(df_vix) | |
# vix = df_vix[["time", "adjusted_close"]] | |
# vix.columns = ["time", "vix"] | |
# | |
# df = df.merge(vix, on="time") | |
# df = df.sort_values(["time", "tic"]).reset_index(drop=True) | |
# elif self.data_source == 'alpaca': | |
# vix_df = self.download_data(["VIXY"], self.start, self.end, self.time_interval) | |
# cleaned_vix = self.clean_data(vix_df) | |
# vix = cleaned_vix[["time", "close"]] | |
# vix = vix.rename(columns={"close": "VIXY"}) | |
# | |
# df = data.copy() | |
# df = df.merge(vix, on="time") | |
# df = df.sort_values(["time", "tic"]).reset_index(drop=True) | |
# elif self.data_source == 'wrds': | |
# vix_df = self.download_data(['vix'], self.start, self.end_date, self.time_interval) | |
# cleaned_vix = self.clean_data(vix_df) | |
# vix = cleaned_vix[['date', 'close']] | |
# | |
# df = data.copy() | |
# df = df.merge(vix, on="date") | |
# df = df.sort_values(["date", "tic"]).reset_index(drop=True) | |
elif self.data_source == "yahoofinance": | |
ticker = "^VIX" | |
elif self.data_source == "alpaca": | |
ticker = "VIXY" | |
elif self.data_source == "wrds": | |
ticker = "vix" | |
else: | |
pass | |
df = self.dataframe.copy() | |
self.dataframe = [ticker] | |
# self.download_data(self.start_date, self.end_date, self.time_interval) | |
self.download_data([ticker], save_path="./data/vix.csv") | |
self.clean_data() | |
cleaned_vix = self.dataframe | |
# .rename(columns={ticker: "vix"}) | |
vix = cleaned_vix[["time", "close"]] | |
cleaned_vix = vix.rename(columns={"close": "vix"}) | |
df = df.merge(cleaned_vix, on="time") | |
df = df.sort_values(["time", "tic"]).reset_index(drop=True) | |
self.dataframe = df | |
def df_to_array(self, tech_indicator_list: List[str], if_vix: bool): | |
unique_ticker = self.dataframe.tic.unique() | |
price_array = np.column_stack( | |
[self.dataframe[self.dataframe.tic == tic].close for tic in unique_ticker] | |
) | |
common_tech_indicator_list = [ | |
i | |
for i in tech_indicator_list | |
if i in self.dataframe.columns.values.tolist() | |
] | |
tech_array = np.hstack( | |
[ | |
self.dataframe.loc[ | |
(self.dataframe.tic == tic), common_tech_indicator_list | |
] | |
for tic in unique_ticker | |
] | |
) | |
if if_vix: | |
risk_array = np.column_stack( | |
[self.dataframe[self.dataframe.tic == tic].vix for tic in unique_ticker] | |
) | |
else: | |
risk_array = ( | |
np.column_stack( | |
[ | |
self.dataframe[self.dataframe.tic == tic].turbulence | |
for tic in unique_ticker | |
] | |
) | |
if "turbulence" in self.dataframe.columns | |
else None | |
) | |
print("Successfully transformed into array") | |
return price_array, tech_array, risk_array | |
# standard_time_interval s: second, m: minute, h: hour, d: day, w: week, M: month, q: quarter, y: year | |
# output time_interval of the processor | |
def calc_nonstandard_time_interval(self) -> str: | |
if self.data_source == "alpaca": | |
pass | |
elif self.data_source == "baostock": | |
# nonstandard_time_interval: 默认为d,日k线;d=日k线、w=周、m=月、5=5分钟、15=15分钟、30=30分钟、60=60分钟k线数据,不区分大小写;指数没有分钟线数据;周线每周最后一个交易日才可以获取,月线每月最后一个交易日才可以获取。 | |
pass | |
time_intervals = ["5m", "15m", "30m", "60m", "1d", "1w", "1M"] | |
assert self.time_interval in time_intervals, ( | |
"This time interval is not supported. Supported time intervals: " | |
+ ",".join(time_intervals) | |
) | |
if ( | |
"d" in self.time_interval | |
or "w" in self.time_interval | |
or "M" in self.time_interval | |
): | |
return self.time_interval[-1:].lower() | |
elif "m" in self.time_interval: | |
return self.time_interval[:-1] | |
elif self.data_source == "binance": | |
# nonstandard_time_interval: 1m,3m,5m,15m,30m,1h,2h,4h,6h,8h,12h,1d,3d,1w,1M | |
time_intervals = [ | |
"1m", | |
"3m", | |
"5m", | |
"15m", | |
"30m", | |
"1h", | |
"2h", | |
"4h", | |
"6h", | |
"8h", | |
"12h", | |
"1d", | |
"3d", | |
"1w", | |
"1M", | |
] | |
assert self.time_interval in time_intervals, ( | |
"This time interval is not supported. Supported time intervals: " | |
+ ",".join(time_intervals) | |
) | |
return self.time_interval | |
elif self.data_source == "ccxt": | |
pass | |
elif self.data_source == "iexcloud": | |
time_intervals = ["1d"] | |
assert self.time_interval in time_intervals, ( | |
"This time interval is not supported. Supported time intervals: " | |
+ ",".join(time_intervals) | |
) | |
return self.time_interval.upper() | |
elif self.data_source == "joinquant": | |
# '1m', '5m', '15m', '30m', '60m', '120m', '1d', '1w', '1M' | |
time_intervals = [ | |
"1m", | |
"5m", | |
"15m", | |
"30m", | |
"60m", | |
"120m", | |
"1d", | |
"1w", | |
"1M", | |
] | |
assert self.time_interval in time_intervals, ( | |
"This time interval is not supported. Supported time intervals: " | |
+ ",".join(time_intervals) | |
) | |
return self.time_interval | |
elif self.data_source == "quantconnect": | |
pass | |
elif self.data_source == "ricequant": | |
# nonstandard_time_interval: 'd' - 天,'w' - 周,'m' - 月, 'q' - 季,'y' - 年 | |
time_intervals = ["d", "w", "M", "q", "y"] | |
assert self.time_interval[-1] in time_intervals, ( | |
"This time interval is not supported. Supported time intervals: " | |
+ ",".join(time_intervals) | |
) | |
if "M" in self.time_interval: | |
return self.time_interval.lower() | |
else: | |
return self.time_interval | |
elif self.data_source == "tushare": | |
# 分钟频度包括1分、5、15、30、60分数据. Not support currently. | |
# time_intervals = ["1m", "5m", "15m", "30m", "60m", "1d"] | |
time_intervals = ["1d"] | |
assert self.time_interval in time_intervals, ( | |
"This time interval is not supported. Supported time intervals: " | |
+ ",".join(time_intervals) | |
) | |
return self.time_interval | |
elif self.data_source == "wrds": | |
pass | |
elif self.data_source == "yahoofinance": | |
# nonstandard_time_interval: ["1m", "2m", "5m", "15m", "30m", "60m", "90m", "1h", "1d", "5d","1wk", "1mo", "3mo"] | |
time_intervals = [ | |
"1m", | |
"2m", | |
"5m", | |
"15m", | |
"30m", | |
"60m", | |
"90m", | |
"1h", | |
"1d", | |
"5d", | |
"1w", | |
"1M", | |
"3M", | |
] | |
assert self.time_interval in time_intervals, ( | |
"This time interval is not supported. Supported time intervals: " | |
+ ",".join(time_intervals) | |
) | |
if "w" in self.time_interval: | |
return self.time_interval + "k" | |
elif "M" in self.time_interval: | |
return self.time_interval[:-1] + "mo" | |
else: | |
return self.time_interval | |
else: | |
raise ValueError( | |
f"Not support transfer_standard_time_interval for {self.data_source}" | |
) | |
# "600000.XSHG" -> "sh.600000" | |
# "000612.XSHE" -> "sz.000612" | |
def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str: | |
return ticker | |
def save_data(self, path): | |
if ".csv" in path: | |
path = path.split("/") | |
filename = path[-1] | |
path = "/".join(path[:-1] + [""]) | |
else: | |
if path[-1] == "/": | |
filename = "dataset.csv" | |
else: | |
filename = "/dataset.csv" | |
os.makedirs(path, exist_ok=True) | |
self.dataframe.to_csv(path + filename, index=False) | |
def load_data(self, path): | |
assert ".csv" in path # only support csv format now | |
self.dataframe = pd.read_csv(path) | |
columns = self.dataframe.columns | |
print(f"{path} loaded") | |
# # check loaded file | |
# assert "date" in columns or "time" in columns | |
# assert "close" in columns | |
def calc_time_zone( | |
ticker_list: List[str], | |
time_zone_selfdefined: str, | |
use_time_zone_selfdefined: int, | |
) -> str: | |
assert isinstance(ticker_list, list) | |
ticker_list = ticker_list[0] | |
if use_time_zone_selfdefined == 1: | |
time_zone = time_zone_selfdefined | |
elif ticker_list in HSI_50_TICKER + SSE_50_TICKER + CSI_300_TICKER: | |
time_zone = TIME_ZONE_SHANGHAI | |
elif ticker_list in DOW_30_TICKER + NAS_100_TICKER + SP_500_TICKER: | |
time_zone = TIME_ZONE_USEASTERN | |
elif ticker_list == CAC_40_TICKER: | |
time_zone = TIME_ZONE_PARIS | |
elif ticker_list in DAX_30_TICKER + TECDAX_TICKER + MDAX_50_TICKER + SDAX_50_TICKER: | |
time_zone = TIME_ZONE_BERLIN | |
elif ticker_list == LQ45_TICKER: | |
time_zone = TIME_ZONE_JAKARTA | |
else: | |
# hack needed to have this working with vix indicator | |
# fix: unable to set time_zone_selfdefined from top-level dataprocessor class | |
time_zone = TIME_ZONE_USEASTERN | |
# raise ValueError("Time zone is wrong.") | |
return time_zone | |
def check_date(d: str) -> bool: | |
assert ( | |
len(d) == 10 | |
), "Please check the length of date and use the correct date like 2020-01-01." | |
indices = [0, 1, 2, 3, 5, 6, 8, 9] | |
correct = True | |
for i in indices: | |
if not d[i].isdigit(): | |
correct = False | |
break | |
if not correct: | |
raise ValueError("Please use the correct date like 2020-01-01.") | |
return correct | |