from cmath import inf |
from typing import Any, List |
from easydict import EasyDict |
from abc import abstractmethod |
from gym import spaces |
from gym.utils import seeding |
from enum import Enum |
import os |
import gym |
import copy |
import pandas as pd |
import numpy as np |
from ding.envs import BaseEnv, BaseEnvTimestep |
from ding.utils import ENV_REGISTRY |
from ding.torch_utils import to_ndarray |
def load_dataset(name, index_name): |
base_dir = os.path.dirname(os.path.abspath(__file__)) |
path = os.path.join(base_dir, 'data', name + '.csv') |
assert os.path.exists( |
path |
), "You need to put the stock data under the \'DI-engine/dizoo/gym_anytrading/envs/data\' folder.\n \ |
if using StocksEnv, you can download Google stocks data at \ |
https://github.com/AminHP/gym-anytrading/blob/master/gym_anytrading/datasets/data/STOCKS_GOOGL.csv" |
df = pd.read_csv(path, parse_dates=True, index_col=index_name) |
return df |
class Actions(int, Enum): |
SELL = 1 |
HOLD = 2 |
BUY = 3 |
class Positions(int, Enum): |
SHORT = -1. |
FLAT = 0. |
LONG = 1. |
def transform(position: Positions, action: int) -> Any: |
''' |
Overview: |
used by env.tep(). |
This func is used to transform the env's position from |
the input (position, action) pair according to the status machine. |
Arguments: |
- position(Positions) : Long, Short or Flat |
- action(int) : Doulbe_Sell, Sell, Hold, Buy, Double_Buy |
Returns: |
- next_position(Positions) : the position after transformation. |
''' |
if action == Actions.SELL: |
if position == Positions.LONG: |
return Positions.FLAT, False |
if position == Positions.FLAT: |
return Positions.SHORT, True |
if action == Actions.BUY: |
if position == Positions.SHORT: |
return Positions.FLAT, False |
if position == Positions.FLAT: |
return Positions.LONG, True |
if action == Actions.DOUBLE_SELL and (position == Positions.LONG or position == Positions.FLAT): |
return Positions.SHORT, True |
if action == Actions.DOUBLE_BUY and (position == Positions.SHORT or position == Positions.FLAT): |
return Positions.LONG, True |
return position, False |
@ENV_REGISTRY.register('base_trading') |
class TradingEnv(BaseEnv): |
def __init__(self, cfg: EasyDict) -> None: |
self._cfg = cfg |
self._env_id = cfg.env_id |
self.cnt = 0 |
if 'plot_freq' not in self._cfg: |
self.plot_freq = 10 |
else: |
self.plot_freq = self._cfg.plot_freq |
if 'save_path' not in self._cfg: |
self.save_path = './' |
else: |
self.save_path = self._cfg.save_path |
self.train_range = cfg.train_range |
self.test_range = cfg.test_range |
self.window_size = cfg.window_size |
self.prices = None |
self.signal_features = None |
self.feature_dim_len = None |
self.shape = (cfg.window_size, 3) |
self._start_tick = 0 |
self._end_tick = 0 |
self._done = None |
self._current_tick = None |
self._last_trade_tick = None |
self._position = None |
self._position_history = None |
self._total_reward = None |
self._init_flag = True |
self._action_space = None |
self._observation_space = None |
self._reward_space = None |
def seed(self, seed: int, dynamic_seed: bool = True) -> None: |
self._seed = seed |
self._dynamic_seed = dynamic_seed |
np.random.seed(self._seed) |
self.np_random, seed = seeding.np_random(seed) |
def reset(self, start_idx: int = None) -> Any: |
self.cnt += 1 |
self.prices, self.signal_features, self.feature_dim_len = self._process_data(start_idx) |
if self._init_flag: |
self.shape = (self.window_size, self.feature_dim_len) |
self._action_space = spaces.Discrete(len(Actions)) |
self._observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float64) |
self._reward_space = gym.spaces.Box(-inf, inf, shape=(1, ), dtype=np.float32) |
self._init_flag = False |
self._done = False |
self._current_tick = self._start_tick |
self._last_trade_tick = self._current_tick - 1 |
self._position = Positions.FLAT |
self._position_history = [self._position] |
self._profit_history = [1.] |
self._total_reward = 0. |
return self._get_observation() |
def random_action(self) -> Any: |
return np.array([self.action_space.sample()]) |
def step(self, action: np.ndarray) -> BaseEnvTimestep: |
assert isinstance(action, np.ndarray), type(action) |
if action.shape == (1, ): |
action = action.item() |
self._done = False |
self._current_tick += 1 |
if self._current_tick >= self._end_tick: |
self._done = True |
step_reward = self._calculate_reward(action) |
self._total_reward += step_reward |
self._position, trade = transform(self._position, action) |
if trade: |
self._last_trade_tick = self._current_tick |
self._position_history.append(self._position) |
self._profit_history.append(float(np.exp(self._total_reward))) |
observation = self._get_observation() |
info = dict( |
total_reward=self._total_reward, |
position=self._position.value, |
) |
if self._done: |
if self._env_id[-1] == 'e' and self.cnt % self.plot_freq == 0: |
self.render() |
info['max_possible_profit'] = np.log(self.max_possible_profit()) |
info['eval_episode_return'] = self._total_reward |
step_reward = to_ndarray([step_reward]).astype(np.float32) |
return BaseEnvTimestep(observation, step_reward, self._done, info) |
def _get_observation(self) -> np.ndarray: |
obs = to_ndarray(self.signal_features[(self._current_tick - self.window_size + 1):self._current_tick + 1] |
).reshape(-1).astype(np.float32) |
tick = (self._current_tick - self._last_trade_tick) / self._cfg.eps_length |
obs = np.hstack([obs, to_ndarray([self._position.value]), to_ndarray([tick])]).astype(np.float32) |
return obs |
def render(self) -> None: |
import matplotlib.pyplot as plt |
plt.clf() |
plt.xlabel('trading days') |
plt.ylabel('profit') |
plt.plot(self._profit_history) |
plt.savefig(self.save_path + str(self._env_id) + "-profit.png") |
plt.clf() |
plt.xlabel('trading days') |
plt.ylabel('close price') |
window_ticks = np.arange(len(self._position_history)) |
eps_price = self.raw_prices[self._start_tick:self._end_tick + 1] |
plt.plot(eps_price) |
short_ticks = [] |
long_ticks = [] |
flat_ticks = [] |
for i, tick in enumerate(window_ticks): |
if self._position_history[i] == Positions.SHORT: |
short_ticks.append(tick) |
elif self._position_history[i] == Positions.LONG: |
long_ticks.append(tick) |
else: |
flat_ticks.append(tick) |
plt.plot(long_ticks, eps_price[long_ticks], 'g^', markersize=3, label="Long") |
plt.plot(flat_ticks, eps_price[flat_ticks], 'bo', markersize=3, label="Flat") |
plt.plot(short_ticks, eps_price[short_ticks], 'rv', markersize=3, label="Short") |
plt.legend(loc='upper left', bbox_to_anchor=(0.05, 0.95)) |
plt.savefig(self.save_path + str(self._env_id) + '-price.png') |
def close(self): |
import matplotlib.pyplot as plt |
plt.close() |
def create_collector_env_cfg(cfg: dict) -> List[dict]: |
""" |
Overview: |
Return a list of all of the environment from input config, used in env manager \ |
(a series of vectorized env), and this method is mainly responsible for envs collecting data. |
In TradingEnv, this method will rename every env_id and generate different config. |
Arguments: |
- cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \ |
env instance actually and generated the corresponding number of configurations. |
Returns: |
- env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config collector envs. |
.. note:: |
Elements(env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port. |
""" |
collector_env_num = cfg.pop('collector_env_num') |
collector_env_cfg = [copy.deepcopy(cfg) for _ in range(collector_env_num)] |
for i in range(collector_env_num): |
collector_env_cfg[i]['env_id'] += ('-' + str(i) + 'e') |
return collector_env_cfg |
def create_evaluator_env_cfg(cfg: dict) -> List[dict]: |
""" |
Overview: |
Return a list of all of the environment from input config, used in env manager \ |
(a series of vectorized env), and this method is mainly responsible for envs evaluating performance. |
In TradingEnv, this method will rename every env_id and generate different config. |
Arguments: |
- cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \ |
env instance actually and generated the corresponding number of configurations. |
Returns: |
- env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config evaluator envs. |
""" |
evaluator_env_num = cfg.pop('evaluator_env_num') |
evaluator_env_cfg = [copy.deepcopy(cfg) for _ in range(evaluator_env_num)] |
for i in range(evaluator_env_num): |
evaluator_env_cfg[i]['env_id'] += ('-' + str(i) + 'e') |
return evaluator_env_cfg |
@abstractmethod |
def _process_data(self): |
raise NotImplementedError |
@abstractmethod |
def _calculate_reward(self, action): |
raise NotImplementedError |
@abstractmethod |
def max_possible_profit(self): |
raise NotImplementedError |
@property |
def observation_space(self) -> gym.spaces.Space: |
return self._observation_space |
@property |
def action_space(self) -> gym.spaces.Space: |
return self._action_space |
@property |
def reward_space(self) -> gym.spaces.Space: |
return self._reward_space |
def __repr__(self) -> str: |
return "DI-engine Trading Env" |