linguask / tests /test_cross_validation.py
GitHub Action
refs/heads/ci-cd/hugging-face
8b414b0
raw
history blame contribute delete
No virus
764 Bytes
from src.cross_validate import CrossValidation
from src.data_reader import load_train_test_df
from src.solutions.constant_predictor import ConstantPredictorSolution
from src.utils import pandas_set_print_options
pandas_set_print_options()
def test_cross_validation():
train_df, _ = load_train_test_df(is_testing=True)
x_columns = ['text_id', 'full_text']
X, y = train_df[x_columns], train_df.drop(columns=['full_text'])
n_splits = 3
cv = CrossValidation(saving_dir='/tmp/sdfjsld', n_splits=n_splits)
predictor = ConstantPredictorSolution()
cv_scores = cv.fit(predictor, X, y)
assert cv_scores.shape == (n_splits + 1, 6)
prediction_df = cv.predict(train_df[x_columns])
assert prediction_df.shape == (len(train_df), 7)