Spaces:
Sleeping
Sleeping
import datetime | |
import shutil | |
import time | |
import hydra | |
import lightning as L | |
from hydra.utils import instantiate | |
from omegaconf import DictConfig, OmegaConf | |
from src.test.utils import evaluate | |
from src.tools.files import json_dump | |
from src.tools.utils import calculate_model_params | |
def main(cfg: DictConfig): | |
L.seed_everything(cfg.seed, workers=True) | |
fabric = instantiate(cfg.trainer.fabric) | |
fabric.launch() | |
fabric.logger.log_hyperparams(OmegaConf.to_container(cfg, resolve=True)) | |
if fabric.global_rank == 0: | |
json_dump(OmegaConf.to_container(cfg, resolve=True), "hydra.json") | |
data = instantiate(cfg.data) | |
loader_train = fabric.setup_dataloaders(data.train_dataloader()) | |
if cfg.val: | |
loader_val = fabric.setup_dataloaders(data.val_dataloader()) | |
model = instantiate(cfg.model) | |
calculate_model_params(model) | |
optimizer = instantiate( | |
cfg.model.optimizer, params=model.parameters(), _partial_=False | |
) | |
model, optimizer = fabric.setup(model, optimizer) | |
scheduler = instantiate(cfg.model.scheduler) | |
fabric.print("Start training") | |
start_time = time.time() | |
for epoch in range(cfg.trainer.max_epochs): | |
scheduler(optimizer, epoch) | |
columns = shutil.get_terminal_size().columns | |
fabric.print("-" * columns) | |
fabric.print(f"Epoch {epoch + 1}/{cfg.trainer.max_epochs}".center(columns)) | |
train(model, loader_train, optimizer, fabric, epoch, cfg) | |
if cfg.val: | |
fabric.print("Evaluate") | |
evaluate(model, loader_val, fabric=fabric) | |
state = { | |
"epoch": epoch, | |
"model": model, | |
"optimizer": optimizer, | |
"scheduler": scheduler, | |
} | |
if cfg.trainer.save_ckpt == "all": | |
fabric.save(f"ckpt_{epoch}.ckpt", state) | |
elif cfg.trainer.save_ckpt == "last": | |
fabric.save("ckpt_last.ckpt", state) | |
fabric.barrier() | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
fabric.print(f"Training time {total_time_str}") | |
for dataset in cfg.test: | |
columns = shutil.get_terminal_size().columns | |
fabric.print("-" * columns) | |
fabric.print(f"Testing on {cfg.test[dataset].dataname}".center(columns)) | |
data = instantiate(cfg.test[dataset]) | |
test_loader = fabric.setup_dataloaders(data.test_dataloader()) | |
test = instantiate(cfg.test[dataset].test) | |
test(model, test_loader, fabric=fabric) | |
fabric.logger.finalize("success") | |
fabric.print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) | |
def train(model, train_loader, optimizer, fabric, epoch, cfg): | |
model.train() | |
for batch_idx, batch in enumerate(train_loader): | |
optimizer.zero_grad() | |
loss = model(batch, fabric) | |
fabric.backward(loss) | |
optimizer.step() | |
if batch_idx % cfg.trainer.print_interval == 0: | |
fabric.print( | |
f"[{100.0 * batch_idx / len(train_loader):.0f}%]\tLoss: {loss.item():.6f}" | |
) | |
if batch_idx % cfg.trainer.log_interval == 0: | |
fabric.log_dict( | |
{ | |
"loss": loss.item(), | |
"lr": optimizer.param_groups[0]["lr"], | |
"epoch": epoch, | |
} | |
) | |
if __name__ == "__main__": | |
main() | |