|
from super_gradients.common.object_names import Models |
|
from super_gradients.training import Trainer |
|
from super_gradients.training import models |
|
from super_gradients.training.dataloaders import coco2017_train_yolo_nas, coco2017_val_yolo_nas |
|
from super_gradients.training.losses import PPYoloELoss |
|
from super_gradients.training.metrics import DetectionMetrics_050 |
|
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback |
|
from super_gradients.training.utils.distributed_training_utils import setup_device |
|
|
|
if __name__ == '__main__': |
|
|
|
model = models.get(Models.YOLO_NAS_S, pretrained_weights='coco', num_classes=2) |
|
|
|
|
|
trainer = Trainer(experiment_name='yolo-nas-plastic', ckpt_root_dir='checkpoints/sg_yolonas') |
|
|
|
|
|
train_dataloader = coco2017_train_yolo_nas( |
|
dataset_params={ |
|
'data_dir': 'dataset', |
|
'subdir': '.', |
|
'json_file': '_annotations.coco.json' |
|
}, |
|
dataloader_params={ |
|
'batch_size': 2, |
|
'num_workers': 0 |
|
} |
|
) |
|
|
|
valid_dataloader = coco2017_val_yolo_nas( |
|
dataset_params={ |
|
'data_dir': 'dataset', |
|
'subdir': '.', |
|
'json_file': '_annotations.coco.json' |
|
}, |
|
dataloader_params={ |
|
'batch_size': 2, |
|
'num_workers': 0 |
|
} |
|
) |
|
|
|
|
|
train_params = { |
|
'average_best_models' : True, |
|
'max_epochs' : 100, |
|
'initial_lr' : 5e-4, |
|
'warmup_mode' : 'LinearEpochLRWarmup', |
|
'warmup_initial_lr' : 1e-6, |
|
'lr_warmup_epochs' : 3, |
|
'lr_mode' : 'cosine', |
|
'cosine_final_lr_ratio' : 0.1, |
|
'optimizer' : 'Adam', |
|
'optimizer_params' : { |
|
'weight_decay' : 0.0001 |
|
}, |
|
'zero_weight_decay_on_bias_and_bn': True, |
|
'ema' : True, |
|
'ema_params' : { |
|
'decay' : 0.9, |
|
'decay_type' : 'threshold' |
|
}, |
|
'mixed_precision' : False, |
|
'loss' : PPYoloELoss( |
|
num_classes=2, |
|
reg_max=16 |
|
), |
|
'valid_metrics_list' : [ |
|
DetectionMetrics_050( |
|
score_thres=0.1, |
|
top_k_predictions=100, |
|
num_cls=2, |
|
normalize_targets=True, |
|
post_prediction_callback=PPYoloEPostPredictionCallback( |
|
score_threshold=0.01, |
|
nms_top_k=1000, |
|
max_predictions=100, |
|
nms_threshold=0.7 |
|
) |
|
) |
|
], |
|
'metric_to_watch': 'mAP@0.50' |
|
} |
|
|
|
trainer.train( |
|
model=model, |
|
training_params=train_params, |
|
train_loader=train_dataloader, |
|
valid_loader=valid_dataloader |
|
) |