Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
""" | |
import datetime | |
import json | |
import time | |
from pathlib import Path | |
import torch | |
import torch.nn as nn | |
from ..misc import dist_utils | |
from ._solver import BaseSolver | |
from .clas_engine import evaluate, train_one_epoch | |
class ClasSolver(BaseSolver): | |
def fit( | |
self, | |
): | |
print("Start training") | |
self.train() | |
args = self.cfg | |
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
print("Number of params:", n_parameters) | |
output_dir = Path(args.output_dir) | |
output_dir.mkdir(exist_ok=True) | |
start_time = time.time() | |
start_epoch = self.last_epoch + 1 | |
for epoch in range(start_epoch, args.epochs): | |
if dist_utils.is_dist_available_and_initialized(): | |
self.train_dataloader.sampler.set_epoch(epoch) | |
train_stats = train_one_epoch( | |
self.model, | |
self.criterion, | |
self.train_dataloader, | |
self.optimizer, | |
self.ema, | |
epoch=epoch, | |
device=self.device, | |
) | |
self.lr_scheduler.step() | |
self.last_epoch += 1 | |
if output_dir: | |
checkpoint_paths = [output_dir / "checkpoint.pth"] | |
# extra checkpoint before LR drop and every 100 epochs | |
if (epoch + 1) % args.checkpoint_freq == 0: | |
checkpoint_paths.append(output_dir / f"checkpoint{epoch:04}.pth") | |
for checkpoint_path in checkpoint_paths: | |
dist_utils.save_on_master(self.state_dict(epoch), checkpoint_path) | |
module = self.ema.module if self.ema else self.model | |
test_stats = evaluate(module, self.criterion, self.val_dataloader, self.device) | |
log_stats = { | |
**{f"train_{k}": v for k, v in train_stats.items()}, | |
**{f"test_{k}": v for k, v in test_stats.items()}, | |
"epoch": epoch, | |
"n_parameters": n_parameters, | |
} | |
if output_dir and dist_utils.is_main_process(): | |
with (output_dir / "log.txt").open("a") as f: | |
f.write(json.dumps(log_stats) + "\n") | |
total_time = time.time() - start_time | |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
print("Training time {}".format(total_time_str)) | |