File size: 2,779 Bytes
8facf52
 
b0da1ed
 
 
8facf52
b5622c0
b0da1ed
b5622c0
df0a4e7
b5622c0
 
 
b0da1ed
 
 
 
 
 
 
 
 
 
 
 
8facf52
b0da1ed
8facf52
df0a4e7
b5622c0
b6884a4
8facf52
 
 
 
7164c21
8facf52
 
 
 
 
 
 
 
b5622c0
8facf52
 
 
 
 
 
b5622c0
8facf52
 
 
b5622c0
 
 
 
 
 
 
 
8facf52
b5622c0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from datasets import load_dataset
import pandas as pd
from huggingface_hub import HfApi

huggingface_repo_name = "ellendan/a-share-prices"

def load_calendar():
    dataset = load_dataset(huggingface_repo_name, data_files="calendar.csv")
    data_frame = dataset['train'].to_pandas()
    return pd.DatetimeIndex(pd.to_datetime(data_frame['cal_date']))

calendar_list_index = load_calendar()

def read_last_update_date():
    hf_api = HfApi()

    file_content = hf_api.hf_hub_download(
        repo_id=huggingface_repo_name,
        filename='.last_update_date',
        repo_type="dataset"
    )
    with open(file_content, "r") as f:
        last_trade_date = f.read().strip()
    return last_trade_date

def load_stock_data():
    dataset = load_dataset(huggingface_repo_name, data_files="all-prices.csv", download_mode="reuse_dataset_if_exists")
    data_frame = dataset['train'].to_pandas()
    data_frame['date'] = pd.to_datetime(data_frame['date'])
    data_frame.set_index(['code','date'], inplace=True)
    data_frame.sort_index(inplace=True)
    return data_frame

def prepare_features(source_data_frame, last_n_days=400):
    source_data_frame = source_data_frame.iloc[-last_n_days:]
    is_high_limit = (source_data_frame['close'] == source_data_frame['high_limit']) * 1 * (source_data_frame['quote_rate'] > 6)
    if is_high_limit[is_high_limit > 0].empty:
        return pd.DataFrame()
    source_data_frame['is_high_limit'] = is_high_limit
    is_segment_start = (is_high_limit.diff() != 0) * 1
    source_data_frame['segment_start'] = is_segment_start
    source_data_frame['segment_index'] = is_segment_start.cumsum()
    serie_high_limit = source_data_frame.groupby(by='segment_index').apply(
        lambda x: pd.DataFrame({
            'date': x.index.get_level_values("date"), 
            'high_limit_days': x['is_high_limit'].cumsum()
            }),
            include_groups=False
        )
    serie_high_limit = serie_high_limit.groupby(by='segment_index').agg(
        max_high_limit_days=('high_limit_days', 'max'),
        date=("high_limit_days", lambda x: x.idxmax()[2])
    )
    return serie_high_limit

# 去掉因为退市等原因造成的数据,超过时间界限的问题
def remove_trade_data_overdate(data_frame, last_n_days=400):
    end_index = calendar_list_index.get_loc('2024-12-31')
    start_index = end_index-last_n_days

    start_day = calendar_list_index[start_index]
    return data_frame[data_frame['date'] >= start_day]

def serie_high_limit(source_data_frame, last_n_days=400):
    data_frame = source_data_frame.groupby(level=0).apply(lambda x: prepare_features(x, last_n_days))
    return remove_trade_data_overdate(data_frame, last_n_days)

if __name__ == "__main__":
    print(serie_high_limit(load_stock_data(), 5))