File size: 2,507 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from ding.framework import Context, OnlineRLContext, OfflineRLContext
import random
import numpy as np
import treetensor.torch as ttorch
import torch

batch_size = 64
n_sample = 8
action_dim = 1
obs_dim = 4
logit_dim = 2

n_episodes = 2
n_episode_length = 16
update_per_collect = 4
collector_env_num = 8


# the range here is meaningless and just for test
def fake_train_data():
    train_data = ttorch.as_tensor(
        {
            'action': torch.randint(0, 2, size=(action_dim, )),
            'collect_train_iter': torch.randint(0, 100, size=(1, )),
            'done': torch.tensor(False),
            'env_data_id': torch.tensor([2]),
            'next_obs': torch.randn(obs_dim),
            'obs': torch.randn(obs_dim),
            'reward': torch.randint(0, 2, size=(1, )),
        }
    )
    return train_data


def fake_online_rl_context():
    ctx = OnlineRLContext(
        env_step=random.randint(0, 100),
        env_episode=random.randint(0, 100),
        train_iter=random.randint(0, 100),
        train_data=[fake_train_data() for _ in range(batch_size)],
        train_output=[{
            'cur_lr': 0.001,
            'total_loss': random.uniform(0, 2)
        } for _ in range(update_per_collect)],
        obs=torch.randn(collector_env_num, obs_dim),
        action=[np.random.randint(low=0, high=1, size=(action_dim), dtype=np.int64) for _ in range(collector_env_num)],
        inference_output={
            env_id: {
                'logit': torch.randn(logit_dim),
                'action': torch.randint(0, 2, size=(action_dim, ))
            }
            for env_id in range(collector_env_num)
        },
        collect_kwargs={'eps': random.uniform(0, 1)},
        trajectories=[fake_train_data() for _ in range(n_sample)],
        episodes=[[fake_train_data() for _ in range(n_episode_length)] for _ in range(n_episodes)],
        trajectory_end_idx=[i for i in range(n_sample)],
        eval_value=random.uniform(-1.0, 1.0),
        last_eval_iter=random.randint(0, 100),
    )
    return ctx


def fake_offline_rl_context():
    ctx = OfflineRLContext(
        train_epoch=random.randint(0, 100),
        train_iter=random.randint(0, 100),
        train_data=[fake_train_data() for _ in range(batch_size)],
        train_output=[{
            'cur_lr': 0.001,
            'total_loss': random.uniform(0, 2)
        } for _ in range(update_per_collect)],
        eval_value=random.uniform(-1.0, 1.0),
        last_eval_iter=random.randint(0, 100),
    )
    return ctx