kengboon's picture
Upload 39 files
c09670c
raw
history blame
No virus
3.03 kB
from super_gradients.common.object_names import Models
from super_gradients.training import Trainer
from super_gradients.training import models
from super_gradients.training.dataloaders import coco2017_train_yolo_nas, coco2017_val_yolo_nas
from super_gradients.training.losses import PPYoloELoss
from super_gradients.training.metrics import DetectionMetrics_050
from super_gradients.training.models.detection_models.pp_yolo_e import PPYoloEPostPredictionCallback
from super_gradients.training.utils.distributed_training_utils import setup_device
if __name__ == '__main__':
# Download model
model = models.get(Models.YOLO_NAS_S, pretrained_weights='coco', num_classes=2)
# Initialize trainer
trainer = Trainer(experiment_name='yolo-nas-plastic', ckpt_root_dir='checkpoints/sg_yolonas')
# Datasets or dataloaders
train_dataloader = coco2017_train_yolo_nas(
dataset_params={
'data_dir': 'dataset',
'subdir': '.',
'json_file': '_annotations.coco.json'
},
dataloader_params={
'batch_size': 2,
'num_workers': 0
}
)
valid_dataloader = coco2017_val_yolo_nas(
dataset_params={
'data_dir': 'dataset',
'subdir': '.',
'json_file': '_annotations.coco.json'
},
dataloader_params={
'batch_size': 2,
'num_workers': 0
}
)
# Training parameters
train_params = {
'average_best_models' : True,
'max_epochs' : 100,
'initial_lr' : 5e-4,
'warmup_mode' : 'LinearEpochLRWarmup',
'warmup_initial_lr' : 1e-6,
'lr_warmup_epochs' : 3,
'lr_mode' : 'cosine',
'cosine_final_lr_ratio' : 0.1,
'optimizer' : 'Adam',
'optimizer_params' : {
'weight_decay' : 0.0001
},
'zero_weight_decay_on_bias_and_bn': True,
'ema' : True,
'ema_params' : {
'decay' : 0.9,
'decay_type' : 'threshold'
},
'mixed_precision' : False,
'loss' : PPYoloELoss(
num_classes=2,
reg_max=16
),
'valid_metrics_list' : [
DetectionMetrics_050(
score_thres=0.1,
top_k_predictions=100,
num_cls=2, # Include background
normalize_targets=True,
post_prediction_callback=PPYoloEPostPredictionCallback(
score_threshold=0.01,
nms_top_k=1000,
max_predictions=100,
nms_threshold=0.7
)
)
],
'metric_to_watch': 'mAP@0.50'
}
trainer.train(
model=model,
training_params=train_params,
train_loader=train_dataloader,
valid_loader=valid_dataloader
)