File size: 3,990 Bytes
de6e775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import List

import baostock as bs
import numpy as np
import pandas as pd
import pytz
import yfinance as yf

"""Reference: https://github.com/AI4Finance-LLC/FinRL"""

try:
    import exchange_calendars as tc
except:
    print(
        "Cannot import exchange_calendars.",
        "If you are using python>=3.7, please install it.",
    )
    import trading_calendars as tc

    print("Use trading_calendars instead for yahoofinance processor..")
# from basic_processor import _Base
from meta.data_processors._base import _Base
from meta.data_processors._base import calc_time_zone

from meta.config import (
    TIME_ZONE_SHANGHAI,
    TIME_ZONE_USEASTERN,
    TIME_ZONE_PARIS,
    TIME_ZONE_BERLIN,
    TIME_ZONE_JAKARTA,
    TIME_ZONE_SELFDEFINED,
    USE_TIME_ZONE_SELFDEFINED,
    BINANCE_BASE_URL,
)


class Baostock(_Base):
    def __init__(
        self,
        data_source: str,
        start_date: str,
        end_date: str,
        time_interval: str,
        **kwargs,
    ):
        super().__init__(data_source, start_date, end_date, time_interval, **kwargs)

    # 日k线、周k线、月k线,以及5分钟、15分钟、30分钟和60分钟k线数据
    # ["5m", "15m", "30m", "60m", "1d", "1w", "1M"]
    def download_data(
        self, ticker_list: List[str], save_path: str = "./data/dataset.csv"
    ):
        lg = bs.login()
        print("baostock login respond error_code:" + lg.error_code)
        print("baostock login respond  error_msg:" + lg.error_msg)

        self.time_zone = calc_time_zone(
            ticker_list, TIME_ZONE_SELFDEFINED, USE_TIME_ZONE_SELFDEFINED
        )
        self.dataframe = pd.DataFrame()
        for ticker in ticker_list:
            nonstandrad_ticker = self.transfer_standard_ticker_to_nonstandard(ticker)
            # All supported: "date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,isST"
            rs = bs.query_history_k_data_plus(
                nonstandrad_ticker,
                "date,code,open,high,low,close,volume",
                start_date=self.start_date,
                end_date=self.end_date,
                frequency=self.time_interval,
                adjustflag="3",
            )

            print("baostock download_data respond error_code:" + rs.error_code)
            print("baostock download_data respond  error_msg:" + rs.error_msg)

            data_list = []
            while (rs.error_code == "0") & rs.next():
                data_list.append(rs.get_row_data())
            df = pd.DataFrame(data_list, columns=rs.fields)
            df.loc[:, "code"] = [ticker] * df.shape[0]
            self.dataframe = pd.concat([self.dataframe, df])
        self.dataframe = self.dataframe.sort_values(by=["date", "code"]).reset_index(
            drop=True
        )
        bs.logout()

        self.dataframe.open = self.dataframe.open.astype(float)
        self.dataframe.high = self.dataframe.high.astype(float)
        self.dataframe.low = self.dataframe.low.astype(float)
        self.dataframe.close = self.dataframe.close.astype(float)
        self.save_data(save_path)

        print(
            f"Download complete! Dataset saved to {save_path}. \nShape of DataFrame: {self.dataframe.shape}"
        )

    def get_trading_days(self, start, end):
        lg = bs.login()
        print("baostock login respond error_code:" + lg.error_code)
        print("baostock login respond  error_msg:" + lg.error_msg)
        result = bs.query_trade_dates(start_date=start, end_date=end)
        bs.logout()
        return result

    # "600000.XSHG" -> "sh.600000"
    # "000612.XSHE" -> "sz.000612"
    def transfer_standard_ticker_to_nonstandard(self, ticker: str) -> str:
        n, alpha = ticker.split(".")
        assert alpha in ["XSHG", "XSHE"], "Wrong alpha"
        if alpha == "XSHG":
            nonstandard_ticker = "sh." + n
        elif alpha == "XSHE":
            nonstandard_ticker = "sz." + n
        return nonstandard_ticker