example-model / tmr /config.json
athn-nik's picture
fixes and retrieval model
2b84cc7
raw
history blame
No virus
3.85 kB
{
"data": {
"motion_loader": {
"_target_": "src.data.amass_motion.AMASSMotionLoader_custom",
"base_dir": "datasets/motions/amass_feats",
"normalizer": {
"_target_": "src.data.motion.Normalizer",
"base_dir": "stats/humanml3d/amass_feats",
"eps": 1e-12
},
"nfeats": 135
},
"_target_": "src.data.text_motion.TextMotionDataset",
"path": "datasets/annotations/humanml3d",
"text_to_token_emb": {
"_target_": "src.data.text.TokenEmbeddings",
"path": "datasets/annotations/humanml3d",
"modelname": "distilbert-base-uncased",
"preload": true
},
"text_to_sent_emb": {
"_target_": "src.data.text.SentenceEmbeddings",
"path": "datasets/annotations/humanml3d",
"modelname": "sentence-transformers/all-mpnet-base-v2",
"preload": true
},
"preload": true
},
"model": {
"_target_": "src.model.TMR",
"motion_encoder": {
"_target_": "src.model.ACTORStyleEncoder",
"nfeats": 135,
"vae": true,
"latent_dim": 256,
"ff_size": 1024,
"num_layers": 6,
"num_heads": 4,
"dropout": 0.1,
"activation": "gelu"
},
"text_encoder": {
"_target_": "src.model.ACTORStyleEncoder",
"nfeats": 768,
"vae": true,
"latent_dim": 256,
"ff_size": 1024,
"num_layers": 6,
"num_heads": 4,
"dropout": 0.1,
"activation": "gelu"
},
"motion_decoder": {
"_target_": "src.model.ACTORStyleDecoder",
"nfeats": 135,
"latent_dim": 256,
"ff_size": 1024,
"num_layers": 6,
"num_heads": 4,
"dropout": 0.1,
"activation": "gelu"
},
"vae": true,
"lmd": {
"recons": 1.0,
"latent": 1e-05,
"kl": 1e-05,
"contrastive": 0.1
},
"lr": 0.0001,
"temperature": 0.1,
"threshold_selfsim": 0.8,
"threshold_selfsim_metrics": 0.95
},
"trainer": {
"_target_": "pytorch_lightning.Trainer",
"max_epochs": 1000,
"log_every_n_steps": 50,
"num_sanity_val_steps": 0,
"check_val_every_n_epoch": 1,
"accelerator": "gpu",
"devices": 1,
"callbacks": [
{
"_target_": "pytorch_lightning.callbacks.ModelCheckpoint",
"filename": "latest-{epoch}",
"every_n_epochs": 1,
"save_top_k": 1,
"save_last": true
},
{
"_target_": "pytorch_lightning.callbacks.ModelCheckpoint",
"filename": "latest-{epoch}",
"monitor": "step",
"mode": "max",
"every_n_epochs": 100,
"save_top_k": -1,
"save_last": false
},
{
"_target_": "src.callback.progress.ProgressLogger",
"precision": 3
},
{
"_target_": "src.callback.tqdmbar.TQDMProgressBar"
}
],
"logger": {
"_target_": "src.logger.csv.CSVLogger",
"save_dir": "outputs/tmr_humanml3d_amass_feats",
"name": "logs"
}
},
"run_dir": "outputs/tmr_humanml3d_amass_feats",
"seed": 1234,
"logger_level": "INFO",
"ckpt": "last",
"resume_dir": null,
"dataloader": {
"_target_": "torch.utils.data.DataLoader",
"batch_size": 32,
"num_workers": 8
}
}