from datasets import load_dataset import pandas as pd from huggingface_hub import HfApi huggingface_repo_name = "ellendan/a-share-prices" def load_calendar(): dataset = load_dataset(huggingface_repo_name, data_files="calendar.csv") data_frame = dataset['train'].to_pandas() return pd.DatetimeIndex(pd.to_datetime(data_frame['cal_date'])) calendar_list_index = load_calendar() def read_last_update_date(): hf_api = HfApi() file_content = hf_api.hf_hub_download( repo_id=huggingface_repo_name, filename='.last_update_date', repo_type="dataset" ) with open(file_content, "r") as f: last_trade_date = f.read().strip() return last_trade_date def load_stock_data(): dataset = load_dataset(huggingface_repo_name, data_files="all-prices.csv", download_mode="reuse_dataset_if_exists") data_frame = dataset['train'].to_pandas() data_frame['date'] = pd.to_datetime(data_frame['date']) data_frame.set_index(['code','date'], inplace=True) data_frame.sort_index(inplace=True) return data_frame def prepare_features(source_data_frame, last_n_days=400): source_data_frame = source_data_frame.iloc[-last_n_days:] is_high_limit = (source_data_frame['close'] == source_data_frame['high_limit']) * 1 * (source_data_frame['quote_rate'] > 6) if is_high_limit[is_high_limit > 0].empty: return pd.DataFrame() source_data_frame['is_high_limit'] = is_high_limit is_segment_start = (is_high_limit.diff() != 0) * 1 source_data_frame['segment_start'] = is_segment_start source_data_frame['segment_index'] = is_segment_start.cumsum() serie_high_limit = source_data_frame.groupby(by='segment_index').apply( lambda x: pd.DataFrame({ 'date': x.index.get_level_values("date"), 'high_limit_days': x['is_high_limit'].cumsum() }), include_groups=False ) serie_high_limit = serie_high_limit.groupby(by='segment_index').agg( max_high_limit_days=('high_limit_days', 'max'), date=("high_limit_days", lambda x: x.idxmax()[2]) ) return serie_high_limit # 去掉因为退市等原因造成的数据,超过时间界限的问题 def remove_trade_data_overdate(data_frame, last_n_days=400): end_index = calendar_list_index.get_loc('2024-12-31') start_index = end_index-last_n_days start_day = calendar_list_index[start_index] return data_frame[data_frame['date'] >= start_day] def serie_high_limit(source_data_frame, last_n_days=400): data_frame = source_data_frame.groupby(level=0).apply(lambda x: prepare_features(x, last_n_days)) return remove_trade_data_overdate(data_frame, last_n_days) if __name__ == "__main__": print(serie_high_limit(load_stock_data(), 5))