Spaces:
Runtime error
Runtime error
File size: 1,111 Bytes
6defa3d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import logging
import pandas as pd
from zenml.client import Client
from zenml import step
from src.model_dev import LinearRegressionModel
from sklearn.base import RegressorMixin
import mlflow
from .config import ModelNameConfig
experiment_tracker = Client().active_stack.experiment_tracker
@step(experiment_tracker=experiment_tracker.name)
def train_model(
X_train: pd.DataFrame,
X_test: pd.DataFrame,
y_train: pd.DataFrame,
y_test: pd.DataFrame,
config: ModelNameConfig
) -> RegressorMixin:
"""
Trains the model on ingested data
Args:
df: the ingested data
"""
try:
model = None
if config.model_name == "LinearRegression":
mlflow.sklearn.autolog()
model = LinearRegressionModel()
trained_model = model.train(X_train, y_train)
return trained_model
else:
raise ValueError(f"Model {config.model_name} not supported")
except Exception as e:
logging.error(f"Error in training model: {e}")
raise e |