Create baseline.py
Browse files- baseline.py +57 -0
baseline.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
import fire
|
5 |
+
import numpy as np
|
6 |
+
from scipy import sparse
|
7 |
+
|
8 |
+
from sklearn.model_selection import PredefinedSplit, GridSearchCV
|
9 |
+
from sklearn.linear_model import LogisticRegression
|
10 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
11 |
+
|
12 |
+
def _load_split(data_dir, source, split, n=np.inf):
|
13 |
+
path = os.path.join(data_dir, f'{source}.{split}.jsonl')
|
14 |
+
texts = []
|
15 |
+
for i, line in enumerate(open(path)):
|
16 |
+
if i >= n:
|
17 |
+
break
|
18 |
+
texts.append(json.loads(line)['text'])
|
19 |
+
return texts
|
20 |
+
|
21 |
+
def load_split(data_dir, source, split, n=np.inf):
|
22 |
+
webtext = _load_split(data_dir, 'webtext', split, n=n//2)
|
23 |
+
gen = _load_split(data_dir, source, split, n=n//2)
|
24 |
+
texts = webtext+gen
|
25 |
+
labels = [0]*len(webtext)+[1]*len(gen)
|
26 |
+
return texts, labels
|
27 |
+
|
28 |
+
def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=10000, n_jobs=None, verbose=False):
|
29 |
+
train_texts, train_labels = load_split(data_dir, source, 'train', n=n_train)
|
30 |
+
valid_texts, valid_labels = load_split(data_dir, source, 'valid', n=n_valid)
|
31 |
+
test_texts, test_labels = load_split(data_dir, source, 'test')
|
32 |
+
|
33 |
+
vect = TfidfVectorizer(ngram_range=(1, 2), min_df=5, max_features=2**21)
|
34 |
+
train_features = vect.fit_transform(train_texts)
|
35 |
+
valid_features = vect.transform(valid_texts)
|
36 |
+
test_features = vect.transform(test_texts)
|
37 |
+
|
38 |
+
model = LogisticRegression(solver='liblinear')
|
39 |
+
params = {'C': [1/64, 1/32, 1/16, 1/8, 1/4, 1/2, 1, 2, 4, 8, 16, 32, 64]}
|
40 |
+
split = PredefinedSplit([-1]*n_train+[0]*n_valid)
|
41 |
+
search = GridSearchCV(model, params, cv=split, n_jobs=n_jobs, verbose=verbose, refit=False)
|
42 |
+
search.fit(sparse.vstack([train_features, valid_features]), train_labels+valid_labels)
|
43 |
+
model = model.set_params(**search.best_params_)
|
44 |
+
model.fit(train_features, train_labels)
|
45 |
+
valid_accuracy = model.score(valid_features, valid_labels)*100.
|
46 |
+
test_accuracy = model.score(test_features, test_labels)*100.
|
47 |
+
data = {
|
48 |
+
'source':source,
|
49 |
+
'n_train':n_train,
|
50 |
+
'valid_accuracy':valid_accuracy,
|
51 |
+
'test_accuracy':test_accuracy
|
52 |
+
}
|
53 |
+
print(data)
|
54 |
+
json.dump(data, open(os.path.join(log_dir, f'{source}.json'), 'w'))
|
55 |
+
|
56 |
+
if __name__ == '__main__':
|
57 |
+
fire.Fire(main)
|