File size: 3,231 Bytes
0829957
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import pandas as pd
import random
from sklearn.utils import shuffle
import torch
import torch.nn as nn
import torch.autograd as autograd
from torchcontrib.optim import SWA
from collections import deque

from preprocess import *

class DQN(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(DQN, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.fc = nn.Sequential( \
            nn.Linear(self.input_dim[0], 32), \
            nn.ReLU(), \
            nn.Linear(32, self.output_dim))

    def forward(self, state):
        return self.fc(state)
    
class DQNAgent:

    def __init__(self, input_dim, dataset,
                 learning_rate=3e-4, 
                 gamma=0.99,
                 buffer=None,
                 buffer_size=10000, 
                 tau=0.999,
                 swa=False,
                 pre_trained_model=None):
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.tau = tau
        self.model = DQN(input_dim, 1)
        if pre_trained_model:
            self.model = pre_trained_model
        base_opt = torch.optim.Adam(self.model.parameters())
        self.swa = swa
        self.dataset=dataset
        self.MSE_loss = nn.MSELoss()
        self.replay_buffer = buffer
        if swa:
          self.optimizer = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
        else:
          self.optimizer = base_opt

    def get_action(self, state, dataset=None):
        if dataset is None:
            dataset = self.dataset
        inputs = get_multiple_model_inputs(state, state.remaining, dataset)
        model_inputs = autograd.Variable(torch.from_numpy(inputs).float().unsqueeze(0))
        expected_returns = self.model.forward(model_inputs)
        value, index = expected_returns.max(1)
        return state.remaining[index[0]]

    def compute_loss(self, batch, dataset, verbose=False):
        states, actions, rewards, next_states, dones = batch
        model_inputs = np.array([get_model_inputs(states[i], actions[i], dataset)\
            for i in range(len(states))])
        model_inputs = torch.FloatTensor(model_inputs)

        rewards = torch.FloatTensor(rewards)
        dones = torch.FloatTensor(dones)

        curr_Q = self.model.forward(model_inputs)
        model_inputs = np.array([get_model_inputs(next_states[i], actions[i], dataset) \
            for i in range(len(next_states))])
        model_inputs = torch.FloatTensor(model_inputs)
        next_Q = self.model.forward(model_inputs)
        max_next_Q = torch.max(next_Q, 1)[0]
        expected_Q = rewards.squeeze(1) + (1 - dones) * self.gamma * max_next_Q

        if verbose:
            print(curr_Q, expected_Q)
        loss = self.MSE_loss(curr_Q.squeeze(0), expected_Q.detach())
        return loss

    def update(self, batch_size, verbose=False):
        batch = self.replay_buffer.sample(batch_size)
        loss = self.compute_loss(batch, self.dataset, verbose)
        train_loss = loss.float()
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        if self.swa:
            self.optimizer.swap_swa_sgd()
        return train_loss