| from __future__ import annotations |
|
|
| from typing import Dict, List, Sequence |
|
|
|
|
| MOVEMENT_STRONG_DOWN_THRESHOLD = -0.40 |
| MOVEMENT_DOWN_THRESHOLD = -0.30 |
| MOVEMENT_PUMP_50_THRESHOLD = 0.50 |
| MOVEMENT_PUMP_100_THRESHOLD = 1.00 |
| MOVEMENT_PUMP_300_THRESHOLD = 3.00 |
|
|
| MOVEMENT_CLASS_NAMES = [ |
| "strong_down", |
| "down", |
| "flat", |
| "up", |
| "strong_up", |
| "extreme_up", |
| ] |
| MOVEMENT_CLASS_TO_ID = {name: idx for idx, name in enumerate(MOVEMENT_CLASS_NAMES)} |
| MOVEMENT_ID_TO_CLASS = {idx: name for name, idx in MOVEMENT_CLASS_TO_ID.items()} |
|
|
| DEFAULT_MOVEMENT_LABEL_CONFIG = { |
| "strong_down_threshold": MOVEMENT_STRONG_DOWN_THRESHOLD, |
| "down_threshold": MOVEMENT_DOWN_THRESHOLD, |
| "pump_50_threshold": MOVEMENT_PUMP_50_THRESHOLD, |
| "pump_100_threshold": MOVEMENT_PUMP_100_THRESHOLD, |
| "pump_300_threshold": MOVEMENT_PUMP_300_THRESHOLD, |
| } |
|
|
|
|
| def classify_movement_return( |
| return_value: float, |
| movement_label_config: Dict[str, float] | None = None, |
| ) -> int: |
| cfg = dict(DEFAULT_MOVEMENT_LABEL_CONFIG) |
| if movement_label_config: |
| cfg.update({k: float(v) for k, v in movement_label_config.items() if k in cfg}) |
|
|
| strong_down_threshold = min(cfg["strong_down_threshold"], cfg["down_threshold"]) |
| down_threshold = cfg["down_threshold"] |
| pump_50_threshold = cfg["pump_50_threshold"] |
| pump_100_threshold = cfg["pump_100_threshold"] |
| pump_300_threshold = cfg["pump_300_threshold"] |
|
|
| if return_value <= strong_down_threshold: |
| return MOVEMENT_CLASS_TO_ID["strong_down"] |
| if return_value < down_threshold: |
| return MOVEMENT_CLASS_TO_ID["down"] |
| if return_value < pump_50_threshold: |
| return MOVEMENT_CLASS_TO_ID["flat"] |
| if return_value < pump_100_threshold: |
| return MOVEMENT_CLASS_TO_ID["up"] |
| if return_value < pump_300_threshold: |
| return MOVEMENT_CLASS_TO_ID["strong_up"] |
| return MOVEMENT_CLASS_TO_ID["extreme_up"] |
|
|
|
|
| def derive_movement_targets( |
| horizon_returns: Sequence[float], |
| horizon_mask: Sequence[float], |
| movement_label_config: Dict[str, float] | None = None, |
| ) -> Dict[str, List[int]]: |
| class_targets: List[int] = [] |
| class_mask: List[int] = [] |
| class_names: List[str] = [] |
|
|
| usable = min(len(horizon_returns), len(horizon_mask)) |
| for idx in range(usable): |
| if float(horizon_mask[idx]) <= 0: |
| class_targets.append(MOVEMENT_CLASS_TO_ID["flat"]) |
| class_mask.append(0) |
| class_names.append("masked") |
| continue |
|
|
| class_id = classify_movement_return( |
| float(horizon_returns[idx]), |
| movement_label_config=movement_label_config, |
| ) |
| class_targets.append(class_id) |
| class_mask.append(1) |
| class_names.append(MOVEMENT_ID_TO_CLASS[class_id]) |
|
|
| return { |
| "movement_class_targets": class_targets, |
| "movement_class_mask": class_mask, |
| "movement_class_names": class_names, |
| } |
|
|
|
|
| def compute_movement_label_config(valid_returns: Sequence[float]) -> Dict[str, float]: |
| del valid_returns |
| return dict(DEFAULT_MOVEMENT_LABEL_CONFIG) |
|
|