film_mlmodule / steps /model_train.py
root
first hf commit
6defa3d
raw
history blame contribute delete
No virus
1.11 kB
import logging
import pandas as pd
from zenml.client import Client
from zenml import step
from src.model_dev import LinearRegressionModel
from sklearn.base import RegressorMixin
import mlflow
from .config import ModelNameConfig
experiment_tracker = Client().active_stack.experiment_tracker
@step(experiment_tracker=experiment_tracker.name)
def train_model(
X_train: pd.DataFrame,
X_test: pd.DataFrame,
y_train: pd.DataFrame,
y_test: pd.DataFrame,
config: ModelNameConfig
) -> RegressorMixin:
"""
Trains the model on ingested data
Args:
df: the ingested data
"""
try:
model = None
if config.model_name == "LinearRegression":
mlflow.sklearn.autolog()
model = LinearRegressionModel()
trained_model = model.train(X_train, y_train)
return trained_model
else:
raise ValueError(f"Model {config.model_name} not supported")
except Exception as e:
logging.error(f"Error in training model: {e}")
raise e