Spaces:
Runtime error
Runtime error
import copy | |
import os | |
import time | |
import warnings | |
warnings.filterwarnings("ignore") | |
from typing import List | |
import pandas as pd | |
from tqdm import tqdm | |
from matplotlib import pyplot as plt | |
import stockstats | |
import talib | |
from meta.data_processors._base import _Base | |
import tushare as ts | |
class Tushare(_Base): | |
""" | |
key-value in kwargs | |
---------- | |
token : str | |
get from https://waditu.com/ after registration | |
adj: str | |
Whether to use adjusted closing price. Default is None. | |
If you want to use forward adjusted closing price or 前复权. pleses use 'qfq' | |
If you want to use backward adjusted closing price or 后复权. pleses use 'hfq' | |
""" | |
def __init__( | |
self, | |
data_source: str, | |
start_date: str, | |
end_date: str, | |
time_interval: str, | |
**kwargs, | |
): | |
super().__init__(data_source, start_date, end_date, time_interval, **kwargs) | |
assert "token" in kwargs.keys(), "Please input token!" | |
self.token = kwargs["token"] | |
if "adj" in kwargs.keys(): | |
self.adj = kwargs["adj"] | |
print(f"Using {self.adj} method.") | |
else: | |
self.adj = None | |
def get_data(self, id) -> pd.DataFrame: | |
# df1 = ts.pro_bar(ts_code=id, start_date=self.start_date,end_date='20180101') | |
# dfb=pd.concat([df, df1], ignore_index=True) | |
# print(dfb.shape) | |
return ts.pro_bar( | |
ts_code=id, | |
start_date=self.start_date, | |
end_date=self.end_date, | |
adj=self.adj, | |
) | |
def download_data( | |
self, ticker_list: List[str], save_path: str = "./data/dataset.csv" | |
): | |
""" | |
`pd.DataFrame` | |
7 columns: A tick symbol, time, open, high, low, close and volume | |
for the specified stock ticker | |
""" | |
assert self.time_interval == "1d", "Not supported currently" | |
self.ticker_list = ticker_list | |
ts.set_token(self.token) | |
self.dataframe = pd.DataFrame() | |
for i in tqdm(ticker_list, total=len(ticker_list)): | |
# nonstandard_id = self.transfer_standard_ticker_to_nonstandard(i) | |
# df_temp = self.get_data(nonstandard_id) | |
df_temp = self.get_data(i) | |
self.dataframe = self.dataframe.append(df_temp) | |
# print("{} ok".format(i)) | |
time.sleep(0.25) | |
self.dataframe.columns = [ | |
"tic", | |
"time", | |
"open", | |
"high", | |
"low", | |
"close", | |
"pre_close", | |
"change", | |
"pct_chg", | |
"volume", | |
"amount", | |
] | |
self.dataframe.sort_values(by=["time", "tic"], inplace=True) | |
self.dataframe.reset_index(drop=True, inplace=True) | |
self.dataframe = self.dataframe[ | |
["tic", "time", "open", "high", "low", "close", "volume"] | |
] | |
# self.dataframe.loc[:, 'tic'] = pd.DataFrame((self.dataframe['tic'].tolist())) | |
self.dataframe["time"] = pd.to_datetime(self.dataframe["time"], format="%Y%m%d") | |
self.dataframe["day"] = self.dataframe["time"].dt.dayofweek | |
self.dataframe["time"] = self.dataframe.time.apply( | |
lambda x: x.strftime("%Y-%m-%d") | |
) | |
self.dataframe.dropna(inplace=True) | |
self.dataframe.sort_values(by=["time", "tic"], inplace=True) | |
self.dataframe.reset_index(drop=True, inplace=True) | |
self.save_data(save_path) | |
print( | |
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}" | |
) | |
def data_split(self, df, start, end, target_date_col="time"): | |
""" | |
split the dataset into training or testing using time | |
:param data: (df) pandas dataframe, start, end | |
:return: (df) pandas dataframe | |
""" | |
data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)] | |
data = data.sort_values([target_date_col, "tic"], ignore_index=True) | |
data.index = data[target_date_col].factorize()[0] | |
return data | |
def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str: | |
# "600000.XSHG" -> "600000.SH" | |
# "000612.XSHE" -> "000612.SZ" | |
n, alpha = ticker.split(".") | |
assert alpha in ["XSHG", "XSHE"], "Wrong alpha" | |
if alpha == "XSHG": | |
nonstandard_ticker = n + ".SH" | |
elif alpha == "XSHE": | |
nonstandard_ticker = n + ".SZ" | |
return nonstandard_ticker | |
def save_data(self, path): | |
if ".csv" in path: | |
path = path.split("/") | |
filename = path[-1] | |
path = "/".join(path[:-1] + [""]) | |
else: | |
if path[-1] == "/": | |
filename = "dataset.csv" | |
else: | |
filename = "/dataset.csv" | |
os.makedirs(path, exist_ok=True) | |
self.dataframe.to_csv(path + filename, index=False) | |
def load_data(self, path): | |
assert ".csv" in path # only support csv format now | |
self.dataframe = pd.read_csv(path) | |
columns = self.dataframe.columns | |
assert ( | |
"tic" in columns and "time" in columns and "close" in columns | |
) # input file must have "tic","time" and "close" columns | |
class ReturnPlotter: | |
""" | |
An easy-to-use plotting tool to plot cumulative returns over time. | |
Baseline supports equal weighting(default) and any stocks you want to use for comparison. | |
""" | |
def __init__(self, df_account_value, df_trade, start_date, end_date): | |
self.start = start_date | |
self.end = end_date | |
self.trade = df_trade | |
self.df_account_value = df_account_value | |
def get_baseline(self, ticket): | |
df = ts.get_hist_data(ticket, start=self.start, end=self.end) | |
df.loc[:, "dt"] = df.index | |
df.index = range(len(df)) | |
df.sort_values(axis=0, by="dt", ascending=True, inplace=True) | |
df["time"] = pd.to_datetime(df["dt"], format="%Y-%m-%d") | |
return df | |
def plot(self, baseline_ticket=None): | |
""" | |
Plot cumulative returns over time. | |
use baseline_ticket to specify stock you want to use for comparison | |
(default: equal weighted returns) | |
""" | |
baseline_label = "Equal-weight portfolio" | |
tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"} | |
if baseline_ticket: | |
# 使用指定ticket作为baseline | |
baseline_df = self.get_baseline(baseline_ticket) | |
baseline_date_list = baseline_df.time.dt.strftime("%Y-%m-%d").tolist() | |
df_date_list = self.df_account_value.time.tolist() | |
df_account_value = self.df_account_value[ | |
self.df_account_value.time.isin(baseline_date_list) | |
] | |
baseline_df = baseline_df[baseline_df.time.isin(df_date_list)] | |
baseline = baseline_df.close.tolist() | |
baseline_label = tic2label.get(baseline_ticket, baseline_ticket) | |
ours = df_account_value.account_value.tolist() | |
else: | |
# 均等权重 | |
all_date = self.trade.time.unique().tolist() | |
baseline = [] | |
for day in all_date: | |
day_close = self.trade[self.trade["time"] == day].close.tolist() | |
avg_close = sum(day_close) / len(day_close) | |
baseline.append(avg_close) | |
ours = self.df_account_value.account_value.tolist() | |
ours = self.pct(ours) | |
baseline = self.pct(baseline) | |
days_per_tick = ( | |
60 # you should scale this variable accroding to the total trading days | |
) | |
time = list(range(len(ours))) | |
datetimes = self.df_account_value.time.tolist() | |
ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0] | |
plt.title("Cumulative Returns") | |
plt.plot(time, ours, label="DDPG Agent", color="green") | |
plt.plot(time, baseline, label=baseline_label, color="grey") | |
plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7) | |
plt.xlabel("Date") | |
plt.ylabel("Cumulative Return") | |
plt.legend() | |
plt.show() | |
plt.savefig(f"plot_{baseline_ticket}.png") | |
def plot_all(self): | |
baseline_label = "Equal-weight portfolio" | |
tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"} | |
# time lists | |
# algorithm time list | |
df_date_list = self.df_account_value.time.tolist() | |
# 399300 time list | |
csi300_df = self.get_baseline("399300") | |
csi300_date_list = csi300_df.time.dt.strftime("%Y-%m-%d").tolist() | |
# 000016 time list | |
sh50_df = self.get_baseline("000016") | |
sh50_date_list = sh50_df.time.dt.strftime("%Y-%m-%d").tolist() | |
# find intersection | |
all_date = sorted( | |
list(set(df_date_list) & set(csi300_date_list) & set(sh50_date_list)) | |
) | |
# filter data | |
csi300_df = csi300_df[csi300_df.time.isin(all_date)] | |
baseline_300 = csi300_df.close.tolist() | |
baseline_label_300 = tic2label["399300"] | |
sh50_df = sh50_df[sh50_df.time.isin(all_date)] | |
baseline_50 = sh50_df.close.tolist() | |
baseline_label_50 = tic2label["000016"] | |
# 均等权重 | |
baseline_equal_weight = [] | |
for day in all_date: | |
day_close = self.trade[self.trade["time"] == day].close.tolist() | |
avg_close = sum(day_close) / len(day_close) | |
baseline_equal_weight.append(avg_close) | |
df_account_value = self.df_account_value[ | |
self.df_account_value.time.isin(all_date) | |
] | |
ours = df_account_value.account_value.tolist() | |
ours = self.pct(ours) | |
baseline_300 = self.pct(baseline_300) | |
baseline_50 = self.pct(baseline_50) | |
baseline_equal_weight = self.pct(baseline_equal_weight) | |
days_per_tick = ( | |
60 # you should scale this variable accroding to the total trading days | |
) | |
time = list(range(len(ours))) | |
datetimes = self.df_account_value.time.tolist() | |
ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0] | |
plt.title("Cumulative Returns") | |
plt.plot(time, ours, label="DDPG Agent", color="darkorange") | |
plt.plot( | |
time, | |
baseline_equal_weight, | |
label=baseline_label, | |
color="cornflowerblue", | |
) # equal weight | |
plt.plot( | |
time, baseline_300, label=baseline_label_300, color="lightgreen" | |
) # 399300 | |
plt.plot(time, baseline_50, label=baseline_label_50, color="silver") # 000016 | |
plt.xlabel("Date") | |
plt.ylabel("Cumulative Return") | |
plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7) | |
plt.legend() | |
plt.show() | |
plt.savefig("./plot_all.png") | |
def pct(self, l): | |
"""Get percentage""" | |
base = l[0] | |
return [x / base for x in l] | |
def get_return(self, df, value_col_name="account_value"): | |
df = copy.deepcopy(df) | |
df["daily_return"] = df[value_col_name].pct_change(1) | |
df["time"] = pd.to_datetime(df["time"], format="%Y-%m-%d") | |
df.set_index("time", inplace=True, drop=True) | |
df.index = df.index.tz_localize("UTC") | |
return pd.Series(df["daily_return"], index=df.index) | |