Upload 2 files
Browse files- agent.py +66 -0
- config.cfg +43 -0
agent.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
from diambra.arena import Roles, SpaceTypes, load_settings_flat_dict
|
6 |
+
from diambra.arena.stable_baselines3.make_sb3_env import make_sb3_env, EnvironmentSettings, WrappersSettings
|
7 |
+
from stable_baselines3 import PPO
|
8 |
+
|
9 |
+
def main(cfg_file, trained_model):
|
10 |
+
# Read the cfg file
|
11 |
+
yaml_file = open(cfg_file)
|
12 |
+
params = yaml.load(yaml_file, Loader=yaml.FullLoader)
|
13 |
+
print("Config parameters = ", json.dumps(params, sort_keys=True, indent=4))
|
14 |
+
yaml_file.close()
|
15 |
+
|
16 |
+
base_path = os.path.dirname(os.path.abspath(__file__))
|
17 |
+
model_folder = os.path.join(base_path, params["folders"]["parent_dir"], params["settings"]["game_id"],
|
18 |
+
params["folders"]["model_name"], "model")
|
19 |
+
|
20 |
+
# Settings
|
21 |
+
params["settings"]["action_space"] = SpaceTypes.DISCRETE if params["settings"]["action_space"] == "discrete" else SpaceTypes.MULTI_DISCRETE
|
22 |
+
settings = load_settings_flat_dict(EnvironmentSettings, params["settings"])
|
23 |
+
settings.role = Roles.P1
|
24 |
+
|
25 |
+
# Wrappers Settings
|
26 |
+
wrappers_settings = load_settings_flat_dict(WrappersSettings, params["wrappers_settings"])
|
27 |
+
wrappers_settings.normalize_reward = False
|
28 |
+
|
29 |
+
# Create environment
|
30 |
+
env, num_envs = make_sb3_env(settings.game_id, settings, wrappers_settings, no_vec=True)
|
31 |
+
print("Activated {} environment(s)".format(num_envs))
|
32 |
+
|
33 |
+
# Load the trained agent
|
34 |
+
model_path = os.path.join(model_folder, trained_model)
|
35 |
+
agent = PPO.load(model_path)
|
36 |
+
|
37 |
+
# Print policy network architecture
|
38 |
+
print("Policy architecture:")
|
39 |
+
print(agent.policy)
|
40 |
+
|
41 |
+
obs, info = env.reset()
|
42 |
+
|
43 |
+
while True:
|
44 |
+
action, _ = agent.predict(obs, deterministic=False)
|
45 |
+
|
46 |
+
obs, reward, terminated, truncated, info = env.step(action.tolist())
|
47 |
+
|
48 |
+
if terminated or truncated:
|
49 |
+
obs, info = env.reset()
|
50 |
+
if info["env_done"]:
|
51 |
+
break
|
52 |
+
|
53 |
+
# Close the environment
|
54 |
+
env.close()
|
55 |
+
|
56 |
+
# Return success
|
57 |
+
return 0
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument("--cfgFile", type=str, required=True, help="Configuration file")
|
62 |
+
parser.add_argument("--trainedModel", type=str, default="model", help="Model checkpoint")
|
63 |
+
opt = parser.parse_args()
|
64 |
+
print(opt)
|
65 |
+
|
66 |
+
main(opt.cfgFile, opt.trainedModel)
|
config.cfg
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
folders:
|
2 |
+
parent_dir: "./results/"
|
3 |
+
model_name: "sr6_128x4_das_nc"
|
4 |
+
|
5 |
+
settings:
|
6 |
+
game_id: "sfiii3n"
|
7 |
+
step_ratio: 6
|
8 |
+
frame_shape: !!python/tuple [128, 128, 1]
|
9 |
+
continue_game: 0.0
|
10 |
+
action_space: "discrete"
|
11 |
+
characters: "Ken"
|
12 |
+
difficulty: 6
|
13 |
+
outfits: 2
|
14 |
+
|
15 |
+
wrappers_settings:
|
16 |
+
normalize_reward: true
|
17 |
+
no_attack_buttons_combinations: true
|
18 |
+
stack_frames: 4
|
19 |
+
dilation: 1
|
20 |
+
add_last_action: true
|
21 |
+
stack_actions: 12
|
22 |
+
scale: true
|
23 |
+
exclude_image_scaling: true
|
24 |
+
role_relative: true
|
25 |
+
flatten: true
|
26 |
+
filter_keys: ["action", "own_health", "opp_health", "own_side", "opp_side", "opp_character", "stage", "timer"]
|
27 |
+
|
28 |
+
policy_kwargs:
|
29 |
+
#net_arch: [{ pi: [64, 64], vf: [32, 32] }]
|
30 |
+
net_arch: [64, 64]
|
31 |
+
|
32 |
+
ppo_settings:
|
33 |
+
gamma: 0.94
|
34 |
+
model_checkpoint: "100000" # 0: No checkpoint, 100000: Load checkpoint (if previously trained for 100000 steps)
|
35 |
+
learning_rate: [2.5e-4, 2.5e-6] # To start
|
36 |
+
clip_range: [0.15, 0.025] # To start
|
37 |
+
#learning_rate: [5.0e-5, 2.5e-6] # Fine Tuning
|
38 |
+
#clip_range: [0.075, 0.025] # Fine Tuning
|
39 |
+
batch_size: 512 #8 #nminibatches gave different batch size depending on the number of environments: batch_size = (n_steps * n_envs) // nminibatches
|
40 |
+
n_epochs: 4
|
41 |
+
n_steps: 512
|
42 |
+
autosave_freq: 10000
|
43 |
+
time_steps: 10000
|