qdqn-FrozenLake-v1 / example.py
Arnas
refactor
e7650e8
raw
history blame
499 Bytes
import gym
import yaml
import torch
from helpers.qnn import QuantumNet
from helpers.wrappers import BinaryWrapper
from helpers.agent import Agent
# Environment
env_name = 'FrozenLake-v1'
env = gym.make(env_name)
env = BinaryWrapper(env)
# Network
with open('config.yaml', 'r') as f:
hparams = yaml.safe_load(f)
net = QuantumNet(hparams['n_layers'])
state_dict = torch.load('qdqn-FrozenLake-v1.pt', map_location=torch.device('cpu'))
net.load_state_dict(state_dict)
# Agent
agent = Agent(net)