Spaces:
Runtime error
Runtime error
import logging | |
import pandas as pd | |
from zenml.client import Client | |
from zenml import step | |
from src.model_dev import LinearRegressionModel | |
from sklearn.base import RegressorMixin | |
import mlflow | |
from .config import ModelNameConfig | |
experiment_tracker = Client().active_stack.experiment_tracker | |
def train_model( | |
X_train: pd.DataFrame, | |
X_test: pd.DataFrame, | |
y_train: pd.DataFrame, | |
y_test: pd.DataFrame, | |
config: ModelNameConfig | |
) -> RegressorMixin: | |
""" | |
Trains the model on ingested data | |
Args: | |
df: the ingested data | |
""" | |
try: | |
model = None | |
if config.model_name == "LinearRegression": | |
mlflow.sklearn.autolog() | |
model = LinearRegressionModel() | |
trained_model = model.train(X_train, y_train) | |
return trained_model | |
else: | |
raise ValueError(f"Model {config.model_name} not supported") | |
except Exception as e: | |
logging.error(f"Error in training model: {e}") | |
raise e |