from abc import ABC, abstractmethod from typing import NamedTuple, Optional, Tuple import numpy as np import torch import torch.nn as nn from torch.distributions import Distribution class PiForward(NamedTuple): pi: Distribution logp_a: Optional[torch.Tensor] entropy: Optional[torch.Tensor] class Actor(nn.Module, ABC): @abstractmethod def forward( self, obs: torch.Tensor, actions: Optional[torch.Tensor] = None, action_masks: Optional[torch.Tensor] = None, ) -> PiForward: ... def sample_weights(self, batch_size: int = 1) -> None: pass @property @abstractmethod def action_shape(self) -> Tuple[int, ...]: ... def pi_forward( distribution: Distribution, actions: Optional[torch.Tensor] = None ) -> PiForward: logp_a = None entropy = None if actions is not None: logp_a = distribution.log_prob(actions) entropy = distribution.entropy() return PiForward(distribution, logp_a, entropy)