Upload 3 files
Browse files- .gitattributes +1 -0
- agent.py +66 -0
- config.yaml +50 -0
- replay.mp4 +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
replay.mp4 filter=lfs diff=lfs merge=lfs -text
|
agent.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
from diambra.arena import Roles, SpaceTypes, load_settings_flat_dict
|
6 |
+
from diambra.arena.stable_baselines3.make_sb3_env import make_sb3_env, EnvironmentSettings, WrappersSettings
|
7 |
+
from stable_baselines3 import PPO
|
8 |
+
|
9 |
+
def main(cfg_file, trained_model):
|
10 |
+
# Read the cfg file
|
11 |
+
yaml_file = open(cfg_file)
|
12 |
+
params = yaml.load(yaml_file, Loader=yaml.FullLoader)
|
13 |
+
print("Config parameters = ", json.dumps(params, sort_keys=True, indent=4))
|
14 |
+
yaml_file.close()
|
15 |
+
|
16 |
+
base_path = os.path.dirname(os.path.abspath(__file__))
|
17 |
+
model_folder = os.path.join(base_path, params["folders"]["parent_dir"], params["settings"]["game_id"],
|
18 |
+
params["folders"]["model_name"], "model")
|
19 |
+
|
20 |
+
# Settings
|
21 |
+
params["settings"]["action_space"] = SpaceTypes.DISCRETE if params["settings"]["action_space"] == "discrete" else SpaceTypes.MULTI_DISCRETE
|
22 |
+
settings = load_settings_flat_dict(EnvironmentSettings, params["settings"])
|
23 |
+
settings.role = Roles.P1
|
24 |
+
|
25 |
+
# Wrappers Settings
|
26 |
+
wrappers_settings = load_settings_flat_dict(WrappersSettings, params["wrappers_settings"])
|
27 |
+
wrappers_settings.normalize_reward = False
|
28 |
+
|
29 |
+
# Create environment
|
30 |
+
env, num_envs = make_sb3_env(settings.game_id, settings, wrappers_settings, no_vec=True)
|
31 |
+
print("Activated {} environment(s)".format(num_envs))
|
32 |
+
|
33 |
+
# Load the trained agent
|
34 |
+
model_path = os.path.join(model_folder, trained_model)
|
35 |
+
agent = PPO.load(model_path)
|
36 |
+
|
37 |
+
# Print policy network architecture
|
38 |
+
print("Policy architecture:")
|
39 |
+
print(agent.policy)
|
40 |
+
|
41 |
+
obs, info = env.reset()
|
42 |
+
|
43 |
+
while True:
|
44 |
+
action, _ = agent.predict(obs, deterministic=False)
|
45 |
+
|
46 |
+
obs, reward, terminated, truncated, info = env.step(action.tolist())
|
47 |
+
|
48 |
+
if terminated or truncated:
|
49 |
+
obs, info = env.reset()
|
50 |
+
if info["env_done"]:
|
51 |
+
break
|
52 |
+
|
53 |
+
# Close the environment
|
54 |
+
env.close()
|
55 |
+
|
56 |
+
# Return success
|
57 |
+
return 0
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument("--cfgFile", type=str, required=True, help="Configuration file")
|
62 |
+
parser.add_argument("--trainedModel", type=str, default="model", help="Model checkpoint")
|
63 |
+
opt = parser.parse_args()
|
64 |
+
print(opt)
|
65 |
+
|
66 |
+
main(opt.cfgFile, opt.trainedModel)
|
config.yaml
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
folders:
|
2 |
+
parent_dir: "./results/"
|
3 |
+
model_name: "sr6_128x4_das_nc"
|
4 |
+
|
5 |
+
settings:
|
6 |
+
game_id: "umk3"
|
7 |
+
step_ratio: 6
|
8 |
+
frame_shape: !!python/tuple [128, 128, 1]
|
9 |
+
continue_game: 0.0
|
10 |
+
action_space: "discrete"
|
11 |
+
characters: "Skorpion"
|
12 |
+
difficulty: 5
|
13 |
+
|
14 |
+
wrappers_settings:
|
15 |
+
normalize_reward: true
|
16 |
+
no_attack_buttons_combinations: true
|
17 |
+
stack_frames: 4
|
18 |
+
dilation: 1
|
19 |
+
add_last_action: true
|
20 |
+
stack_actions: 12
|
21 |
+
scale: true
|
22 |
+
exclude_image_scaling: true
|
23 |
+
role_relative: true
|
24 |
+
flatten: true
|
25 |
+
filter_keys: ["action", "own_health", "opp_health", "own_side", "opp_side", "opp_character", "stage", "timer"]
|
26 |
+
|
27 |
+
# optuna results
|
28 |
+
# Best hyperparameters: {'gamma': 0.05944028113410932, 'max_grad_norm': 3.5407661656818026,
|
29 |
+
# 'exponent_n_steps': 5, 'n_epochs': 14, 'batch_size': 512, 'lr': 0.014638860976621421,
|
30 |
+
# 'ent_coef': 2.361611947920214e-06, 'clip_range': 0.3, 'gae_lambda': 0.9520674913500098,
|
31 |
+
# 'vf_coef': 0.6420316461542878, 'net_arch': 'medium', 'activation_fn': 'leaky_relu'}
|
32 |
+
|
33 |
+
policy_kwargs:
|
34 |
+
#net_arch: [{ pi: [64, 64], vf: [32, 32] }]
|
35 |
+
net_arch: [256, 256]
|
36 |
+
activation_fn: "leaky_relu"
|
37 |
+
|
38 |
+
ppo_settings:
|
39 |
+
gamma: 0.94
|
40 |
+
model_checkpoint: "660000" # 0: No checkpoint, else: Load checkpoint (if previously trained)
|
41 |
+
learning_rate: [1.0e-3, 2.5e-6] # To start
|
42 |
+
clip_range: [0.3, 0.015] # To start
|
43 |
+
batch_size: 512 #8 #nminibatches gave different batch size depending on the number of environments: batch_size = (n_steps * n_envs) // nminibatches
|
44 |
+
n_epochs: 14
|
45 |
+
n_steps: 512
|
46 |
+
gae_lambda: 0.9520674913500098
|
47 |
+
ent_coef: 2.361611947920214e-06
|
48 |
+
vf_coef: 0.6420316461542878
|
49 |
+
autosave_freq: 50000
|
50 |
+
time_steps: 1000000
|
replay.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:01dbbd2ee0288f38b4d7c41e2438878fa5d2aedb77ec251c0fc97eadfd852dbc
|
3 |
+
size 7377696
|