Sadashiv's picture
Upload 31 files
625ed08 verified
from src.entity.config_entity import ModelPusherConfig
from src.entity import artifact_entity
from src.predictor import ModelResolver
from src.exception import CropException
from src.logger import logging
from src.utils import load_object, save_object
from src.entity.artifact_entity import (
DataTransformationArtifact,
ModelTrainerArtifact,
ModelPusherArtifact,
)
import sys
import os
class ModelPusher:
def __init__(
self,
model_pusher_config: ModelPusherConfig,
data_transformation_artifact: DataTransformationArtifact,
model_trainer_artifact: ModelTrainerArtifact,
):
try:
logging.info(f"{'>'*20} Model Pusher Initiated {'<'*30}")
self.model_pusher_config = model_pusher_config
self.data_transformation_artifact = data_transformation_artifact
self.model_trainer_artifact = model_trainer_artifact
self.model_resolver = ModelResolver(
model_registry=self.model_pusher_config.saved_model_dir
)
except Exception as e:
raise CropException(e, sys)
def initiate_model_pusher(self) -> ModelPusherArtifact:
try:
# load object
logging.info(f"Loading transformer model and target encoder")
transformer = load_object(file_path=self.data_transformation_artifact.transform_object_path)
model = load_object(file_path=self.model_trainer_artifact.model_path)
target_encoder = load_object(file_path=self.data_transformation_artifact.target_encoder_path)
# model pusher dir
logging.info(f"Saving model into model pusher directory")
save_object(file_path=self.model_pusher_config.pusher_transformer_path,obj=transformer)
save_object(file_path=self.model_pusher_config.pusher_model_path, obj=model)
save_object(file_path=self.model_pusher_config.pusher_target_encoder_path, obj=target_encoder)
# saved model dir
logging.info(f"Saving model in saved model dir")
transformer_path = self.model_resolver.get_latest_save_transformer_path()
model_path = self.model_resolver.get_latest_save_model_path()
target_encoder_path = self.model_resolver.get_latest_save_target_encoder_path()
save_object(file_path=transformer_path, obj=transformer)
save_object(file_path=model_path, obj=model)
save_object(file_path=target_encoder_path, obj=target_encoder)
model_pusher_artifact = ModelPusherArtifact(
pusher_model_dir=self.model_pusher_config.pusher_model_dir,
saved_model_dir=self.model_pusher_config.saved_model_dir,
)
logging.info(f"Model Pusher artifact: {model_pusher_artifact}")
return model_pusher_artifact
except Exception as e:
raise CropException(e, sys)