DQN playing SpaceInvadersNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/1d4094fbcc9082de7f53f4348dd4c7c354152907
0016a0a
import gym | |
import torch | |
from abc import ABC, abstractmethod | |
from stable_baselines3.common.vec_env.base_vec_env import VecEnv | |
from torch.utils.tensorboard.writer import SummaryWriter | |
from typing import List, Optional, TypeVar | |
from shared.callbacks.callback import Callback | |
from shared.policy.policy import Policy | |
from shared.stats import EpisodesStats | |
AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm") | |
class Algorithm(ABC): | |
def __init__( | |
self, | |
policy: Policy, | |
env: VecEnv, | |
device: torch.device, | |
tb_writer: SummaryWriter, | |
**kwargs, | |
) -> None: | |
super().__init__() | |
self.policy = policy | |
self.env = env | |
self.device = device | |
self.tb_writer = tb_writer | |
def learn( | |
self: AlgorithmSelf, total_timesteps: int, callback: Optional[Callback] = None | |
) -> AlgorithmSelf: | |
... | |