Sadashiv's picture
Upload 146 files
17c5137 verified
import os
import sys
import shutil
from ultralytics import YOLO
from src.logger import logging
from src.exception import PlantException
from src.entity.config_entity import ModelTrainerConfig
from src.entity.artifact_entity import ModelTrainerArtifact
from src.entity.artifact_entity import DataIngestionArtifact
class ModelTrainer:
try:
def __init__(self, model_trainer_config:ModelTrainerConfig, data_ingestion_artifact:DataIngestionArtifact):
self.model_trainer_config = model_trainer_config
self.data_ingestion_artifact = data_ingestion_artifact
def initiate_model_trainer(self) -> ModelTrainerArtifact:
logging.info(f"Removing and existing runs directory from previous training")
os.system("rm -rf runs")
model_config_file_name = self.model_trainer_config.weight_name.split('.')[0]
print(model_config_file_name)
os.system(f"yolo task=classify mode=train model={self.model_trainer_config.weight_name} \
data={self.data_ingestion_artifact.dataset_path} epochs={self.model_trainer_config.no_epochs} \
imgsz=128 batch={self.model_trainer_config.batch_size} patience={self.model_trainer_config.patience}")
os.makedirs("custom_model_weights", exist_ok=True)
os.system("cp runs/classify/train/weights/best.pt custom_model_weights/")
os.makedirs(self.model_trainer_config.model_trainer_dir, exist_ok=True)
os.system(f"cp runs/classify/train/weights/best.pt {self.model_trainer_config.model_trainer_dir}/")
model_trainer_artifact = ModelTrainerArtifact(trained_model_file_path="custom_model_weights/best.pt",)
logging.info("Exited initiate_model_trainer method of ModelTrainer class")
logging.info(f"Model trainer artifact: {model_trainer_artifact}")
return model_trainer_artifact
except Exception as e:
raise PlantException(e, sys)