ellendan's picture
(fix) fix
7164c21
from datasets import load_dataset
import pandas as pd
from huggingface_hub import HfApi
huggingface_repo_name = "ellendan/a-share-prices"
def load_calendar():
dataset = load_dataset(huggingface_repo_name, data_files="calendar.csv")
data_frame = dataset['train'].to_pandas()
return pd.DatetimeIndex(pd.to_datetime(data_frame['cal_date']))
calendar_list_index = load_calendar()
def read_last_update_date():
hf_api = HfApi()
file_content = hf_api.hf_hub_download(
repo_id=huggingface_repo_name,
filename='.last_update_date',
repo_type="dataset"
)
with open(file_content, "r") as f:
last_trade_date = f.read().strip()
return last_trade_date
def load_stock_data():
dataset = load_dataset(huggingface_repo_name, data_files="all-prices.csv", download_mode="reuse_dataset_if_exists")
data_frame = dataset['train'].to_pandas()
data_frame['date'] = pd.to_datetime(data_frame['date'])
data_frame.set_index(['code','date'], inplace=True)
data_frame.sort_index(inplace=True)
return data_frame
def prepare_features(source_data_frame, last_n_days=400):
source_data_frame = source_data_frame.iloc[-last_n_days:]
is_high_limit = (source_data_frame['close'] == source_data_frame['high_limit']) * 1 * (source_data_frame['quote_rate'] > 6)
if is_high_limit[is_high_limit > 0].empty:
return pd.DataFrame()
source_data_frame['is_high_limit'] = is_high_limit
is_segment_start = (is_high_limit.diff() != 0) * 1
source_data_frame['segment_start'] = is_segment_start
source_data_frame['segment_index'] = is_segment_start.cumsum()
serie_high_limit = source_data_frame.groupby(by='segment_index').apply(
lambda x: pd.DataFrame({
'date': x.index.get_level_values("date"),
'high_limit_days': x['is_high_limit'].cumsum()
}),
include_groups=False
)
serie_high_limit = serie_high_limit.groupby(by='segment_index').agg(
max_high_limit_days=('high_limit_days', 'max'),
date=("high_limit_days", lambda x: x.idxmax()[2])
)
return serie_high_limit
# ๅŽปๆމๅ› ไธบ้€€ๅธ‚็ญ‰ๅŽŸๅ› ้€ ๆˆ็š„ๆ•ฐๆฎ๏ผŒ่ถ…่ฟ‡ๆ—ถ้—ด็•Œ้™็š„้—ฎ้ข˜
def remove_trade_data_overdate(data_frame, last_n_days=400):
end_index = calendar_list_index.get_loc('2024-12-31')
start_index = end_index-last_n_days
start_day = calendar_list_index[start_index]
return data_frame[data_frame['date'] >= start_day]
def serie_high_limit(source_data_frame, last_n_days=400):
data_frame = source_data_frame.groupby(level=0).apply(lambda x: prepare_features(x, last_n_days))
return remove_trade_data_overdate(data_frame, last_n_days)
if __name__ == "__main__":
print(serie_high_limit(load_stock_data(), 5))