sgoodfriend's picture
A2C playing LunarLander-v2 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
de6a584
raw
history blame contribute delete
No virus
1.03 kB
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.distributions import Distribution
class PiForward(NamedTuple):
pi: Distribution
logp_a: Optional[torch.Tensor]
entropy: Optional[torch.Tensor]
class Actor(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: torch.Tensor,
actions: Optional[torch.Tensor] = None,
action_masks: Optional[torch.Tensor] = None,
) -> PiForward:
...
def sample_weights(self, batch_size: int = 1) -> None:
pass
@property
@abstractmethod
def action_shape(self) -> Tuple[int, ...]:
...
def pi_forward(
distribution: Distribution, actions: Optional[torch.Tensor] = None
) -> PiForward:
logp_a = None
entropy = None
if actions is not None:
logp_a = distribution.log_prob(actions)
entropy = distribution.entropy()
return PiForward(distribution, logp_a, entropy)