hishamcse commited on
Commit
608696b
·
verified ·
1 Parent(s): a650346

Upload 2 files

Browse files
Files changed (2) hide show
  1. agent.py +66 -0
  2. config.cfg +43 -0
agent.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import json
4
+ import argparse
5
+ from diambra.arena import Roles, SpaceTypes, load_settings_flat_dict
6
+ from diambra.arena.stable_baselines3.make_sb3_env import make_sb3_env, EnvironmentSettings, WrappersSettings
7
+ from stable_baselines3 import PPO
8
+
9
+ def main(cfg_file, trained_model):
10
+ # Read the cfg file
11
+ yaml_file = open(cfg_file)
12
+ params = yaml.load(yaml_file, Loader=yaml.FullLoader)
13
+ print("Config parameters = ", json.dumps(params, sort_keys=True, indent=4))
14
+ yaml_file.close()
15
+
16
+ base_path = os.path.dirname(os.path.abspath(__file__))
17
+ model_folder = os.path.join(base_path, params["folders"]["parent_dir"], params["settings"]["game_id"],
18
+ params["folders"]["model_name"], "model")
19
+
20
+ # Settings
21
+ params["settings"]["action_space"] = SpaceTypes.DISCRETE if params["settings"]["action_space"] == "discrete" else SpaceTypes.MULTI_DISCRETE
22
+ settings = load_settings_flat_dict(EnvironmentSettings, params["settings"])
23
+ settings.role = Roles.P1
24
+
25
+ # Wrappers Settings
26
+ wrappers_settings = load_settings_flat_dict(WrappersSettings, params["wrappers_settings"])
27
+ wrappers_settings.normalize_reward = False
28
+
29
+ # Create environment
30
+ env, num_envs = make_sb3_env(settings.game_id, settings, wrappers_settings, no_vec=True)
31
+ print("Activated {} environment(s)".format(num_envs))
32
+
33
+ # Load the trained agent
34
+ model_path = os.path.join(model_folder, trained_model)
35
+ agent = PPO.load(model_path)
36
+
37
+ # Print policy network architecture
38
+ print("Policy architecture:")
39
+ print(agent.policy)
40
+
41
+ obs, info = env.reset()
42
+
43
+ while True:
44
+ action, _ = agent.predict(obs, deterministic=False)
45
+
46
+ obs, reward, terminated, truncated, info = env.step(action.tolist())
47
+
48
+ if terminated or truncated:
49
+ obs, info = env.reset()
50
+ if info["env_done"]:
51
+ break
52
+
53
+ # Close the environment
54
+ env.close()
55
+
56
+ # Return success
57
+ return 0
58
+
59
+ if __name__ == "__main__":
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--cfgFile", type=str, required=True, help="Configuration file")
62
+ parser.add_argument("--trainedModel", type=str, default="model", help="Model checkpoint")
63
+ opt = parser.parse_args()
64
+ print(opt)
65
+
66
+ main(opt.cfgFile, opt.trainedModel)
config.cfg ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ folders:
2
+ parent_dir: "./results/"
3
+ model_name: "sr6_128x4_das_nc"
4
+
5
+ settings:
6
+ game_id: "sfiii3n"
7
+ step_ratio: 6
8
+ frame_shape: !!python/tuple [128, 128, 1]
9
+ continue_game: 0.0
10
+ action_space: "discrete"
11
+ characters: "Ken"
12
+ difficulty: 6
13
+ outfits: 2
14
+
15
+ wrappers_settings:
16
+ normalize_reward: true
17
+ no_attack_buttons_combinations: true
18
+ stack_frames: 4
19
+ dilation: 1
20
+ add_last_action: true
21
+ stack_actions: 12
22
+ scale: true
23
+ exclude_image_scaling: true
24
+ role_relative: true
25
+ flatten: true
26
+ filter_keys: ["action", "own_health", "opp_health", "own_side", "opp_side", "opp_character", "stage", "timer"]
27
+
28
+ policy_kwargs:
29
+ #net_arch: [{ pi: [64, 64], vf: [32, 32] }]
30
+ net_arch: [64, 64]
31
+
32
+ ppo_settings:
33
+ gamma: 0.94
34
+ model_checkpoint: "100000" # 0: No checkpoint, 100000: Load checkpoint (if previously trained for 100000 steps)
35
+ learning_rate: [2.5e-4, 2.5e-6] # To start
36
+ clip_range: [0.15, 0.025] # To start
37
+ #learning_rate: [5.0e-5, 2.5e-6] # Fine Tuning
38
+ #clip_range: [0.075, 0.025] # Fine Tuning
39
+ batch_size: 512 #8 #nminibatches gave different batch size depending on the number of environments: batch_size = (n_steps * n_envs) // nminibatches
40
+ n_epochs: 4
41
+ n_steps: 512
42
+ autosave_freq: 10000
43
+ time_steps: 10000