kristada673's picture
Upload 19 files
de6e775
from typing import List
import baostock as bs
import numpy as np
import pandas as pd
import pytz
import yfinance as yf
"""Reference: https://github.com/AI4Finance-LLC/FinRL"""
try:
import exchange_calendars as tc
except:
print(
"Cannot import exchange_calendars.",
"If you are using python>=3.7, please install it.",
)
import trading_calendars as tc
print("Use trading_calendars instead for yahoofinance processor..")
# from basic_processor import _Base
from meta.data_processors._base import _Base
from meta.data_processors._base import calc_time_zone
from meta.config import (
TIME_ZONE_SHANGHAI,
TIME_ZONE_USEASTERN,
TIME_ZONE_PARIS,
TIME_ZONE_BERLIN,
TIME_ZONE_JAKARTA,
TIME_ZONE_SELFDEFINED,
USE_TIME_ZONE_SELFDEFINED,
BINANCE_BASE_URL,
)
class Baostock(_Base):
def __init__(
self,
data_source: str,
start_date: str,
end_date: str,
time_interval: str,
**kwargs,
):
super().__init__(data_source, start_date, end_date, time_interval, **kwargs)
# 日k线、周k线、月k线,以及5分钟、15分钟、30分钟和60分钟k线数据
# ["5m", "15m", "30m", "60m", "1d", "1w", "1M"]
def download_data(
self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
):
lg = bs.login()
print("baostock login respond error_code:" + lg.error_code)
print("baostock login respond error_msg:" + lg.error_msg)
self.time_zone = calc_time_zone(
ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
)
self.dataframe = pd.DataFrame()
for ticker in ticker_list:
nonstandrad_ticker = self.transfer_standard_ticker_to_nonstandard(ticker)
# All supported: "date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,isST"
rs = bs.query_history_k_data_plus(
nonstandrad_ticker,
"date,code,open,high,low,close,volume",
start_date=self.start_date,
end_date=self.end_date,
frequency=self.time_interval,
adjustflag="3",
)
print("baostock download_data respond error_code:" + rs.error_code)
print("baostock download_data respond error_msg:" + rs.error_msg)
data_list = []
while (rs.error_code == "0") & rs.next():
data_list.append(rs.get_row_data())
df = pd.DataFrame(data_list, columns=rs.fields)
df.loc[:, "code"] = [ticker] * df.shape[0]
self.dataframe = pd.concat([self.dataframe, df])
self.dataframe = self.dataframe.sort_values(by=["date", "code"]).reset_index(
drop=True
)
bs.logout()
self.dataframe.open = self.dataframe.open.astype(float)
self.dataframe.high = self.dataframe.high.astype(float)
self.dataframe.low = self.dataframe.low.astype(float)
self.dataframe.close = self.dataframe.close.astype(float)
self.save_data(save_path)
print(
f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
)
def get_trading_days(self, start, end):
lg = bs.login()
print("baostock login respond error_code:" + lg.error_code)
print("baostock login respond error_msg:" + lg.error_msg)
result = bs.query_trade_dates(start_date=start, end_date=end)
bs.logout()
return result
# "600000.XSHG" -> "sh.600000"
# "000612.XSHE" -> "sz.000612"
def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
n, alpha = ticker.split(".")
assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
if alpha == "XSHG":
nonstandard_ticker = "sh." + n
elif alpha == "XSHE":
nonstandard_ticker = "sz." + n
return nonstandard_ticker