ShaswatRobotics's picture
Update iris/atari/alien/config.json
9b5a06c verified
{
"name": "alien_iris_world_model",
"env": "AlienNoFrameskip-v4",
"model_type": "iris",
"metadata": {
"latent_dim": [1, 16],
"num_tokens": 340
},
"util_folders":{
"models": "../../src/models"
},
"requirements":{
"-r": "requirements.txt"
},
"models": [
{
"name": "world_model",
"framework": null,
"format": "state_dict",
"source": {
"weights_path": "world_model.pt",
"class_path": "../../src/world_model.py",
"class_name": "WorldModel",
"class_args": [
{
"vocab_size": 512,
"act_vocab_size": 18,
"tokens_per_block": 17,
"max_blocks": 20,
"attention": "causal",
"num_layers": 10,
"num_heads": 4,
"embed_dim": 256,
"embed_pdrop": 0.1,
"resid_pdrop": 0.1,
"attn_pdrop": 0.1
}]
},
"signature": {
"inputs": ["tokens", "past_keys_values"],
"call_mode": "positional"
},
"sub_models":
[
{
"name": "transformer",
"sub_model_name": "transformer",
"signature":
{
"inputs": ["sequences", "past_keys_values"],
"call_mode": "positional"
}
}
],
"methods":
[
{
"name": "generate_empty_keys_values",
"method_name": "generate_empty_keys_values",
"signature":
{
"inputs": ["n"]
}
}
]
},
{
"name": "tokenizer",
"framework": null,
"format": "state_dict",
"source": {
"weights_path": "tokenizer.pt",
"class_path": "../../src/tokenizer.py",
"class_name": "Tokenizer",
"class_args": [{
"vocab_size": 512,
"embed_dim": 512,
"encoder": {
"resolution": 64,
"in_channels": 3,
"z_channels": 512,
"ch": 64,
"ch_mult": [1, 1, 1, 1, 1],
"num_res_blocks": 2,
"attn_resolutions": [8, 16],
"out_ch": 3,
"dropout": 0.0
},
"decoder": {
"resolution": 64,
"in_channels": 3,
"z_channels": 512,
"ch": 64,
"ch_mult": [1, 1, 1, 1, 1],
"num_res_blocks": 2,
"attn_resolutions": [8, 16],
"out_ch": 3,
"dropout": 0.0
}
}]
},
"signature": {
"inputs": ["x", "should_preprocess", "should_postprocess"],
"call_mode": "positional"
},
"sub_models":
[
{
"name": "embedding",
"sub_model_name": "embedding",
"signature":
{
"call_mode": "auto"
}
}
],
"methods":
[
{
"name": "decode",
"method_name": "decode",
"signature":
{
"inputs": ["z", "should_postprocess"]
}
},
{
"name": "decode_obs_tokens",
"method_name": "decode_obs_tokens",
"signature":
{
"inputs": ["obs_tokens", "num_observations_tokens"]
}
},
{
"name": "encode",
"method_name": "encode",
"signature":
{
"inputs": ["observations", "should_preprocess"]
}
}
]
}
]
}