from super_gradients.common.object_names import Models from super_gradients.training import Trainer from super_gradients.training import models from super_gradients.training.dataloaders import coco2017_train_yolo_nas, coco2017_val_yolo_nas from super_gradients.training.losses import PPYoloELoss from super_gradients.training.metrics import DetectionMetrics_050 from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback from super_gradients.training.utils.distributed_training_utils import setup_device if __name__ == '__main__': # Download model model = models.get(Models.YOLO_NAS_S, pretrained_weights='coco', num_classes=2) # Initialize trainer trainer = Trainer(experiment_name='yolo-nas-plastic', ckpt_root_dir='checkpoints/sg_yolonas') # Datasets or dataloaders train_dataloader = coco2017_train_yolo_nas( dataset_params={ 'data_dir': 'dataset', 'subdir': '.', 'json_file': '_annotations.coco.json' }, dataloader_params={ 'batch_size': 2, 'num_workers': 0 } ) valid_dataloader = coco2017_val_yolo_nas( dataset_params={ 'data_dir': 'dataset', 'subdir': '.', 'json_file': '_annotations.coco.json' }, dataloader_params={ 'batch_size': 2, 'num_workers': 0 } ) # Training parameters train_params = { 'average_best_models' : True, 'max_epochs' : 100, 'initial_lr' : 5e-4, 'warmup_mode' : 'LinearEpochLRWarmup', 'warmup_initial_lr' : 1e-6, 'lr_warmup_epochs' : 3, 'lr_mode' : 'cosine', 'cosine_final_lr_ratio' : 0.1, 'optimizer' : 'Adam', 'optimizer_params' : { 'weight_decay' : 0.0001 }, 'zero_weight_decay_on_bias_and_bn': True, 'ema' : True, 'ema_params' : { 'decay' : 0.9, 'decay_type' : 'threshold' }, 'mixed_precision' : False, 'loss' : PPYoloELoss( num_classes=2, reg_max=16 ), 'valid_metrics_list' : [ DetectionMetrics_050( score_thres=0.1, top_k_predictions=100, num_cls=2, # Include background normalize_targets=True, post_prediction_callback=PPYoloEPostPredictionCallback( score_threshold=0.01, nms_top_k=1000, max_predictions=100, nms_threshold=0.7 ) ) ], 'metric_to_watch': 'mAP@0.50' } trainer.train( model=model, training_params=train_params, train_loader=train_dataloader, valid_loader=valid_dataloader )