tiny_clip / src /utils.py
sachin's picture
succesful local run
c6fe3c5
raw
history blame
No virus
974 Bytes
import datetime
import pytorch_lightning as pl
from pytorch_lightning import loggers
from src import config
def _get_wandb_logger(trainer_config: config.TrainerConfig):
name = f"{config.MODEL_NAME}-{datetime.datetime.now()}"
if trainer_config.debug:
name = "debug-" + name
return loggers.WandbLogger(
entity=config.WANDB_ENTITY,
save_dir=config.WANDB_LOG_PATH,
project=config.MODEL_NAME,
name=name,
config=trainer_config._model_config.to_dict(),
)
def get_trainer(trainer_config: config.TrainerConfig):
return pl.Trainer(
max_epochs=trainer_config.epochs if not trainer_config.debug else 1,
logger=_get_wandb_logger(trainer_config),
log_every_n_steps=trainer_config.log_every_n_steps,
gradient_clip_val=1.0,
limit_train_batches=5 if trainer_config.debug else 1.0,
limit_val_batches=5 if trainer_config.debug else 1.0,
accelerator="auto",
)