sooks commited on
Commit
b87b286
1 Parent(s): fba8f24

Create baseline.py

Browse files
Files changed (1) hide show
  1. baseline.py +57 -0
baseline.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import fire
5
+ import numpy as np
6
+ from scipy import sparse
7
+
8
+ from sklearn.model_selection import PredefinedSplit, GridSearchCV
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.feature_extraction.text import TfidfVectorizer
11
+
12
+ def _load_split(data_dir, source, split, n=np.inf):
13
+ path = os.path.join(data_dir, f'{source}.{split}.jsonl')
14
+ texts = []
15
+ for i, line in enumerate(open(path)):
16
+ if i >= n:
17
+ break
18
+ texts.append(json.loads(line)['text'])
19
+ return texts
20
+
21
+ def load_split(data_dir, source, split, n=np.inf):
22
+ webtext = _load_split(data_dir, 'webtext', split, n=n//2)
23
+ gen = _load_split(data_dir, source, split, n=n//2)
24
+ texts = webtext+gen
25
+ labels = [0]*len(webtext)+[1]*len(gen)
26
+ return texts, labels
27
+
28
+ def main(data_dir, log_dir, source='xl-1542M-k40', n_train=500000, n_valid=10000, n_jobs=None, verbose=False):
29
+ train_texts, train_labels = load_split(data_dir, source, 'train', n=n_train)
30
+ valid_texts, valid_labels = load_split(data_dir, source, 'valid', n=n_valid)
31
+ test_texts, test_labels = load_split(data_dir, source, 'test')
32
+
33
+ vect = TfidfVectorizer(ngram_range=(1, 2), min_df=5, max_features=2**21)
34
+ train_features = vect.fit_transform(train_texts)
35
+ valid_features = vect.transform(valid_texts)
36
+ test_features = vect.transform(test_texts)
37
+
38
+ model = LogisticRegression(solver='liblinear')
39
+ params = {'C': [1/64, 1/32, 1/16, 1/8, 1/4, 1/2, 1, 2, 4, 8, 16, 32, 64]}
40
+ split = PredefinedSplit([-1]*n_train+[0]*n_valid)
41
+ search = GridSearchCV(model, params, cv=split, n_jobs=n_jobs, verbose=verbose, refit=False)
42
+ search.fit(sparse.vstack([train_features, valid_features]), train_labels+valid_labels)
43
+ model = model.set_params(**search.best_params_)
44
+ model.fit(train_features, train_labels)
45
+ valid_accuracy = model.score(valid_features, valid_labels)*100.
46
+ test_accuracy = model.score(test_features, test_labels)*100.
47
+ data = {
48
+ 'source':source,
49
+ 'n_train':n_train,
50
+ 'valid_accuracy':valid_accuracy,
51
+ 'test_accuracy':test_accuracy
52
+ }
53
+ print(data)
54
+ json.dump(data, open(os.path.join(log_dir, f'{source}.json'), 'w'))
55
+
56
+ if __name__ == '__main__':
57
+ fire.Fire(main)