File size: 5,375 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from typing import Union
import numpy as np
import torch
class DiscreteSupport(object):
def __init__(self, min: int, max: int, delta: float = 1.) -> None:
assert min < max
self.min = min
self.max = max
self.range = np.arange(min, max + 1, delta)
self.size = len(self.range)
self.set_size = len(self.range)
self.delta = delta
def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) -> torch.Tensor:
"""
Overview:
Transform the original value to the scaled value, i.e. the h(.) function
in paper https://arxiv.org/pdf/1805.11593.pdf.
Reference:
- MuZero: Appendix F: Network Architecture
- https://arxiv.org/pdf/1805.11593.pdf (Page-11) Appendix A : Proposition A.2
"""
# h(.) function
if delta == 1: # for speed up
output = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + epsilon * x
else:
# delta != 1
output = torch.sign(x) * (torch.sqrt(torch.abs(x / delta) + 1) - 1) + epsilon * x / delta
return output
def inverse_scalar_transform(
logits: torch.Tensor,
support_size: int,
epsilon: float = 0.001,
categorical_distribution: bool = True
) -> torch.Tensor:
"""
Overview:
transform the scaled value or its categorical representation to the original value,
i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf.
Reference:
- MuZero Appendix F: Network Architecture.
- https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2
"""
if categorical_distribution:
scalar_support = DiscreteSupport(-support_size, support_size, delta=1)
value_probs = torch.softmax(logits, dim=1)
value_support = torch.from_numpy(scalar_support.range).unsqueeze(0)
value_support = value_support.to(device=value_probs.device)
value = (value_support * value_probs).sum(1, keepdim=True)
else:
value = logits
# h^(-1)(.) function
output = torch.sign(value) * (
((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) ** 2 - 1
)
# TODO(pu): comment this line due to saving time
# output[torch.abs(output) < epsilon] = 0.
return output
class InverseScalarTransform:
"""
Overview:
transform the the scaled value or its categorical representation to the original value,
i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf.
Reference:
- MuZero Appendix F: Network Architecture.
- https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2
"""
def __init__(
self,
support_size: int,
device: Union[str, torch.device] = 'cpu',
categorical_distribution: bool = True
) -> None:
scalar_support = DiscreteSupport(-support_size, support_size, delta=1)
self.value_support = torch.from_numpy(scalar_support.range).unsqueeze(0)
self.value_support = self.value_support.to(device)
self.categorical_distribution = categorical_distribution
def __call__(self, logits: torch.Tensor, epsilon: float = 0.001) -> torch.Tensor:
if self.categorical_distribution:
value_probs = torch.softmax(logits, dim=1)
value = value_probs.mul_(self.value_support).sum(1, keepdim=True)
else:
value = logits
tmp = ((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon))
# t * t is faster than t ** 2
output = torch.sign(value) * (tmp * tmp - 1)
return output
def visit_count_temperature(
manual_temperature_decay: bool, fixed_temperature_value: float,
threshold_training_steps_for_final_lr_temperature: int, trained_steps: int
) -> float:
if manual_temperature_decay:
if trained_steps < 0.5 * threshold_training_steps_for_final_lr_temperature:
return 1.0
elif trained_steps < 0.75 * threshold_training_steps_for_final_lr_temperature:
return 0.5
else:
return 0.25
else:
return fixed_temperature_value
def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
We then apply a transformation ``phi`` to the scalar in order to obtain equivalent categorical representations.
After this transformation, each scalar is represented as the linear combination of its two adjacent supports.
Reference:
- MuZero paper Appendix F: Network Architecture.
"""
min = discrete_support.min
max = discrete_support.max
set_size = discrete_support.set_size
delta = discrete_support.delta
x.clamp_(min, max)
x_low = x.floor()
x_high = x.ceil()
p_high = x - x_low
p_low = 1 - p_high
target = torch.zeros(x.shape[0], x.shape[1], set_size).to(x.device)
x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta
target.scatter_(2, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1))
target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1))
return target
def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return -(torch.log_softmax(prediction, dim=1) * target).sum(1)
|