kristada673's picture
Upload 19 files
de6e775
import copy
import os
import time
import warnings
warnings.filterwarnings("ignore")
from typing import List
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt
import stockstats
import talib
from meta.data_processors._base import _Base
import tushare as ts
class Tushare(_Base):
"""
key-value in kwargs
----------
token : str
get from https://waditu.com/ after registration
adj: str
Whether to use adjusted closing price. Default is None.
If you want to use forward adjusted closing price or 前复权. pleses use 'qfq'
If you want to use backward adjusted closing price or 后复权. pleses use 'hfq'
"""
def __init__(
self,
data_source: str,
start_date: str,
end_date: str,
time_interval: str,
**kwargs,
):
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
assert "token" in kwargs.keys(), "Please input token!"
self.token = kwargs["token"]
if "adj" in kwargs.keys():
self.adj = kwargs["adj"]
print(f"Using {self.adj} method.")
else:
self.adj = None
def get_data(self, id) -> pd.DataFrame:
# df1 = ts.pro_bar(ts_code=id, start_date=self.start_date,end_date='20180101')
# dfb=pd.concat([df, df1], ignore_index=True)
# print(dfb.shape)
return ts.pro_bar(
ts_code=id,
start_date=self.start_date,
end_date=self.end_date,
adj=self.adj,
)
def download_data(
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
):
"""
`pd.DataFrame`
7 columns: A tick symbol, time, open, high, low, close and volume
for the specified stock ticker
"""
assert self.time_interval == "1d", "Not supported currently"
self.ticker_list = ticker_list
ts.set_token(self.token)
self.dataframe = pd.DataFrame()
for i in tqdm(ticker_list, total=len(ticker_list)):
# nonstandard_id = self.transfer_standard_ticker_to_nonstandard(i)
# df_temp = self.get_data(nonstandard_id)
df_temp = self.get_data(i)
self.dataframe = self.dataframe.append(df_temp)
# print("{} ok".format(i))
time.sleep(0.25)
self.dataframe.columns = [
"tic",
"time",
"open",
"high",
"low",
"close",
"pre_close",
"change",
"pct_chg",
"volume",
"amount",
]
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
self.dataframe.reset_index(drop=True, inplace=True)
self.dataframe = self.dataframe[
["tic", "time", "open", "high", "low", "close", "volume"]
]
# self.dataframe.loc[:, 'tic'] = pd.DataFrame((self.dataframe['tic'].tolist()))
self.dataframe["time"] = pd.to_datetime(self.dataframe["time"], format="%Y%m%d")
self.dataframe["day"] = self.dataframe["time"].dt.dayofweek
self.dataframe["time"] = self.dataframe.time.apply(
lambda x: x.strftime("%Y-%m-%d")
)
self.dataframe.dropna(inplace=True)
self.dataframe.sort_values(by=["time", "tic"], inplace=True)
self.dataframe.reset_index(drop=True, inplace=True)
self.save_data(save_path)
print(
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
)
def data_split(self, df, start, end, target_date_col="time"):
"""
split the dataset into training or testing using time
:param data: (df) pandas dataframe, start, end
:return: (df) pandas dataframe
"""
data = df[(df[target_date_col] >= start) & (df[target_date_col] < end)]
data = data.sort_values([target_date_col, "tic"], ignore_index=True)
data.index = data[target_date_col].factorize()[0]
return data
def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
# "600000.XSHG" -> "600000.SH"
# "000612.XSHE" -> "000612.SZ"
n, alpha = ticker.split(".")
assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
if alpha == "XSHG":
nonstandard_ticker = n + ".SH"
elif alpha == "XSHE":
nonstandard_ticker = n + ".SZ"
return nonstandard_ticker
def save_data(self, path):
if ".csv" in path:
path = path.split("/")
filename = path[-1]
path = "/".join(path[:-1] + [""])
else:
if path[-1] == "/":
filename = "dataset.csv"
else:
filename = "/dataset.csv"
os.makedirs(path, exist_ok=True)
self.dataframe.to_csv(path + filename, index=False)
def load_data(self, path):
assert ".csv" in path # only support csv format now
self.dataframe = pd.read_csv(path)
columns = self.dataframe.columns
assert (
"tic" in columns and "time" in columns and "close" in columns
) # input file must have "tic","time" and "close" columns
class ReturnPlotter:
"""
An easy-to-use plotting tool to plot cumulative returns over time.
Baseline supports equal weighting(default) and any stocks you want to use for comparison.
"""
def __init__(self, df_account_value, df_trade, start_date, end_date):
self.start = start_date
self.end = end_date
self.trade = df_trade
self.df_account_value = df_account_value
def get_baseline(self, ticket):
df = ts.get_hist_data(ticket, start=self.start, end=self.end)
df.loc[:, "dt"] = df.index
df.index = range(len(df))
df.sort_values(axis=0, by="dt", ascending=True, inplace=True)
df["time"] = pd.to_datetime(df["dt"], format="%Y-%m-%d")
return df
def plot(self, baseline_ticket=None):
"""
Plot cumulative returns over time.
use baseline_ticket to specify stock you want to use for comparison
(default: equal weighted returns)
"""
baseline_label = "Equal-weight portfolio"
tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"}
if baseline_ticket:
# 使用指定ticket作为baseline
baseline_df = self.get_baseline(baseline_ticket)
baseline_date_list = baseline_df.time.dt.strftime("%Y-%m-%d").tolist()
df_date_list = self.df_account_value.time.tolist()
df_account_value = self.df_account_value[
self.df_account_value.time.isin(baseline_date_list)
]
baseline_df = baseline_df[baseline_df.time.isin(df_date_list)]
baseline = baseline_df.close.tolist()
baseline_label = tic2label.get(baseline_ticket, baseline_ticket)
ours = df_account_value.account_value.tolist()
else:
# 均等权重
all_date = self.trade.time.unique().tolist()
baseline = []
for day in all_date:
day_close = self.trade[self.trade["time"] == day].close.tolist()
avg_close = sum(day_close) / len(day_close)
baseline.append(avg_close)
ours = self.df_account_value.account_value.tolist()
ours = self.pct(ours)
baseline = self.pct(baseline)
days_per_tick = (
60 # you should scale this variable accroding to the total trading days
)
time = list(range(len(ours)))
datetimes = self.df_account_value.time.tolist()
ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0]
plt.title("Cumulative Returns")
plt.plot(time, ours, label="DDPG Agent", color="green")
plt.plot(time, baseline, label=baseline_label, color="grey")
plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7)
plt.xlabel("Date")
plt.ylabel("Cumulative Return")
plt.legend()
plt.show()
plt.savefig(f"plot_{baseline_ticket}.png")
def plot_all(self):
baseline_label = "Equal-weight portfolio"
tic2label = {"399300": "CSI 300 Index", "000016": "SSE 50 Index"}
# time lists
# algorithm time list
df_date_list = self.df_account_value.time.tolist()
# 399300 time list
csi300_df = self.get_baseline("399300")
csi300_date_list = csi300_df.time.dt.strftime("%Y-%m-%d").tolist()
# 000016 time list
sh50_df = self.get_baseline("000016")
sh50_date_list = sh50_df.time.dt.strftime("%Y-%m-%d").tolist()
# find intersection
all_date = sorted(
list(set(df_date_list) & set(csi300_date_list) & set(sh50_date_list))
)
# filter data
csi300_df = csi300_df[csi300_df.time.isin(all_date)]
baseline_300 = csi300_df.close.tolist()
baseline_label_300 = tic2label["399300"]
sh50_df = sh50_df[sh50_df.time.isin(all_date)]
baseline_50 = sh50_df.close.tolist()
baseline_label_50 = tic2label["000016"]
# 均等权重
baseline_equal_weight = []
for day in all_date:
day_close = self.trade[self.trade["time"] == day].close.tolist()
avg_close = sum(day_close) / len(day_close)
baseline_equal_weight.append(avg_close)
df_account_value = self.df_account_value[
self.df_account_value.time.isin(all_date)
]
ours = df_account_value.account_value.tolist()
ours = self.pct(ours)
baseline_300 = self.pct(baseline_300)
baseline_50 = self.pct(baseline_50)
baseline_equal_weight = self.pct(baseline_equal_weight)
days_per_tick = (
60 # you should scale this variable accroding to the total trading days
)
time = list(range(len(ours)))
datetimes = self.df_account_value.time.tolist()
ticks = [tick for t, tick in zip(time, datetimes) if t % days_per_tick == 0]
plt.title("Cumulative Returns")
plt.plot(time, ours, label="DDPG Agent", color="darkorange")
plt.plot(
time,
baseline_equal_weight,
label=baseline_label,
color="cornflowerblue",
) # equal weight
plt.plot(
time, baseline_300, label=baseline_label_300, color="lightgreen"
) # 399300
plt.plot(time, baseline_50, label=baseline_label_50, color="silver") # 000016
plt.xlabel("Date")
plt.ylabel("Cumulative Return")
plt.xticks([i * days_per_tick for i in range(len(ticks))], ticks, fontsize=7)
plt.legend()
plt.show()
plt.savefig("./plot_all.png")
def pct(self, l):
"""Get percentage"""
base = l[0]
return [x / base for x in l]
def get_return(self, df, value_col_name="account_value"):
df = copy.deepcopy(df)
df["daily_return"] = df[value_col_name].pct_change(1)
df["time"] = pd.to_datetime(df["time"], format="%Y-%m-%d")
df.set_index("time", inplace=True, drop=True)
df.index = df.index.tz_localize("UTC")
return pd.Series(df["daily_return"], index=df.index)