import logging import pandas as pd from zenml.client import Client from zenml import step from src.model_dev import LinearRegressionModel from sklearn.base import RegressorMixin import mlflow from .config import ModelNameConfig experiment_tracker = Client().active_stack.experiment_tracker @step(experiment_tracker=experiment_tracker.name) def train_model( X_train: pd.DataFrame, X_test: pd.DataFrame, y_train: pd.DataFrame, y_test: pd.DataFrame, config: ModelNameConfig ) -> RegressorMixin: """ Trains the model on ingested data Args: df: the ingested data """ try: model = None if config.model_name == "LinearRegression": mlflow.sklearn.autolog() model = LinearRegressionModel() trained_model = model.train(X_train, y_train) return trained_model else: raise ValueError(f"Model {config.model_name} not supported") except Exception as e: logging.error(f"Error in training model: {e}") raise e