taim-gan / src /config.py
Dmmc's picture
three-model version
c8ddb9b
raw
history blame contribute delete
No virus
1.08 kB
"""Configurations for the project."""
from pathlib import Path
from typing import Any, Dict
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
repo_path = Path(__file__).parent.parent.absolute()
output_path = repo_path / "models"
config_dict = {
"Ng": 32,
"D": 256,
"condition_dim": 100,
"noise_dim": 100,
"lr_config": {
"disc_lr": 2e-4,
"gen_lr": 2e-4,
"img_encoder_lr": 3e-3,
"text_encoder_lr": 3e-3,
},
"batch_size": 64,
"device": device,
"epochs": 200,
"output_dir": output_path,
"snapshot": 5,
"const_dict": {
"smooth_val_gen": 0.999,
"lambda1": 1,
"lambda2": 1,
"lambda3": 1,
"lambda4": 1,
"gamma1": 4,
"gamma2": 5,
"gamma3": 10,
},
}
def update_config(cfg_dict: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
"""
Function to update the configuration dictionary.
"""
for key, value in kwargs.items():
cfg_dict[key] = value
return cfg_dict