File size: 764 Bytes
8b414b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from src.cross_validate import CrossValidation
from src.data_reader import load_train_test_df
from src.solutions.constant_predictor import ConstantPredictorSolution
from src.utils import pandas_set_print_options

pandas_set_print_options()


def test_cross_validation():
    train_df, _ = load_train_test_df(is_testing=True)

    x_columns = ['text_id', 'full_text']
    X, y = train_df[x_columns], train_df.drop(columns=['full_text'])

    n_splits = 3
    cv = CrossValidation(saving_dir='/tmp/sdfjsld', n_splits=n_splits)
    predictor = ConstantPredictorSolution()
    cv_scores = cv.fit(predictor, X, y)
    assert cv_scores.shape == (n_splits + 1, 6)

    prediction_df = cv.predict(train_df[x_columns])

    assert prediction_df.shape == (len(train_df), 7)