sgoodfriend's picture
A2C playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
05b94c0
raw
history blame contribute delete
No virus
1.02 kB
from abc import ABC, abstractmethod
from typing import List, Optional, TypeVar
import gym
import torch
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
class Algorithm(ABC):
@abstractmethod
def __init__(
self,
policy: Policy,
env: VecEnv,
device: torch.device,
tb_writer: SummaryWriter,
**kwargs,
) -> None:
super().__init__()
self.policy = policy
self.env = env
self.device = device
self.tb_writer = tb_writer
@abstractmethod
def learn(
self: AlgorithmSelf,
train_timesteps: int,
callbacks: Optional[List[Callback]] = None,
total_timesteps: Optional[int] = None,
start_timesteps: int = 0,
) -> AlgorithmSelf:
...