Spaces:
Sleeping
Sleeping
| import torch | |
| from torch import Tensor | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class ResNetBlock(nn.Module): | |
| def __init__(self, num_hidden: int): | |
| super(ResNetBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1, bias=False) | |
| self.bn1 = nn.BatchNorm2d(num_hidden) | |
| self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1, bias=False) | |
| self.bn2 = nn.BatchNorm2d(num_hidden) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| residual = x | |
| x = F.relu(self.bn1(self.conv1(x))) | |
| x = self.bn2(self.conv2(x)) | |
| x += residual | |
| x = F.relu(x) | |
| return x | |
| class DropoutBlock(nn.Module): | |
| def __init__(self, in_units: int, out_units: int, rate: float): | |
| super(DropoutBlock, self).__init__() | |
| self.model = nn.Sequential( | |
| nn.Linear(in_units, out_units), | |
| nn.BatchNorm1d(out_units), | |
| nn.ReLU(), | |
| nn.Dropout(rate) | |
| ) | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.model(x) | |
| class Model(nn.Module): | |
| def __init__(self, n_action: int, num_hidden: int, num_resblock:int, | |
| rate:float, row:int, col: int, device: str): | |
| super(Model, self).__init__() | |
| # Bottom layer | |
| self.initial_block = nn.Sequential( | |
| nn.Conv2d(4, num_hidden, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(num_hidden), | |
| nn.ReLU() | |
| ).to(device) | |
| self.res_blocks = nn.Sequential( | |
| *[ResNetBlock(num_hidden) for _ in range(num_resblock)] | |
| ).to(device) | |
| self.dropout_model = nn.Sequential( | |
| DropoutBlock(num_hidden*row*col, 200, rate), | |
| DropoutBlock(200, 100, rate) | |
| ) | |
| self.model = nn.Sequential( | |
| self.initial_block, | |
| self.res_blocks, | |
| nn.Flatten(), | |
| self.dropout_model | |
| ) | |
| self.policy_head = nn.Sequential( | |
| nn.Linear(100, 100), | |
| nn.ReLU(), | |
| nn.Linear(100, n_action), | |
| ).to(device) | |
| self.value_head = nn.Sequential( | |
| nn.Linear(100, 100), | |
| nn.ReLU(), | |
| nn.Linear(100, 1), | |
| nn.Tanh() | |
| ).to(device) | |
| self.to(device) | |
| self.device = device | |
| # Losses | |
| # Mean Square Error for minimizing the difference between estimated value and target value | |
| self.mse_loss = nn.MSELoss() | |
| # Cross entropy loss to evaluate the correct policy as compared to target policy | |
| self.ce_loss = nn.CrossEntropyLoss() | |
| def forward(self, x): | |
| x = self.model(x) | |
| value = self.value_head(x) | |
| policy = self.policy_head(x) | |
| return value, policy | |
| # Perform the loss calculation | |
| def get_loss(self, pred_val, pred_policy, true_val, true_policy): | |
| val_loss = self.mse_loss(pred_val, true_val) | |
| policy_loss = self.ce_loss(pred_policy, true_policy) | |
| final_loss = val_loss + policy_loss | |
| return { | |
| 'total_loss': final_loss, | |
| 'value_loss': val_loss, | |
| 'policy_loss': policy_loss | |
| } |