import gym | |
import yaml | |
import torch | |
from helpers.qnn import QuantumNet | |
from helpers.wrappers import BinaryWrapper | |
from helpers.agent import Agent | |
# Environment | |
env_name = 'FrozenLake-v1' | |
env = gym.make(env_name) | |
env = BinaryWrapper(env) | |
# Network | |
with open('config.yaml', 'r') as f: | |
hparams = yaml.safe_load(f) | |
net = QuantumNet(hparams['n_layers']) | |
state_dict = torch.load('qdqn-FrozenLake-v1.pt', map_location=torch.device('cpu')) | |
net.load_state_dict(state_dict) | |
# Agent | |
agent = Agent(net) | |