File size: 4,218 Bytes
eaf2e33
 
3582c8a
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3582c8a
eaf2e33
 
 
 
 
 
3582c8a
eaf2e33
3582c8a
eaf2e33
 
 
 
 
 
3582c8a
 
 
eaf2e33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import glob
import random
from abc import abstractmethod
import numpy as np
import torch
from src.utils.filesys import getpath
from src.gan.gans import nz
from src.gan.gankits import sample_latvec
from src.drl.drl_uses import load_cfgs
from src.drl.sunrise.sunrise_adaption import SunriseProxyAgent


def process_obs(obs, device='cpu'):
    obs = torch.tensor(obs, device=device, dtype=torch.float32)
    if len(obs.shape) == 1:
        obs = obs.unsqueeze(0)
    return obs


class GenPolicy:
    def __init__(self, n=5):
        self.n = n # Number of segments in an observation

    @abstractmethod
    def step(self, obs):
        pass

    @staticmethod
    @abstractmethod
    def from_path(path, **kwargs):
        pass

    def reset(self):
        pass


class RLGenPolicy(GenPolicy):
    def __init__(self, model, n, device='cuda:0'):
        self.model = model
        super(RLGenPolicy, self).__init__(n)
        self.model.eval()
        self.model.to(device)
        self.device = device
        self.meregs = []

    def step(self, obs):
        obs = process_obs(obs, device=self.device)
        b, d = obs.shape
        if d < nz * self.n:
            obs = torch.cat([torch.zeros([b, nz * self.n - d], device=self.device), obs], dim=-1)
        with torch.no_grad():
            model_output, _ = self.model(obs)
        return torch.clamp(model_output, -1, 1).squeeze().cpu().numpy()

    @staticmethod
    def from_path(path, device='cuda:0'):
        model = torch.load(getpath(f'{path}/policy.pth'), map_location=device)
        n = load_cfgs(path, 'N')
        return RLGenPolicy(model, n, device)



class EnsembleGenPolicy(GenPolicy):
    def __init__(self, models, n, device='cpu'):
        super(EnsembleGenPolicy, self).__init__(n)
        for model in models:
            model.to(device)
        self.device = device
        self.models = models
        self.m = len(models)


    def step(self, obs):
        o = torch.tensor(obs, device=self.device, dtype=torch.float32)
        actions = []
        with torch.no_grad():
            for m in self.models:
                a = m(o) # action model predict
                if type(a) == tuple:
                    a = a[0]
                actions.append(torch.clamp(a, -1, 1).cpu().numpy())
        if len(obs.shape) == 1:
            return random.choice(actions)
        else:
            # 这里对于每个observation, 选择m个模型, 每个模型都输出一个动作, 然后随机选择其中一个动作
            actions = np.array(actions)
            # 这里的self.m就是模型的数量, 等价于len(self.models)
            selections = [random.choice(range(self.m)) for _ in range(len(obs))]
            selected = [actions[s, i, :] for i, s in enumerate(selections)]
            return np.array(selected)

    @staticmethod
    def from_path(path, device='cpu'):
        """
        读取path中的所有模型
        """
        models = [
            torch.load(p, map_location=device)
            for p in glob.glob(getpath(path, 'policy*.pth'))
        ]
        n = load_cfgs(path, 'N')
        return EnsembleGenPolicy(models, n, device)


class RandGenPolicy(GenPolicy):
    def __init__(self):
        super(RandGenPolicy, self).__init__(1)

    def step(self, obs):
        n = obs.shape[0]
        return sample_latvec(n).squeeze().numpy()

    @staticmethod
    def from_path(path, **kwargs):
        return RandGenPolicy()


class DvDGenPolicy(GenPolicy):
    def __init__(self, learner, n, rand_switch=False):
        super(DvDGenPolicy, self).__init__(n)
        self.master = learner
        self.rand_switch = rand_switch
        self.working_policy = None
        self.reset()

    def step(self, obs):
        return self.working_policy.forward(obs).astype(np.float32)

    def reset(self):
        if self.rand_switch:
            self.working_policy = random.choice(self.master.agents)
        else:
            self.working_policy = self.master.agents[self.master.agent]

    @staticmethod
    def from_path(path, device='cpu'):
        """
            We don't find loading function in DvD-ES codes and we have no idea about
            how to implement it :-(
        """
        return None