hishamcse commited on
Commit
cf3c783
1 Parent(s): 285b79d

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. agent.py +66 -0
  3. config.yaml +50 -0
  4. replay.mp4 +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ replay.mp4 filter=lfs diff=lfs merge=lfs -text
agent.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import json
4
+ import argparse
5
+ from diambra.arena import Roles, SpaceTypes, load_settings_flat_dict
6
+ from diambra.arena.stable_baselines3.make_sb3_env import make_sb3_env, EnvironmentSettings, WrappersSettings
7
+ from stable_baselines3 import PPO
8
+
9
+ def main(cfg_file, trained_model):
10
+ # Read the cfg file
11
+ yaml_file = open(cfg_file)
12
+ params = yaml.load(yaml_file, Loader=yaml.FullLoader)
13
+ print("Config parameters = ", json.dumps(params, sort_keys=True, indent=4))
14
+ yaml_file.close()
15
+
16
+ base_path = os.path.dirname(os.path.abspath(__file__))
17
+ model_folder = os.path.join(base_path, params["folders"]["parent_dir"], params["settings"]["game_id"],
18
+ params["folders"]["model_name"], "model")
19
+
20
+ # Settings
21
+ params["settings"]["action_space"] = SpaceTypes.DISCRETE if params["settings"]["action_space"] == "discrete" else SpaceTypes.MULTI_DISCRETE
22
+ settings = load_settings_flat_dict(EnvironmentSettings, params["settings"])
23
+ settings.role = Roles.P1
24
+
25
+ # Wrappers Settings
26
+ wrappers_settings = load_settings_flat_dict(WrappersSettings, params["wrappers_settings"])
27
+ wrappers_settings.normalize_reward = False
28
+
29
+ # Create environment
30
+ env, num_envs = make_sb3_env(settings.game_id, settings, wrappers_settings, no_vec=True)
31
+ print("Activated {} environment(s)".format(num_envs))
32
+
33
+ # Load the trained agent
34
+ model_path = os.path.join(model_folder, trained_model)
35
+ agent = PPO.load(model_path)
36
+
37
+ # Print policy network architecture
38
+ print("Policy architecture:")
39
+ print(agent.policy)
40
+
41
+ obs, info = env.reset()
42
+
43
+ while True:
44
+ action, _ = agent.predict(obs, deterministic=False)
45
+
46
+ obs, reward, terminated, truncated, info = env.step(action.tolist())
47
+
48
+ if terminated or truncated:
49
+ obs, info = env.reset()
50
+ if info["env_done"]:
51
+ break
52
+
53
+ # Close the environment
54
+ env.close()
55
+
56
+ # Return success
57
+ return 0
58
+
59
+ if __name__ == "__main__":
60
+ parser = argparse.ArgumentParser()
61
+ parser.add_argument("--cfgFile", type=str, required=True, help="Configuration file")
62
+ parser.add_argument("--trainedModel", type=str, default="model", help="Model checkpoint")
63
+ opt = parser.parse_args()
64
+ print(opt)
65
+
66
+ main(opt.cfgFile, opt.trainedModel)
config.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ folders:
2
+ parent_dir: "./results/"
3
+ model_name: "sr6_128x4_das_nc"
4
+
5
+ settings:
6
+ game_id: "umk3"
7
+ step_ratio: 6
8
+ frame_shape: !!python/tuple [128, 128, 1]
9
+ continue_game: 0.0
10
+ action_space: "discrete"
11
+ characters: "Skorpion"
12
+ difficulty: 5
13
+
14
+ wrappers_settings:
15
+ normalize_reward: true
16
+ no_attack_buttons_combinations: true
17
+ stack_frames: 4
18
+ dilation: 1
19
+ add_last_action: true
20
+ stack_actions: 12
21
+ scale: true
22
+ exclude_image_scaling: true
23
+ role_relative: true
24
+ flatten: true
25
+ filter_keys: ["action", "own_health", "opp_health", "own_side", "opp_side", "opp_character", "stage", "timer"]
26
+
27
+ # optuna results
28
+ # Best hyperparameters: {'gamma': 0.05944028113410932, 'max_grad_norm': 3.5407661656818026,
29
+ # 'exponent_n_steps': 5, 'n_epochs': 14, 'batch_size': 512, 'lr': 0.014638860976621421,
30
+ # 'ent_coef': 2.361611947920214e-06, 'clip_range': 0.3, 'gae_lambda': 0.9520674913500098,
31
+ # 'vf_coef': 0.6420316461542878, 'net_arch': 'medium', 'activation_fn': 'leaky_relu'}
32
+
33
+ policy_kwargs:
34
+ #net_arch: [{ pi: [64, 64], vf: [32, 32] }]
35
+ net_arch: [256, 256]
36
+ activation_fn: "leaky_relu"
37
+
38
+ ppo_settings:
39
+ gamma: 0.94
40
+ model_checkpoint: "660000" # 0: No checkpoint, else: Load checkpoint (if previously trained)
41
+ learning_rate: [1.0e-3, 2.5e-6] # To start
42
+ clip_range: [0.3, 0.015] # To start
43
+ batch_size: 512 #8 #nminibatches gave different batch size depending on the number of environments: batch_size = (n_steps * n_envs) // nminibatches
44
+ n_epochs: 14
45
+ n_steps: 512
46
+ gae_lambda: 0.9520674913500098
47
+ ent_coef: 2.361611947920214e-06
48
+ vf_coef: 0.6420316461542878
49
+ autosave_freq: 50000
50
+ time_steps: 1000000
replay.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01dbbd2ee0288f38b4d7c41e2438878fa5d2aedb77ec251c0fc97eadfd852dbc
3
+ size 7377696