mohdelgaar commited on
Commit
20b7679
·
0 Parent(s):

Initial commit

Browse files
Files changed (9) hide show
  1. .gitattributes +36 -0
  2. README.md +10 -0
  3. app.py +42 -0
  4. compute_lng.py +57 -0
  5. const.py +1053 -0
  6. demo.py +371 -0
  7. model.py +696 -0
  8. options.py +158 -0
  9. requirements.txt +10 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.npy filter=lfs diff=lfs merge=lfs -text
2
+ *.jar filter=lfs diff=lfs merge=lfs -text
3
+ *.7z filter=lfs diff=lfs merge=lfs -text
4
+ *.arrow filter=lfs diff=lfs merge=lfs -text
5
+ *.bin filter=lfs diff=lfs merge=lfs -text
6
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
7
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
8
+ *.ftz filter=lfs diff=lfs merge=lfs -text
9
+ *.gz filter=lfs diff=lfs merge=lfs -text
10
+ *.h5 filter=lfs diff=lfs merge=lfs -text
11
+ *.joblib filter=lfs diff=lfs merge=lfs -text
12
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
13
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
14
+ *.model filter=lfs diff=lfs merge=lfs -text
15
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
16
+ *.npz filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.state filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LingConv
3
+ emoji: 🔁
4
+ colorFrom: indigo
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 4.40.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ nltk.download('wordnet')
3
+ import spacy
4
+ spacy.cli.download('en_core_web_sm')
5
+ from const import name_map
6
+ from demo import run_gradio
7
+ from model import EncoderDecoderVAE
8
+ from options import parse_args
9
+ import numpy as np
10
+ from transformers import T5Tokenizer
11
+ import torch
12
+ import joblib
13
+ import pandas as pd
14
+
15
+
16
+ def process_examples(samples, full_names):
17
+ for i in range(len(samples)):
18
+ sample = samples[i]
19
+ input_text = tokenizer.decode(sample['sentence1_input_ids'], skip_special_tokens=True)
20
+ ling1 = scaler.inverse_transform([sample['sentence1_ling']])[0]
21
+ ling2 = scaler.inverse_transform([sample['sentence2_ling']])[0]
22
+ ling = pd.DataFrame({'Index': full_names, 'Source': ling1, 'Target': ling2})
23
+ samples[i] = [input_text, ling]
24
+ return list(samples)
25
+
26
+ args, args_list, lng_names = parse_args(ckpt='./ckpt/model.pt')
27
+
28
+ tokenizer = T5Tokenizer.from_pretrained(args.model_name)
29
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+
31
+ scaler = joblib.load('assets/scaler.bin')
32
+ full_names = [name_map[x] for x in lng_names]
33
+ samples = joblib.load('assets/samples.bin')
34
+ examples = process_examples(samples, full_names)
35
+ ling_collection = np.load('assets/ling_collection.npy')
36
+
37
+ model = EncoderDecoderVAE(args, tokenizer.pad_token_id, tokenizer.get_vocab()['</s>']).to(device)
38
+ state = torch.load(args.ckpt, map_location=torch.device('cpu'))
39
+ model.load_state_dict(state['model'], strict=False)
40
+ model.eval()
41
+
42
+ run_gradio(model, tokenizer, scaler, ling_collection, examples, full_names)
compute_lng.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lng.lca.lc_anc import lca
2
+ from lng.L2SCA.analyzeText import sca
3
+ import lftk
4
+ import spacy
5
+ nlp = spacy.load("en_core_web_sm")
6
+
7
+ def extract_lingfeat(text):
8
+ from lingfeat import extractor
9
+ LingFeat = extractor.pass_text(text)
10
+ LingFeat.preprocess()
11
+
12
+ d = {}
13
+ d.update(LingFeat.WoKF_()) # Wikipedia Knowledge Features
14
+ d.update(LingFeat.WBKF_()) # WeeBit Corpus Knowledge Features
15
+ d.update(LingFeat.OSKF_()) # OneStopEng Corpus Knowledge Features
16
+
17
+ # Discourse (Disco) Features
18
+ d.update(LingFeat.EnDF_()) # Entity Density Features
19
+ d.update(LingFeat.EnGF_()) # Entity Grid Features
20
+
21
+ # Syntactic (Synta) Features
22
+ # d.update(LingFeat.PhrF_()) # Noun/Verb/Adj/Adv/... Phrasal Features (logging stanza)
23
+ # d.update(LingFeat.TrSF_()) # (Parse) Tree Structural Features (logging stanza)
24
+ d.update(LingFeat.POSF_()) # Noun/Verb/Adj/Adv/... Part-of-Speech Features
25
+
26
+ # Lexico Semantic (LxSem) Features
27
+ d.update(LingFeat.TTRF_()) # Type Token Ratio Features
28
+ d.update(LingFeat.VarF_()) # Noun/Verb/Adj/Adv Variation Features
29
+ d.update(LingFeat.PsyF_()) # Psycholinguistic Difficulty of Words (AoA Kuperman)
30
+ d.update(LingFeat.WorF_()) # Word Familiarity from Frequency Count (SubtlexUS)
31
+
32
+ # Shallow Traditional (ShTra) Features
33
+ d.update(LingFeat.ShaF_()) # Shallow Features (e.g. avg number of tokens)
34
+ d.update(LingFeat.TraF_()) # Traditional Formulas
35
+
36
+ return list(d.values())
37
+
38
+
39
+ def extract_lftk(text):
40
+ if text == '':
41
+ return [0.] * 220
42
+ doc = nlp(text)
43
+ LFTK = lftk.Extractor(doc)
44
+
45
+ feats = LFTK.extract()
46
+ return list(feats.values())
47
+
48
+ def compute_lng(text, shortcut = False):
49
+ lca_feats = lca(text)
50
+ if shortcut:
51
+ sca_feats = [0] * 23
52
+ else:
53
+ sca_feats = sca(text)
54
+ lftk = extract_lftk(text)
55
+ all_feats = lca_feats + sca_feats + lftk
56
+
57
+ return all_feats
const.py ADDED
@@ -0,0 +1,1053 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+
3
+ sca_names = "W,S,VP,C,T,DC,CT,CP,CN,MLS,MLT,MLC,C-S,VP-T,C-T,DC-C,DC-T,T-S,\
4
+ CT-T,CP-T,CP-C,CN-T,CN-C".split(',')
5
+ lca_names = "wordtypes,swordtypes,lextypes,slextypes,wordtokens,swordtokens,\
6
+ lextokens,slextokens,ld,ls1,ls2,vs1,vs2,cvs1,ndw,ndwz,ndwerz,ndwesz,ttr,\
7
+ msttr,cttr,rttr,logttr,uber,lv,vv1,svv1,cvv1,vv2,nv,adjv,advv,modv".split(',')
8
+
9
+ lftk_names = [
10
+ 't_word', 't_stopword', 't_punct', 't_syll', 't_syll2', 't_syll3', 't_uword', 't_sent', 't_char', 'a_word_ps', 'a_char_ps',
11
+ 'a_char_pw', 'a_syll_ps', 'a_syll_pw', 'a_stopword_ps', 'a_stopword_pw', 't_kup', 't_bry', 't_subtlex_us_zipf', 'a_kup_pw',
12
+ 'a_bry_pw', 'a_kup_ps', 'a_bry_ps', 'a_subtlex_us_zipf_pw', 'a_subtlex_us_zipf_ps', 't_n_ent', 't_n_ent_person', 't_n_ent_norp',
13
+ 't_n_ent_fac', 't_n_ent_org', 't_n_ent_gpe', 't_n_ent_loc', 't_n_ent_product', 't_n_ent_event', 't_n_ent_art', 't_n_ent_law',
14
+ 't_n_ent_language', 't_n_ent_date', 't_n_ent_time', 't_n_ent_percent', 't_n_ent_money', 't_n_ent_quantity', 't_n_ent_ordinal',
15
+ 't_n_ent_cardinal', 'a_n_ent_pw', 'a_n_ent_person_pw', 'a_n_ent_norp_pw', 'a_n_ent_fac_pw', 'a_n_ent_org_pw', 'a_n_ent_gpe_pw',
16
+ 'a_n_ent_loc_pw', 'a_n_ent_product_pw', 'a_n_ent_event_pw', 'a_n_ent_art_pw', 'a_n_ent_law_pw', 'a_n_ent_language_pw',
17
+ 'a_n_ent_date_pw', 'a_n_ent_time_pw', 'a_n_ent_percent_pw', 'a_n_ent_money_pw', 'a_n_ent_quantity_pw', 'a_n_ent_ordinal_pw',
18
+ 'a_n_ent_cardinal_pw', 'a_n_ent_ps', 'a_n_ent_person_ps', 'a_n_ent_norp_ps', 'a_n_ent_fac_ps', 'a_n_ent_org_ps', 'a_n_ent_gpe_ps',
19
+ 'a_n_ent_loc_ps', 'a_n_ent_product_ps', 'a_n_ent_event_ps', 'a_n_ent_art_ps', 'a_n_ent_law_ps', 'a_n_ent_language_ps',
20
+ 'a_n_ent_date_ps', 'a_n_ent_time_ps', 'a_n_ent_percent_ps', 'a_n_ent_money_ps', 'a_n_ent_quantity_ps', 'a_n_ent_ordinal_ps',
21
+ 'a_n_ent_cardinal_ps', 'simp_adj_var', 'simp_adp_var', 'simp_adv_var', 'simp_aux_var', 'simp_cconj_var', 'simp_det_var',
22
+ 'simp_intj_var', 'simp_noun_var', 'simp_num_var', 'simp_part_var', 'simp_pron_var', 'simp_propn_var', 'simp_punct_var',
23
+ 'simp_sconj_var', 'simp_sym_var', 'simp_verb_var', 'simp_space_var', 'root_adj_var', 'root_adp_var', 'root_adv_var', 'root_aux_var',
24
+ 'root_cconj_var', 'root_det_var', 'root_intj_var', 'root_noun_var', 'root_num_var', 'root_part_var', 'root_pron_var', 'root_propn_var',
25
+ 'root_punct_var', 'root_sconj_var', 'root_sym_var', 'root_verb_var', 'root_space_var', 'corr_adj_var', 'corr_adp_var', 'corr_adv_var',
26
+ 'corr_aux_var', 'corr_cconj_var', 'corr_det_var', 'corr_intj_var', 'corr_noun_var', 'corr_num_var', 'corr_part_var', 'corr_pron_var',
27
+ 'corr_propn_var', 'corr_punct_var', 'corr_sconj_var', 'corr_sym_var', 'corr_verb_var', 'corr_space_var', 'simp_ttr', 'root_ttr',
28
+ 'corr_ttr', 'bilog_ttr', 'uber_ttr', 'simp_ttr_no_lem', 'root_ttr_no_lem', 'corr_ttr_no_lem', 'bilog_ttr_no_lem', 'uber_ttr_no_lem',
29
+ 'n_adj', 'n_adp', 'n_adv', 'n_aux', 'n_cconj', 'n_det', 'n_intj', 'n_noun', 'n_num', 'n_part', 'n_pron', 'n_propn', 'n_punct',
30
+ 'n_sconj', 'n_sym', 'n_verb', 'n_space', 'n_uadj', 'n_uadp', 'n_uadv', 'n_uaux', 'n_ucconj', 'n_udet', 'n_uintj', 'n_unoun',
31
+ 'n_unum', 'n_upart', 'n_upron', 'n_upropn', 'n_upunct', 'n_usconj', 'n_usym', 'n_uverb', 'n_uspace', 'a_adj_pw', 'a_adp_pw',
32
+ 'a_adv_pw', 'a_aux_pw', 'a_cconj_pw', 'a_det_pw', 'a_intj_pw', 'a_noun_pw', 'a_num_pw', 'a_part_pw', 'a_pron_pw', 'a_propn_pw',
33
+ 'a_punct_pw', 'a_sconj_pw', 'a_sym_pw', 'a_verb_pw', 'a_space_pw', 'a_adj_ps', 'a_adp_ps', 'a_adv_ps', 'a_aux_ps', 'a_cconj_ps',
34
+ 'a_det_ps', 'a_intj_ps', 'a_noun_ps', 'a_num_ps', 'a_part_ps', 'a_pron_ps', 'a_propn_ps', 'a_punct_ps', 'a_sconj_ps', 'a_sym_ps',
35
+ 'a_verb_ps', 'a_space_ps', 'fkre', 'fkgl', 'fogi', 'smog', 'cole', 'auto', 'rt_fast', 'rt_average', 'rt_slow']
36
+
37
+ lftk_full_names = ['Total Number Of Words', 'Total Number Of Stop Words',
38
+ 'Total Number Of Puntuations', 'Total Number Of Syllables',
39
+ 'Total Number Of Words More Than Two Syllables', 'Total Number Of Words More Than Three Syllables',
40
+ 'Total Number Of Unique Words', 'Total Number Of Sentences',
41
+ 'Total Number Of Characters', 'Average Number Of Words Per Sentence',
42
+ 'Average Number Of Characters Per Sentence', 'Average Number Of Characters Per Word',
43
+ 'Average Number Of Syllables Per Sentence', 'Average Number Of Syllables Per Word',
44
+ 'Average Number Of Stop Words Per Sentence', 'Average Number Of Stop Words Per Word',
45
+ 'Total Kuperman Age Of Acquistion Of Words', 'Total Brysbaert Age Of Acquistion Of Words',
46
+ 'Total Subtlex Us Zipf Of Words', 'Average Kuperman Age Of Acquistion Of Words Per Word',
47
+ 'Average Brysbaert Age Of Acquistion Of Words Per Word', 'Average Kuperman Age Of Acquistion Of Words Per Sentence',
48
+ 'Average Brysbaert Age Of Acquistion Of Words Per Sentence', 'Average Subtlex Us Zipf Of Words Per Word',
49
+ 'Average Subtlex Us Zipf Of Words Per Sentence', 'Total Number Of Named Entities',
50
+ 'Total Number Of Named Entities Person', 'Total Number Of Named Entities Norp',
51
+ 'Total Number Of Named Entities Fac', 'Total Number Of Named Entities Org',
52
+ 'Total Number Of Named Entities Gpe', 'Total Number Of Named Entities Loc',
53
+ 'Total Number Of Named Entities Product', 'Total Number Of Named Entities Event',
54
+ 'Total Number Of Named Entities Art', 'Total Number Of Named Entities Law',
55
+ 'Total Number Of Named Entities Language', 'Total Number Of Named Entities Date',
56
+ 'Total Number Of Named Entities Time', 'Total Number Of Named Entities Percent',
57
+ 'Total Number Of Named Entities Money', 'Total Number Of Named Entities Quantity',
58
+ 'Total Number Of Named Entities Ordinal', 'Total Number Of Named Entities Cardinal',
59
+ 'Average Number Of Named Entities Per Word', 'Average Number Of Named Entities Person Per Word',
60
+ 'Average Number Of Named Entities Norp Per Word', 'Average Number Of Named Entities Fac Per Word',
61
+ 'Average Number Of Named Entities Org Per Word', 'Average Number Of Named Entities Gpe Per Word',
62
+ 'Average Number Of Named Entities Loc Per Word', 'Average Number Of Named Entities Product Per Word',
63
+ 'Average Number Of Named Entities Event Per Word', 'Average Number Of Named Entities Art Per Word',
64
+ 'Average Number Of Named Entities Law Per Word', 'Average Number Of Named Entities Language Per Word',
65
+ 'Average Number Of Named Entities Date Per Word', 'Average Number Of Named Entities Time Per Word',
66
+ 'Average Number Of Named Entities Percent Per Word', 'Average Number Of Named Entities Money Per Word',
67
+ 'Average Number Of Named Entities Quantity Per Word', 'Average Number Of Named Entities Ordinal Per Word',
68
+ 'Average Number Of Named Entities Cardinal Per Word', 'Average Number Of Named Entities Per Sentence',
69
+ 'Average Number Of Named Entities Person Per Sentence', 'Average Number Of Named Entities Norp Per Sentence',
70
+ 'Average Number Of Named Entities Fac Per Sentence', 'Average Number Of Named Entities Org Per Sentence',
71
+ 'Average Number Of Named Entities Gpe Per Sentence', 'Average Number Of Named Entities Loc Per Sentence',
72
+ 'Average Number Of Named Entities Product Per Sentence', 'Average Number Of Named Entities Event Per Sentence',
73
+ 'Average Number Of Named Entities Art Per Sentence', 'Average Number Of Named Entities Law Per Sentence',
74
+ 'Average Number Of Named Entities Language Per Sentence', 'Average Number Of Named Entities Date Per Sentence',
75
+ 'Average Number Of Named Entities Time Per Sentence', 'Average Number Of Named Entities Percent Per Sentence',
76
+ 'Average Number Of Named Entities Money Per Sentence', 'Average Number Of Named Entities Quantity Per Sentence',
77
+ 'Average Number Of Named Entities Ordinal Per Sentence', 'Average Number Of Named Entities Cardinal Per Sentence',
78
+ 'Simple Adjectives Variation', 'Simple Adpositions Variation',
79
+ 'Simple Adverbs Variation', 'Simple Auxiliaries Variation',
80
+ 'Simple Coordinating Conjunctions Variation', 'Simple Determiners Variation',
81
+ 'Simple Interjections Variation', 'Simple Nouns Variation',
82
+ 'Simple Numerals Variation', 'Simple Particles Variation',
83
+ 'Simple Pronouns Variation', 'Simple Proper Nouns Variation',
84
+ 'Simple Punctuations Variation', 'Simple Subordinating Conjunctions Variation',
85
+ 'Simple Symbols Variation', 'Simple Verbs Variation',
86
+ 'Simple Spaces Variation', 'Root Adjectives Variation',
87
+ 'Root Adpositions Variation', 'Root Adverbs Variation',
88
+ 'Root Auxiliaries Variation', 'Root Coordinating Conjunctions Variation',
89
+ 'Root Determiners Variation', 'Root Interjections Variation',
90
+ 'Root Nouns Variation', 'Root Numerals Variation',
91
+ 'Root Particles Variation', 'Root Pronouns Variation',
92
+ 'Root Proper Nouns Variation', 'Root Punctuations Variation',
93
+ 'Root Subordinating Conjunctions Variation', 'Root Symbols Variation',
94
+ 'Root Verbs Variation', 'Root Spaces Variation',
95
+ 'Corrected Adjectives Variation', 'Corrected Adpositions Variation',
96
+ 'Corrected Adverbs Variation', 'Corrected Auxiliaries Variation',
97
+ 'Corrected Coordinating Conjunctions Variation', 'Corrected Determiners Variation',
98
+ 'Corrected Interjections Variation', 'Corrected Nouns Variation',
99
+ 'Corrected Numerals Variation', 'Corrected Particles Variation',
100
+ 'Corrected Pronouns Variation', 'Corrected Proper Nouns Variation',
101
+ 'Corrected Punctuations Variation', 'Corrected Subordinating Conjunctions Variation',
102
+ 'Corrected Symbols Variation', 'Corrected Verbs Variation',
103
+ 'Corrected Spaces Variation', 'Simple Type Token Ratio',
104
+ 'Root Type Token Ratio', 'Corrected Type Token Ratio',
105
+ 'Bilogarithmic Type Token Ratio', 'Uber Type Token Ratio',
106
+ 'Simple Type Token Ratio No Lemma', 'Root Type Token Ratio No Lemma',
107
+ 'Corrected Type Token Ratio No Lemma', 'Bilogarithmic Type Token Ratio No Lemma',
108
+ 'Uber Type Token Ratio No Lemma', 'Total Number Of Adjectives',
109
+ 'Total Number Of Adpositions', 'Total Number Of Adverbs',
110
+ 'Total Number Of Auxiliaries', 'Total Number Of Coordinating Conjunctions',
111
+ 'Total Number Of Determiners', 'Total Number Of Interjections',
112
+ 'Total Number Of Nouns', 'Total Number Of Numerals',
113
+ 'Total Number Of Particles', 'Total Number Of Pronouns',
114
+ 'Total Number Of Proper Nouns', 'Total Number Of Punctuations',
115
+ 'Total Number Of Subordinating Conjunctions', 'Total Number Of Symbols',
116
+ 'Total Number Of Verbs', 'Total Number Of Spaces',
117
+ 'Total Number Of Unique Adjectives', 'Total Number Of Unique Adpositions',
118
+ 'Total Number Of Unique Adverbs', 'Total Number Of Unique Auxiliaries',
119
+ 'Total Number Of Unique Coordinating Conjunctions', 'Total Number Of Unique Determiners',
120
+ 'Total Number Of Unique Interjections', 'Total Number Of Unique Nouns',
121
+ 'Total Number Of Unique Numerals', 'Total Number Of Unique Particles',
122
+ 'Total Number Of Unique Pronouns', 'Total Number Of Unique Proper Nouns',
123
+ 'Total Number Of Unique Punctuations', 'Total Number Of Unique Subordinating Conjunctions',
124
+ 'Total Number Of Unique Symbols', 'Total Number Of Unique Verbs',
125
+ 'Total Number Of Unique Spaces', 'Average Number Of Adjectives Per Word',
126
+ 'Average Number Of Adpositions Per Word', 'Average Number Of Adverbs Per Word',
127
+ 'Average Number Of Auxiliaries Per Word', 'Average Number Of Coordinating Conjunctions Per Word',
128
+ 'Average Number Of Determiners Per Word', 'Average Number Of Interjections Per Word',
129
+ 'Average Number Of Nouns Per Word', 'Average Number Of Numerals Per Word',
130
+ 'Average Number Of Particles Per Word', 'Average Number Of Pronouns Per Word',
131
+ 'Average Number Of Proper Nouns Per Word', 'Average Number Of Punctuations Per Word',
132
+ 'Average Number Of Subordinating Conjunctions Per Word', 'Average Number Of Symbols Per Word',
133
+ 'Average Number Of Verbs Per Word', 'Average Number Of Spaces Per Word',
134
+ 'Average Number Of Adjectives Per Sentence', 'Average Number Of Adpositions Per Sentence',
135
+ 'Average Number Of Adverbs Per Sentence', 'Average Number Of Auxiliaries Per Sentence',
136
+ 'Average Number Of Coordinating Conjunctions Per Sentence', 'Average Number Of Determiners Per Sentence',
137
+ 'Average Number Of Interjections Per Sentence', 'Average Number Of Nouns Per Sentence',
138
+ 'Average Number Of Numerals Per Sentence', 'Average Number Of Particles Per Sentence',
139
+ 'Average Number Of Pronouns Per Sentence', 'Average Number Of Proper Nouns Per Sentence',
140
+ 'Average Number Of Punctuations Per Sentence', 'Average Number Of Subordinating Conjunctions Per Sentence',
141
+ 'Average Number Of Symbols Per Sentence', 'Average Number Of Verbs Per Sentence',
142
+ 'Average Number Of Spaces Per Sentence', 'Flesch Kincaid Reading Ease',
143
+ 'Flesch Kincaid Grade Level', 'Gunning Fog Index',
144
+ 'Smog Index', 'Coleman Liau Index',
145
+ 'Automated Readability Index', 'Reading Time For Fast Readers',
146
+ 'Reading Time For Average Readers', 'Reading Time For Slow Readers']
147
+
148
+ full_names = [
149
+ 'Unique words',
150
+ 'Unique sophisticated words',
151
+ 'Unique lexical words',
152
+ 'Unique sophisticated lexical words',
153
+ 'Total words',
154
+ 'Total sophisticated words',
155
+ 'Total lexical words',
156
+ 'Total sophisticated lexical words',
157
+ 'Lexical density',
158
+ 'Lexical sophistication (total)',
159
+ 'Lexical sophistication (unique)',
160
+ 'Verb sophistication',
161
+ 'Verb sophistication (squared numerator)',
162
+ 'Verb sophistication (sqrt denominator)',
163
+ 'Unique words',
164
+ 'Unique words in first k tokens',
165
+ 'Unique words in random k tokens (average of 10 samples)',
166
+ 'Unique words in random sequence of k words (average of 10 samples)',
167
+ 'Ratio of unique words',
168
+ 'Mean TTR of all k word segments',
169
+ 'Corrected TTR (sqrt(2N) denominator)',
170
+ 'Root TTR (sqrt(N) denominator)',
171
+ 'Log TTR',
172
+ 'Uber',
173
+ 'D Measure',
174
+ 'Ratio of unique verbs',
175
+ 'Verb variation with squared numerator',
176
+ 'Verb variation with (sqrt(2N)) denominator',
177
+ 'Verb variation over all lexical words',
178
+ 'Noun variation',
179
+ 'Adjective variation',
180
+ 'Adverb variation',
181
+ '(Ajd + Adv) variation',
182
+ '# words',
183
+ '# sentences',
184
+ '# verb phrases',
185
+ '# clauses',
186
+ '# T-units',
187
+ '# dependent clauses',
188
+ '# complex T-units',
189
+ '# coordinate phrases',
190
+ '# complex nominals',
191
+ 'Mean length of sentence',
192
+ 'Mean length of T-unit',
193
+ 'Mean unit of clause',
194
+ 'Clauses per sentence',
195
+ 'Verb phrases per T-unit',
196
+ 'Clauses per T-unit',
197
+ 'Dependent clause ratio',
198
+ 'Dependent clause per T-unit',
199
+ 'T-units per sentence',
200
+ 'Complex T-unit ratio',
201
+ 'Coordinate phrases per T-unit',
202
+ 'Coordinate phrases per clause',
203
+ 'Complex nominals per T-unit',
204
+ 'Complex nominals per clause',
205
+ ]
206
+
207
+ lingfeat_names = [
208
+ 'WRich05_S', 'WRich10_S', 'WRich15_S', 'WRich20_S', 'WClar05_S', 'WClar10_S',
209
+ 'WClar15_S', 'WClar20_S', 'WNois05_S', 'WNois10_S', 'WNois15_S', 'WNois20_S',
210
+ 'WTopc05_S', 'WTopc10_S', 'WTopc15_S', 'WTopc20_S', 'BRich05_S', 'BRich10_S',
211
+ 'BRich15_S', 'BRich20_S', 'BClar05_S', 'BClar10_S', 'BClar15_S', 'BClar20_S',
212
+ 'BNois05_S', 'BNois10_S', 'BNois15_S', 'BNois20_S', 'BTopc05_S', 'BTopc10_S',
213
+ 'BTopc15_S', 'BTopc20_S', 'to_EntiM_C', 'as_EntiM_C', 'at_EntiM_C', 'to_UEnti_C',
214
+ 'as_UEnti_C', 'at_UEnti_C', 'ra_SSTo_C', 'ra_SOTo_C', 'ra_SXTo_C', 'ra_SNTo_C',
215
+ 'ra_OSTo_C', 'ra_OOTo_C', 'ra_OXTo_C', 'ra_ONTo_C', 'ra_XSTo_C', 'ra_XOTo_C',
216
+ 'ra_XXTo_C', 'ra_XNTo_C', 'ra_NSTo_C', 'ra_NOTo_C', 'ra_NXTo_C', 'ra_NNTo_C',
217
+ 'LoCohPA_S', 'LoCohPW_S', 'LoCohPU_S', 'LoCoDPA_S', 'LoCoDPW_S', 'LoCoDPU_S',
218
+ 'to_NoTag_C', 'as_NoTag_C', 'at_NoTag_C', 'ra_NoAjT_C', 'ra_NoVeT_C', 'ra_NoAvT_C',
219
+ 'ra_NoSuT_C', 'ra_NoCoT_C', 'to_VeTag_C', 'as_VeTag_C', 'at_VeTag_C', 'ra_VeAjT_C',
220
+ 'ra_VeNoT_C', 'ra_VeAvT_C', 'ra_VeSuT_C', 'ra_VeCoT_C', 'to_AjTag_C', 'as_AjTag_C',
221
+ 'at_AjTag_C', 'ra_AjNoT_C', 'ra_AjVeT_C', 'ra_AjAvT_C', 'ra_AjSuT_C', 'ra_AjCoT_C',
222
+ 'to_AvTag_C', 'as_AvTag_C', 'at_AvTag_C', 'ra_AvAjT_C', 'ra_AvNoT_C', 'ra_AvVeT_C',
223
+ 'ra_AvSuT_C', 'ra_AvCoT_C', 'to_SuTag_C', 'as_SuTag_C', 'at_SuTag_C', 'ra_SuAjT_C',
224
+ 'ra_SuNoT_C', 'ra_SuVeT_C', 'ra_SuAvT_C', 'ra_SuCoT_C', 'to_CoTag_C', 'as_CoTag_C',
225
+ 'at_CoTag_C', 'ra_CoAjT_C', 'ra_CoNoT_C', 'ra_CoVeT_C', 'ra_CoAvT_C', 'ra_CoSuT_C',
226
+ 'to_ContW_C', 'as_ContW_C', 'at_ContW_C', 'to_FuncW_C', 'as_FuncW_C', 'at_FuncW_C',
227
+ 'ra_CoFuW_C', 'SimpTTR_S', 'CorrTTR_S', 'BiLoTTR_S', 'UberTTR_S', 'MTLDTTR_S',
228
+ 'SimpNoV_S', 'SquaNoV_S', 'CorrNoV_S', 'SimpVeV_S', 'SquaVeV_S', 'CorrVeV_S',
229
+ 'SimpAjV_S', 'SquaAjV_S', 'CorrAjV_S', 'SimpAvV_S', 'SquaAvV_S', 'CorrAvV_S',
230
+ 'to_AAKuW_C', 'as_AAKuW_C', 'at_AAKuW_C', 'to_AAKuL_C', 'as_AAKuL_C', 'at_AAKuL_C',
231
+ 'to_AABiL_C', 'as_AABiL_C', 'at_AABiL_C', 'to_AABrL_C', 'as_AABrL_C', 'at_AABrL_C',
232
+ 'to_AACoL_C', 'as_AACoL_C', 'at_AACoL_C', 'to_SbFrQ_C', 'as_SbFrQ_C', 'at_SbFrQ_C',
233
+ 'to_SbCDC_C', 'as_SbCDC_C', 'at_SbCDC_C', 'to_SbFrL_C', 'as_SbFrL_C', 'at_SbFrL_C',
234
+ 'to_SbCDL_C', 'as_SbCDL_C', 'at_SbCDL_C', 'to_SbSBW_C', 'as_SbSBW_C', 'at_SbSBW_C',
235
+ 'to_SbL1W_C', 'as_SbL1W_C', 'at_SbL1W_C', 'to_SbSBC_C', 'as_SbSBC_C', 'at_SbSBC_C',
236
+ 'to_SbL1C_C', 'as_SbL1C_C', 'at_SbL1C_C', 'TokSenM_S', 'TokSenS_S', 'TokSenL_S',
237
+ 'as_Token_C', 'as_Sylla_C', 'at_Sylla_C', 'as_Chara_C', 'at_Chara_C', 'FleschG_S',
238
+ 'AutoRea_S', 'ColeLia_S', 'SmogInd_S', 'Gunning_S', 'LinseaW_S'
239
+ ]
240
+
241
+ lingfeat_subtypes = [
242
+ "Knowledge Feats",
243
+ "Knowledge Feats",
244
+ "Knowledge Feats",
245
+ "Knowledge Feats",
246
+ "Knowledge Feats",
247
+ "Knowledge Feats",
248
+ "Knowledge Feats",
249
+ "Knowledge Feats",
250
+ "Knowledge Feats",
251
+ "Knowledge Feats",
252
+ "Knowledge Feats",
253
+ "Knowledge Feats",
254
+ "Knowledge Feats",
255
+ "Knowledge Feats",
256
+ "Knowledge Feats",
257
+ "Knowledge Feats",
258
+ "Knowledge Feats",
259
+ "Knowledge Feats",
260
+ "Knowledge Feats",
261
+ "Knowledge Feats",
262
+ "Knowledge Feats",
263
+ "Knowledge Feats",
264
+ "Knowledge Feats",
265
+ "Knowledge Feats",
266
+ "Knowledge Feats",
267
+ "Knowledge Feats",
268
+ "Knowledge Feats",
269
+ "Knowledge Feats",
270
+ "Knowledge Feats",
271
+ "Knowledge Feats",
272
+ "Knowledge Feats",
273
+ "Knowledge Feats",
274
+ "Knowledge Feats",
275
+ "Knowledge Feats",
276
+ "Knowledge Feats",
277
+ "Knowledge Feats",
278
+ "Knowledge Feats",
279
+ "Knowledge Feats",
280
+ "Knowledge Feats",
281
+ "Knowledge Feats",
282
+ "Knowledge Feats",
283
+ "Knowledge Feats",
284
+ "Knowledge Feats",
285
+ "Knowledge Feats",
286
+ "Knowledge Feats",
287
+ "Knowledge Feats",
288
+ "Knowledge Feats",
289
+ "Knowledge Feats",
290
+ "Entity Density Feats",
291
+ "Entity Density Feats",
292
+ "Entity Density Feats",
293
+ "Entity Density Feats",
294
+ "Entity Density Feats",
295
+ "Entity Density Feats",
296
+ "Entity Grid Feats",
297
+ "Entity Grid Feats",
298
+ "Entity Grid Feats",
299
+ "Entity Grid Feats",
300
+ "Entity Grid Feats",
301
+ "Entity Grid Feats",
302
+ "Entity Grid Feats",
303
+ "Entity Grid Feats",
304
+ "Entity Grid Feats",
305
+ "Entity Grid Feats",
306
+ "Entity Grid Feats",
307
+ "Entity Grid Feats",
308
+ "Entity Grid Feats",
309
+ "Entity Grid Feats",
310
+ "Entity Grid Feats",
311
+ "Entity Grid Feats",
312
+ "Entity Grid Feats",
313
+ "Entity Grid Feats",
314
+ "Entity Grid Feats",
315
+ "Entity Grid Feats",
316
+ "Entity Grid Feats",
317
+ "Entity Grid Feats",
318
+ "Phrasal Feats",
319
+ "Phrasal Feats",
320
+ "Phrasal Feats",
321
+ "Phrasal Feats",
322
+ "Phrasal Feats",
323
+ "Phrasal Feats",
324
+ "Phrasal Feats",
325
+ "Phrasal Feats",
326
+ "Phrasal Feats",
327
+ "Phrasal Feats",
328
+ "Phrasal Feats",
329
+ "Phrasal Feats",
330
+ "Phrasal Feats",
331
+ "Phrasal Feats",
332
+ "Phrasal Feats",
333
+ "Phrasal Feats",
334
+ "Phrasal Feats",
335
+ "Phrasal Feats",
336
+ "Phrasal Feats",
337
+ "Phrasal Feats",
338
+ "Phrasal Feats",
339
+ "Phrasal Feats",
340
+ "Phrasal Feats",
341
+ "Phrasal Feats",
342
+ "Phrasal Feats",
343
+ "Phrasal Feats",
344
+ "Phrasal Feats",
345
+ "Phrasal Feats",
346
+ "Phrasal Feats",
347
+ "Phrasal Feats",
348
+ "Phrasal Feats",
349
+ "Phrasal Feats",
350
+ "Phrasal Feats",
351
+ "Phrasal Feats",
352
+ "Phrasal Feats",
353
+ "Phrasal Feats",
354
+ "Phrasal Feats",
355
+ "Phrasal Feats",
356
+ "Phrasal Feats",
357
+ "Phrasal Feats",
358
+ "Phrasal Feats",
359
+ "Phrasal Feats",
360
+ "Phrasal Feats",
361
+ "Phrasal Feats",
362
+ "Phrasal Feats",
363
+ "Phrasal Feats",
364
+ "Phrasal Feats",
365
+ "Phrasal Feats",
366
+ "Tree Structure Feats",
367
+ "Tree Structure Feats",
368
+ "Tree Structure Feats",
369
+ "Tree Structure Feats",
370
+ "Tree Structure Feats",
371
+ "Tree Structure Feats",
372
+ "POS Feats",
373
+ "POS Feats",
374
+ "POS Feats",
375
+ "POS Feats",
376
+ "POS Feats",
377
+ "POS Feats",
378
+ "POS Feats",
379
+ "POS Feats",
380
+ "POS Feats",
381
+ "POS Feats",
382
+ "POS Feats",
383
+ "POS Feats",
384
+ "POS Feats",
385
+ "POS Feats",
386
+ "POS Feats",
387
+ "POS Feats",
388
+ "POS Feats",
389
+ "POS Feats",
390
+ "POS Feats",
391
+ "POS Feats",
392
+ "POS Feats",
393
+ "POS Feats",
394
+ "POS Feats",
395
+ "POS Feats",
396
+ "POS Feats",
397
+ "POS Feats",
398
+ "POS Feats",
399
+ "POS Feats",
400
+ "POS Feats",
401
+ "POS Feats",
402
+ "POS Feats",
403
+ "POS Feats",
404
+ "POS Feats",
405
+ "POS Feats",
406
+ "POS Feats",
407
+ "POS Feats",
408
+ "POS Feats",
409
+ "POS Feats",
410
+ "POS Feats",
411
+ "POS Feats",
412
+ "POS Feats",
413
+ "POS Feats",
414
+ "POS Feats",
415
+ "POS Feats",
416
+ "POS Feats",
417
+ "POS Feats",
418
+ "POS Feats",
419
+ "POS Feats",
420
+ "POS Feats",
421
+ "POS Feats",
422
+ "POS Feats",
423
+ "POS Feats",
424
+ "POS Feats",
425
+ "POS Feats",
426
+ "POS Feats",
427
+ "Variation Ratio Feats",
428
+ "Variation Ratio Feats",
429
+ "Variation Ratio Feats",
430
+ "Variation Ratio Feats",
431
+ "Variation Ratio Feats",
432
+ "Variation Ratio Feats",
433
+ "Variation Ratio Feats",
434
+ "Variation Ratio Feats",
435
+ "Variation Ratio Feats",
436
+ "Variation Ratio Feats",
437
+ "Variation Ratio Feats",
438
+ "Variation Ratio Feats",
439
+ "TTR Feats",
440
+ "TTR Feats",
441
+ "TTR Feats",
442
+ "TTR Feats",
443
+ "TTR Feats",
444
+ "Psycholinguistic Feats",
445
+ "Psycholinguistic Feats",
446
+ "Psycholinguistic Feats",
447
+ "Psycholinguistic Feats",
448
+ "Psycholinguistic Feats",
449
+ "Psycholinguistic Feats",
450
+ "Psycholinguistic Feats",
451
+ "Psycholinguistic Feats",
452
+ "Psycholinguistic Feats",
453
+ "Psycholinguistic Feats",
454
+ "Psycholinguistic Feats",
455
+ "Psycholinguistic Feats",
456
+ "Psycholinguistic Feats",
457
+ "Psycholinguistic Feats",
458
+ "Psycholinguistic Feats",
459
+ "Word Familiarity",
460
+ "Word Familiarity",
461
+ "Word Familiarity",
462
+ "Word Familiarity",
463
+ "Word Familiarity",
464
+ "Word Familiarity",
465
+ "Word Familiarity",
466
+ "Word Familiarity",
467
+ "Word Familiarity",
468
+ "Word Familiarity",
469
+ "Word Familiarity",
470
+ "Word Familiarity",
471
+ "Word Familiarity",
472
+ "Word Familiarity",
473
+ "Word Familiarity",
474
+ "Word Familiarity",
475
+ "Word Familiarity",
476
+ "Word Familiarity",
477
+ "Word Familiarity",
478
+ "Word Familiarity",
479
+ "Word Familiarity",
480
+ "Word Familiarity",
481
+ "Word Familiarity",
482
+ "Word Familiarity",
483
+ "Shallow Feats",
484
+ "Shallow Feats",
485
+ "Shallow Feats",
486
+ "Shallow Feats",
487
+ "Shallow Feats",
488
+ "Shallow Feats",
489
+ "Shallow Feats",
490
+ "Shallow Feats",
491
+ "Traditional Formulas",
492
+ "Traditional Formulas",
493
+ "Traditional Formulas",
494
+ "Traditional Formulas",
495
+ "Traditional Formulas",
496
+ "Traditional Formulas",
497
+ ]
498
+
499
+ lingfeat_types = [
500
+ "AdSem",
501
+ "AdSem",
502
+ "AdSem",
503
+ "AdSem",
504
+ "AdSem",
505
+ "AdSem",
506
+ "AdSem",
507
+ "AdSem",
508
+ "AdSem",
509
+ "AdSem",
510
+ "AdSem",
511
+ "AdSem",
512
+ "AdSem",
513
+ "AdSem",
514
+ "AdSem",
515
+ "AdSem",
516
+ "AdSem",
517
+ "AdSem",
518
+ "AdSem",
519
+ "AdSem",
520
+ "AdSem",
521
+ "AdSem",
522
+ "AdSem",
523
+ "AdSem",
524
+ "AdSem",
525
+ "AdSem",
526
+ "AdSem",
527
+ "AdSem",
528
+ "AdSem",
529
+ "AdSem",
530
+ "AdSem",
531
+ "AdSem",
532
+ "AdSem",
533
+ "AdSem",
534
+ "AdSem",
535
+ "AdSem",
536
+ "AdSem",
537
+ "AdSem",
538
+ "AdSem",
539
+ "AdSem",
540
+ "AdSem",
541
+ "AdSem",
542
+ "AdSem",
543
+ "AdSem",
544
+ "AdSem",
545
+ "AdSem",
546
+ "AdSem",
547
+ "AdSem",
548
+ "Disco",
549
+ "Disco",
550
+ "Disco",
551
+ "Disco",
552
+ "Disco",
553
+ "Disco",
554
+ "Disco",
555
+ "Disco",
556
+ "Disco",
557
+ "Disco",
558
+ "Disco",
559
+ "Disco",
560
+ "Disco",
561
+ "Disco",
562
+ "Disco",
563
+ "Disco",
564
+ "Disco",
565
+ "Disco",
566
+ "Disco",
567
+ "Disco",
568
+ "Disco",
569
+ "Disco",
570
+ "Disco",
571
+ "Disco",
572
+ "Disco",
573
+ "Disco",
574
+ "Disco",
575
+ "Disco",
576
+ "Synta",
577
+ "Synta",
578
+ "Synta",
579
+ "Synta",
580
+ "Synta",
581
+ "Synta",
582
+ "Synta",
583
+ "Synta",
584
+ "Synta",
585
+ "Synta",
586
+ "Synta",
587
+ "Synta",
588
+ "Synta",
589
+ "Synta",
590
+ "Synta",
591
+ "Synta",
592
+ "Synta",
593
+ "Synta",
594
+ "Synta",
595
+ "Synta",
596
+ "Synta",
597
+ "Synta",
598
+ "Synta",
599
+ "Synta",
600
+ "Synta",
601
+ "Synta",
602
+ "Synta",
603
+ "Synta",
604
+ "Synta",
605
+ "Synta",
606
+ "Synta",
607
+ "Synta",
608
+ "Synta",
609
+ "Synta",
610
+ "Synta",
611
+ "Synta",
612
+ "Synta",
613
+ "Synta",
614
+ "Synta",
615
+ "Synta",
616
+ "Synta",
617
+ "Synta",
618
+ "Synta",
619
+ "Synta",
620
+ "Synta",
621
+ "Synta",
622
+ "Synta",
623
+ "Synta",
624
+ "Synta",
625
+ "Synta",
626
+ "Synta",
627
+ "Synta",
628
+ "Synta",
629
+ "Synta",
630
+ "Synta",
631
+ "Synta",
632
+ "Synta",
633
+ "Synta",
634
+ "Synta",
635
+ "Synta",
636
+ "Synta",
637
+ "Synta",
638
+ "Synta",
639
+ "Synta",
640
+ "Synta",
641
+ "Synta",
642
+ "Synta",
643
+ "Synta",
644
+ "Synta",
645
+ "Synta",
646
+ "Synta",
647
+ "Synta",
648
+ "Synta",
649
+ "Synta",
650
+ "Synta",
651
+ "Synta",
652
+ "Synta",
653
+ "Synta",
654
+ "Synta",
655
+ "Synta",
656
+ "Synta",
657
+ "Synta",
658
+ "Synta",
659
+ "Synta",
660
+ "Synta",
661
+ "Synta",
662
+ "Synta",
663
+ "Synta",
664
+ "Synta",
665
+ "Synta",
666
+ "Synta",
667
+ "Synta",
668
+ "Synta",
669
+ "Synta",
670
+ "Synta",
671
+ "Synta",
672
+ "Synta",
673
+ "Synta",
674
+ "Synta",
675
+ "Synta",
676
+ "Synta",
677
+ "Synta",
678
+ "Synta",
679
+ "Synta",
680
+ "Synta",
681
+ "Synta",
682
+ "Synta",
683
+ "Synta",
684
+ "Synta",
685
+ "LxSem",
686
+ "LxSem",
687
+ "LxSem",
688
+ "LxSem",
689
+ "LxSem",
690
+ "LxSem",
691
+ "LxSem",
692
+ "LxSem",
693
+ "LxSem",
694
+ "LxSem",
695
+ "LxSem",
696
+ "LxSem",
697
+ "LxSem",
698
+ "LxSem",
699
+ "LxSem",
700
+ "LxSem",
701
+ "LxSem",
702
+ "LxSem",
703
+ "LxSem",
704
+ "LxSem",
705
+ "LxSem",
706
+ "LxSem",
707
+ "LxSem",
708
+ "LxSem",
709
+ "LxSem",
710
+ "LxSem",
711
+ "LxSem",
712
+ "LxSem",
713
+ "LxSem",
714
+ "LxSem",
715
+ "LxSem",
716
+ "LxSem",
717
+ "LxSem",
718
+ "LxSem",
719
+ "LxSem",
720
+ "LxSem",
721
+ "LxSem",
722
+ "LxSem",
723
+ "LxSem",
724
+ "LxSem",
725
+ "LxSem",
726
+ "LxSem",
727
+ "LxSem",
728
+ "LxSem",
729
+ "LxSem",
730
+ "LxSem",
731
+ "LxSem",
732
+ "LxSem",
733
+ "LxSem",
734
+ "LxSem",
735
+ "LxSem",
736
+ "LxSem",
737
+ "LxSem",
738
+ "LxSem",
739
+ "LxSem",
740
+ "LxSem",
741
+ "ShaTr",
742
+ "ShaTr",
743
+ "ShaTr",
744
+ "ShaTr",
745
+ "ShaTr",
746
+ "ShaTr",
747
+ "ShaTr",
748
+ "ShaTr",
749
+ "ShaTr",
750
+ "ShaTr",
751
+ "ShaTr",
752
+ "ShaTr",
753
+ "ShaTr",
754
+ "ShaTr",
755
+ ]
756
+
757
+ lf_names = """| 1 | AdSem | WoKF_ | Wiki Knowledge Features | WRich05_S | Semantic Richness, 50 topics extracted from Wikipedia |
758
+ | 2 | AdSem | WoKF_ | Wiki Knowledge Features | WClar05_S | Semantic Clarity, 50 topics extracted from Wikipedia |
759
+ | 3 | AdSem | WoKF_ | Wiki Knowledge Features | WNois05_S | Semantic Noise, 50 topics extracted from Wikipedia |
760
+ | 4 | AdSem | WoKF_ | Wiki Knowledge Features | WTopc05_S | Number of topics, 50 topics extracted from Wikipedia |
761
+ | 5 | AdSem | WoKF_ | Wiki Knowledge Features | WRich10_S | Semantic Richness, 100 topics extracted from Wikipedia |
762
+ | 6 | AdSem | WoKF_ | Wiki Knowledge Features | WClar10_S | Semantic Clarity, 100 topics extracted from Wikipedia |
763
+ | 7 | AdSem | WoKF_ | Wiki Knowledge Features | WNois10_S | Semantic Noise, 100 topics extracted from Wikipedia |
764
+ | 8 | AdSem | WoKF_ | Wiki Knowledge Features | WTopc10_S | Number of topics, 100 topics extracted from Wikipedia |
765
+ | 9 | AdSem | WoKF_ | Wiki Knowledge Features | WRich15_S | Semantic Richness, 150 topics extracted from Wikipedia |
766
+ | 10 | AdSem | WoKF_ | Wiki Knowledge Features | WClar15_S | Semantic Clarity, 150 topics extracted from Wikipedia |
767
+ | 11 | AdSem | WoKF_ | Wiki Knowledge Features | WNois15_S | Semantic Noise, 150 topics extracted from Wikipedia |
768
+ | 12 | AdSem | WoKF_ | Wiki Knowledge Features | WTopc15_S | Number of topics, 150 topics extracted from Wikipedia |
769
+ | 13 | AdSem | WoKF_ | Wiki Knowledge Features | WRich20_S | Semantic Richness, 200 topics extracted from Wikipedia |
770
+ | 14 | AdSem | WoKF_ | Wiki Knowledge Features | WClar20_S | Semantic Clarity, 200 topics extracted from Wikipedia |
771
+ | 15 | AdSem | WoKF_ | Wiki Knowledge Features | WNois20_S | Semantic Noise, 200 topics extracted from Wikipedia |
772
+ | 16 | AdSem | WoKF_ | Wiki Knowledge Features | WTopc20_S | Number of topics, 200 topics extracted from Wikipedia |
773
+ | 17 | AdSem | WBKF_ | WB Knowledge Features | BRich05_S | Semantic Richness, 50 topics extracted from WeeBit Corpus |
774
+ | 18 | AdSem | WBKF_ | WB Knowledge Features | BClar05_S | Semantic Clarity, 50 topics extracted from WeeBit Corpus |
775
+ | 19 | AdSem | WBKF_ | WB Knowledge Features | BNois05_S | Semantic Noise, 50 topics extracted from WeeBit Corpus |
776
+ | 20 | AdSem | WBKF_ | WB Knowledge Features | BTopc05_S | Number of topics, 50 topics extracted from WeeBit Corpus |
777
+ | 21 | AdSem | WBKF_ | WB Knowledge Features | BRich10_S | Semantic Richness, 100 topics extracted from WeeBit Corpus |
778
+ | 22 | AdSem | WBKF_ | WB Knowledge Features | BClar10_S | Semantic Clarity, 100 topics extracted from WeeBit Corpus |
779
+ | 23 | AdSem | WBKF_ | WB Knowledge Features | BNois10_S | Semantic Noise, 100 topics extracted from WeeBit Corpus |
780
+ | 24 | AdSem | WBKF_ | WB Knowledge Features | BTopc10_S | Number of topics, 100 topics extracted from WeeBit Corpus |
781
+ | 25 | AdSem | WBKF_ | WB Knowledge Features | BRich15_S | Semantic Richness, 150 topics extracted from WeeBit Corpus |
782
+ | 26 | AdSem | WBKF_ | WB Knowledge Features | BClar15_S | Semantic Clarity, 150 topics extracted from WeeBit Corpus |
783
+ | 27 | AdSem | WBKF_ | WB Knowledge Features | BNois15_S | Semantic Noise, 150 topics extracted from WeeBit Corpus |
784
+ | 28 | AdSem | WBKF_ | WB Knowledge Features | BTopc15_S | Number of topics, 150 topics extracted from WeeBit Corpus |
785
+ | 29 | AdSem | WBKF_ | WB Knowledge Features | BRich20_S | Semantic Richness, 200 topics extracted from WeeBit Corpus |
786
+ | 30 | AdSem | WBKF_ | WB Knowledge Features | BClar20_S | Semantic Clarity, 200 topics extracted from WeeBit Corpus |
787
+ | 31 | AdSem | WBKF_ | WB Knowledge Features | BNois20_S | Semantic Noise, 200 topics extracted from WeeBit Corpus |
788
+ | 32 | AdSem | WBKF_ | WB Knowledge Features | BTopc20_S | Number of topics, 200 topics extracted from WeeBit Corpus |
789
+ | 33 | AdSem | OSKF_ | OSE Knowledge Features | ORich05_S | Semantic Richness, 50 topics extracted from OneStopEng Corpus |
790
+ | 34 | AdSem | OSKF_ | OSE Knowledge Features | OClar05_S | Semantic Clarity, 50 topics extracted from OneStopEng Corpus |
791
+ | 35 | AdSem | OSKF_ | OSE Knowledge Features | ONois05_S | Semantic Noise, 50 topics extracted from OneStopEng Corpus |
792
+ | 36 | AdSem | OSKF_ | OSE Knowledge Features | OTopc05_S | Number of topics, 50 topics extracted from OneStopEng Corpus |
793
+ | 37 | AdSem | OSKF_ | OSE Knowledge Features | ORich10_S | Semantic Richness, 100 topics extracted from OneStopEng Corpus |
794
+ | 38 | AdSem | OSKF_ | OSE Knowledge Features | OClar10_S | Semantic Clarity, 100 topics extracted from OneStopEng Corpus |
795
+ | 39 | AdSem | OSKF_ | OSE Knowledge Features | ONois10_S | Semantic Noise, 100 topics extracted from OneStopEng Corpus |
796
+ | 40 | AdSem | OSKF_ | OSE Knowledge Features | OTopc10_S | Number of topics, 100 topics extracted from OneStopEng Corpus |
797
+ | 41 | AdSem | OSKF_ | OSE Knowledge Features | ORich15_S | Semantic Richness, 150 topics extracted from OneStopEng Corpus |
798
+ | 42 | AdSem | OSKF_ | OSE Knowledge Features | OClar15_S | Semantic Clarity, 150 topics extracted from OneStopEng Corpus |
799
+ | 43 | AdSem | OSKF_ | OSE Knowledge Features | ONois15_S | Semantic Noise, 150 topics extracted from OneStopEng Corpus |
800
+ | 44 | AdSem | OSKF_ | OSE Knowledge Features | OTopc15_S | Number of topics, 150 topics extracted from OneStopEng Corpus |
801
+ | 45 | AdSem | OSKF_ | OSE Knowledge Features | ORich20_S | Semantic Richness, 200 topics extracted from OneStopEng Corpus |
802
+ | 46 | AdSem | OSKF_ | OSE Knowledge Features | OClar20_S | Semantic Clarity, 200 topics extracted from OneStopEng Corpus |
803
+ | 47 | AdSem | OSKF_ | OSE Knowledge Features | ONois20_S | Semantic Noise, 200 topics extracted from OneStopEng Corpus |
804
+ | 48 | AdSem | OSKF_ | OSE Knowledge Features | OTopc20_S | Number of topics, 200 topics extracted from OneStopEng Corpus |
805
+ | 49 | Disco | EnDF_ | Entity Density Features | to_EntiM_C | total number of Entities Mentions counts |
806
+ | 50 | Disco | EnDF_ | Entity Density Features | as_EntiM_C | average number of Entities Mentions counts per sentence |
807
+ | 51 | Disco | EnDF_ | Entity Density Features | at_EntiM_C | average number of Entities Mentions counts per token (word) |
808
+ | 52 | Disco | EnDF_ | Entity Density Features | to_UEnti_C | total number of unique Entities |
809
+ | 53 | Disco | EnDF_ | Entity Density Features | as_UEnti_C | average number of unique Entities per sentence |
810
+ | 54 | Disco | EnDF_ | Entity Density Features | at_UEnti_C | average number of unique Entities per token (word) |
811
+ | 55 | Disco | EnGF_ | Entity Grid Features | ra_SSTo_C | ratio of ss transitions to total |
812
+ | 56 | Disco | EnGF_ | Entity Grid Features | ra_SOTo_C | ratio of so transitions to total |
813
+ | 57 | Disco | EnGF_ | Entity Grid Features | ra_SXTo_C | ratio of sx transitions to total |
814
+ | 58 | Disco | EnGF_ | Entity Grid Features | ra_SNTo_C | ratio of sn transitions to total |
815
+ | 59 | Disco | EnGF_ | Entity Grid Features | ra_OSTo_C | ratio of os transitions to total |
816
+ | 60 | Disco | EnGF_ | Entity Grid Features | ra_OOTo_C | ratio of oo transitions to total |
817
+ | 61 | Disco | EnGF_ | Entity Grid Features | ra_OXTo_C | ratio of ox transitions to total |
818
+ | 62 | Disco | EnGF_ | Entity Grid Features | ra_ONTo_C | ratio of on transitions to total |
819
+ | 63 | Disco | EnGF_ | Entity Grid Features | ra_XSTo_C | ratio of xs transitions to total |
820
+ | 64 | Disco | EnGF_ | Entity Grid Features | ra_XOTo_C | ratio of xo transitions to total |
821
+ | 65 | Disco | EnGF_ | Entity Grid Features | ra_XXTo_C | ratio of xx transitions to total |
822
+ | 66 | Disco | EnGF_ | Entity Grid Features | ra_XNTo_C | ratio of xn transitions to total |
823
+ | 67 | Disco | EnGF_ | Entity Grid Features | ra_NSTo_C | ratio of ns transitions to total |
824
+ | 68 | Disco | EnGF_ | Entity Grid Features | ra_NOTo_C | ratio of no transitions to total |
825
+ | 69 | Disco | EnGF_ | Entity Grid Features | ra_NXTo_C | ratio of nx transitions to total |
826
+ | 70 | Disco | EnGF_ | Entity Grid Features | ra_NNTo_C | ratio of nn transitions to total |
827
+ | 71 | Disco | EnGF_ | Entity Grid Features | LoCohPA_S | Local Coherence for PA score |
828
+ | 72 | Disco | EnGF_ | Entity Grid Features | LoCohPW_S | Local Coherence for PW score |
829
+ | 73 | Disco | EnGF_ | Entity Grid Features | LoCohPU_S | Local Coherence for PU score |
830
+ | 74 | Disco | EnGF_ | Entity Grid Features | LoCoDPA_S | Local Coherence distance for PA score |
831
+ | 75 | Disco | EnGF_ | Entity Grid Features | LoCoDPW_S | Local Coherence distance for PW score |
832
+ | 76 | Disco | EnGF_ | Entity Grid Features | LoCoDPU_S | Local Coherence distance for PU score |
833
+ | 77 | Synta | PhrF_ | Phrasal Features | to_NoPhr_C | total count of Noun phrases |
834
+ | 78 | Synta | PhrF_ | Phrasal Features | as_NoPhr_C | average count of Noun phrases per sentence |
835
+ | 79 | Synta | PhrF_ | Phrasal Features | at_NoPhr_C | average count of Noun phrases per token |
836
+ | 80 | Synta | PhrF_ | Phrasal Features | ra_NoVeP_C | ratio of Noun phrases count to Verb phrases count |
837
+ | 81 | Synta | PhrF_ | Phrasal Features | ra_NoSuP_C | ratio of Noun phrases count to Subordinate Clauses count |
838
+ | 82 | Synta | PhrF_ | Phrasal Features | ra_NoPrP_C | ratio of Noun phrases count to Prep phrases count |
839
+ | 83 | Synta | PhrF_ | Phrasal Features | ra_NoAjP_C | ratio of Noun phrases count to Adj phrases count |
840
+ | 84 | Synta | PhrF_ | Phrasal Features | ra_NoAvP_C | ratio of Noun phrases count to Adv phrases count |
841
+ | 85 | Synta | PhrF_ | Phrasal Features | to_VePhr_C | total count of Verb phrases |
842
+ | 86 | Synta | PhrF_ | Phrasal Features | as_VePhr_C | average count of Verb phrases per sentence |
843
+ | 87 | Synta | PhrF_ | Phrasal Features | at_VePhr_C | average count of Verb phrases per token |
844
+ | 88 | Synta | PhrF_ | Phrasal Features | ra_VeNoP_C | ratio of Verb phrases count to Noun phrases count |
845
+ | 89 | Synta | PhrF_ | Phrasal Features | ra_VeSuP_C | ratio of Verb phrases count to Subordinate Clauses count |
846
+ | 90 | Synta | PhrF_ | Phrasal Features | ra_VePrP_C | ratio of Verb phrases count to Prep phrases count |
847
+ | 91 | Synta | PhrF_ | Phrasal Features | ra_VeAjP_C | ratio of Verb phrases count to Adj phrases count |
848
+ | 92 | Synta | PhrF_ | Phrasal Features | ra_VeAvP_C | ratio of Verb phrases count to Adv phrases count |
849
+ | 93 | Synta | PhrF_ | Phrasal Features | to_SuPhr_C | total count of Subordinate Clauses |
850
+ | 94 | Synta | PhrF_ | Phrasal Features | as_SuPhr_C | average count of Subordinate Clauses per sentence |
851
+ | 95 | Synta | PhrF_ | Phrasal Features | at_SuPhr_C | average count of Subordinate Clauses per token |
852
+ | 96 | Synta | PhrF_ | Phrasal Features | ra_SuNoP_C | ratio of Subordinate Clauses count to Noun phrases count |
853
+ | 97 | Synta | PhrF_ | Phrasal Features | ra_SuVeP_C | ratio of Subordinate Clauses count to Verb phrases count |
854
+ | 98 | Synta | PhrF_ | Phrasal Features | ra_SuPrP_C | ratio of Subordinate Clauses count to Prep phrases count |
855
+ | 99 | Synta | PhrF_ | Phrasal Features | ra_SuAjP_C | ratio of Subordinate Clauses count to Adj phrases count |
856
+ | 100 | Synta | PhrF_ | Phrasal Features | ra_SuAvP_C | ratio of Subordinate Clauses count to Adv phrases count |
857
+ | 101 | Synta | PhrF_ | Phrasal Features | to_PrPhr_C | total count of prepositional phrases |
858
+ | 102 | Synta | PhrF_ | Phrasal Features | as_PrPhr_C | average count of prepositional phrases per sentence |
859
+ | 103 | Synta | PhrF_ | Phrasal Features | at_PrPhr_C | average count of prepositional phrases per token |
860
+ | 104 | Synta | PhrF_ | Phrasal Features | ra_PrNoP_C | ratio of Prep phrases count to Noun phrases count |
861
+ | 105 | Synta | PhrF_ | Phrasal Features | ra_PrVeP_C | ratio of Prep phrases count to Verb phrases count |
862
+ | 106 | Synta | PhrF_ | Phrasal Features | ra_PrSuP_C | ratio of Prep phrases count to Subordinate Clauses count |
863
+ | 107 | Synta | PhrF_ | Phrasal Features | ra_PrAjP_C | ratio of Prep phrases count to Adj phrases count |
864
+ | 108 | Synta | PhrF_ | Phrasal Features | ra_PrAvP_C | ratio of Prep phrases count to Adv phrases count |
865
+ | 109 | Synta | PhrF_ | Phrasal Features | to_AjPhr_C | total count of Adjective phrases |
866
+ | 110 | Synta | PhrF_ | Phrasal Features | as_AjPhr_C | average count of Adjective phrases per sentence |
867
+ | 111 | Synta | PhrF_ | Phrasal Features | at_AjPhr_C | average count of Adjective phrases per token |
868
+ | 112 | Synta | PhrF_ | Phrasal Features | ra_AjNoP_C | ratio of Adj phrases count to Noun phrases count |
869
+ | 113 | Synta | PhrF_ | Phrasal Features | ra_AjVeP_C | ratio of Adj phrases count to Verb phrases count |
870
+ | 114 | Synta | PhrF_ | Phrasal Features | ra_AjSuP_C | ratio of Adj phrases count to Subordinate Clauses count |
871
+ | 115 | Synta | PhrF_ | Phrasal Features | ra_AjPrP_C | ratio of Adj phrases count to Prep phrases count |
872
+ | 116 | Synta | PhrF_ | Phrasal Features | ra_AjAvP_C | ratio of Adj phrases count to Adv phrases count |
873
+ | 117 | Synta | PhrF_ | Phrasal Features | to_AvPhr_C | total count of Adverb phrases |
874
+ | 118 | Synta | PhrF_ | Phrasal Features | as_AvPhr_C | average count of Adverb phrases per sentence |
875
+ | 119 | Synta | PhrF_ | Phrasal Features | at_AvPhr_C | average count of Adverb phrases per token |
876
+ | 120 | Synta | PhrF_ | Phrasal Features | ra_AvNoP_C | ratio of Adv phrases count to Noun phrases count |
877
+ | 121 | Synta | PhrF_ | Phrasal Features | ra_AvVeP_C | ratio of Adv phrases count to Verb phrases count |
878
+ | 122 | Synta | PhrF_ | Phrasal Features | ra_AvSuP_C | ratio of Adv phrases count to Subordinate Clauses count |
879
+ | 123 | Synta | PhrF_ | Phrasal Features | ra_AvPrP_C | ratio of Adv phrases count to Prep phrases count |
880
+ | 124 | Synta | PhrF_ | Phrasal Features | ra_AvAjP_C | ratio of Adv phrases count to Adj phrases count |
881
+ | 125 | Synta | TrSF_ | Tree Structure Features | to_TreeH_C | total Tree height of all sentences |
882
+ | 126 | Synta | TrSF_ | Tree Structure Features | as_TreeH_C | average Tree height per sentence |
883
+ | 127 | Synta | TrSF_ | Tree Structure Features | at_TreeH_C | average Tree height per token (word) |
884
+ | 128 | Synta | TrSF_ | Tree Structure Features | to_FTree_C | total length of flattened Trees |
885
+ | 129 | Synta | TrSF_ | Tree Structure Features | as_FTree_C | average length of flattened Trees per sentence |
886
+ | 130 | Synta | TrSF_ | Tree Structure Features | at_FTree_C | average length of flattened Trees per token (word) |
887
+ | 131 | Synta | POSF_ | Part-of-Speech Features | to_NoTag_C | total count of Noun POS tags |
888
+ | 132 | Synta | POSF_ | Part-of-Speech Features | as_NoTag_C | average count of Noun POS tags per sentence |
889
+ | 133 | Synta | POSF_ | Part-of-Speech Features | at_NoTag_C | average count of Noun POS tags per token |
890
+ | 134 | Synta | POSF_ | Part-of-Speech Features | ra_NoAjT_C | ratio of Noun POS count to Adjective POS count |
891
+ | 135 | Synta | POSF_ | Part-of-Speech Features | ra_NoVeT_C | ratio of Noun POS count to Verb POS count |
892
+ | 136 | Synta | POSF_ | Part-of-Speech Features | ra_NoAvT_C | ratio of Noun POS count to Adverb POS count |
893
+ | 137 | Synta | POSF_ | Part-of-Speech Features | ra_NoSuT_C | ratio of Noun POS count to Subordinating Conjunction count |
894
+ | 138 | Synta | POSF_ | Part-of-Speech Features | ra_NoCoT_C | ratio of Noun POS count to Coordinating Conjunction count |
895
+ | 139 | Synta | POSF_ | Part-of-Speech Features | to_VeTag_C | total count of Verb POS tags |
896
+ | 140 | Synta | POSF_ | Part-of-Speech Features | as_VeTag_C | average count of Verb POS tags per sentence |
897
+ | 141 | Synta | POSF_ | Part-of-Speech Features | at_VeTag_C | average count of Verb POS tags per token |
898
+ | 142 | Synta | POSF_ | Part-of-Speech Features | ra_VeAjT_C | ratio of Verb POS count to Adjective POS count |
899
+ | 143 | Synta | POSF_ | Part-of-Speech Features | ra_VeNoT_C | ratio of Verb POS count to Noun POS count |
900
+ | 144 | Synta | POSF_ | Part-of-Speech Features | ra_VeAvT_C | ratio of Verb POS count to Adverb POS count |
901
+ | 145 | Synta | POSF_ | Part-of-Speech Features | ra_VeSuT_C | ratio of Verb POS count to Subordinating Conjunction count |
902
+ | 146 | Synta | POSF_ | Part-of-Speech Features | ra_VeCoT_C | ratio of Verb POS count to Coordinating Conjunction count |
903
+ | 147 | Synta | POSF_ | Part-of-Speech Features | to_AjTag_C | total count of Adjective POS tags |
904
+ | 148 | Synta | POSF_ | Part-of-Speech Features | as_AjTag_C | average count of Adjective POS tags per sentence |
905
+ | 149 | Synta | POSF_ | Part-of-Speech Features | at_AjTag_C | average count of Adjective POS tags per token |
906
+ | 150 | Synta | POSF_ | Part-of-Speech Features | ra_AjNoT_C | ratio of Adjective POS count to Noun POS count |
907
+ | 151 | Synta | POSF_ | Part-of-Speech Features | ra_AjVeT_C | ratio of Adjective POS count to Verb POS count |
908
+ | 152 | Synta | POSF_ | Part-of-Speech Features | ra_AjAvT_C | ratio of Adjective POS count to Adverb POS count |
909
+ | 153 | Synta | POSF_ | Part-of-Speech Features | ra_AjSuT_C | ratio of Adjective POS count to Subordinating Conjunction count |
910
+ | 154 | Synta | POSF_ | Part-of-Speech Features | ra_AjCoT_C | ratio of Adjective POS count to Coordinating Conjunction count |
911
+ | 155 | Synta | POSF_ | Part-of-Speech Features | to_AvTag_C | total count of Adverb POS tags |
912
+ | 156 | Synta | POSF_ | Part-of-Speech Features | as_AvTag_C | average count of Adverb POS tags per sentence |
913
+ | 157 | Synta | POSF_ | Part-of-Speech Features | at_AvTag_C | average count of Adverb POS tags per token |
914
+ | 158 | Synta | POSF_ | Part-of-Speech Features | ra_AvAjT_C | ratio of Adverb POS count to Adjective POS count |
915
+ | 159 | Synta | POSF_ | Part-of-Speech Features | ra_AvNoT_C | ratio of Adverb POS count to Noun POS count |
916
+ | 160 | Synta | POSF_ | Part-of-Speech Features | ra_AvVeT_C | ratio of Adverb POS count to Verb POS count |
917
+ | 161 | Synta | POSF_ | Part-of-Speech Features | ra_AvSuT_C | ratio of Adverb POS count to Subordinating Conjunction count |
918
+ | 162 | Synta | POSF_ | Part-of-Speech Features | ra_AvCoT_C | ratio of Adverb POS count to Coordinating Conjunction count |
919
+ | 163 | Synta | POSF_ | Part-of-Speech Features | to_SuTag_C | total count of Subordinating Conjunction POS tags |
920
+ | 164 | Synta | POSF_ | Part-of-Speech Features | as_SuTag_C | average count of Subordinating Conjunction POS tags per sentence |
921
+ | 165 | Synta | POSF_ | Part-of-Speech Features | at_SuTag_C | average count of Subordinating Conjunction POS tags per token |
922
+ | 166 | Synta | POSF_ | Part-of-Speech Features | ra_SuAjT_C | ratio of Subordinating Conjunction POS count to Adjective POS count |
923
+ | 167 | Synta | POSF_ | Part-of-Speech Features | ra_SuNoT_C | ratio of Subordinating Conjunction POS count to Noun POS count |
924
+ | 168 | Synta | POSF_ | Part-of-Speech Features | ra_SuVeT_C | ratio of Subordinating Conjunction POS count to Verb POS count |
925
+ | 169 | Synta | POSF_ | Part-of-Speech Features | ra_SuAvT_C | ratio of Subordinating Conjunction POS count to Adverb POS count |
926
+ | 170 | Synta | POSF_ | Part-of-Speech Features | ra_SuCoT_C | ratio of Subordinating Conjunction POS count to Coordinating Conjunction count |
927
+ | 171 | Synta | POSF_ | Part-of-Speech Features | to_CoTag_C | total count of Coordinating Conjunction POS tags |
928
+ | 172 | Synta | POSF_ | Part-of-Speech Features | as_CoTag_C | average count of Coordinating Conjunction POS tags per sentence |
929
+ | 173 | Synta | POSF_ | Part-of-Speech Features | at_CoTag_C | average count of Coordinating Conjunction POS tags per token |
930
+ | 174 | Synta | POSF_ | Part-of-Speech Features | ra_CoAjT_C | ratio of Coordinating Conjunction POS count to Adjective POS count |
931
+ | 175 | Synta | POSF_ | Part-of-Speech Features | ra_CoNoT_C | ratio of Coordinating Conjunction POS count to Noun POS count |
932
+ | 176 | Synta | POSF_ | Part-of-Speech Features | ra_CoVeT_C | ratio of Coordinating Conjunction POS count to Verb POS count |
933
+ | 177 | Synta | POSF_ | Part-of-Speech Features | ra_CoAvT_C | ratio of Coordinating Conjunction POS count to Adverb POS count |
934
+ | 178 | Synta | POSF_ | Part-of-Speech Features | ra_CoSuT_C | ratio of Coordinating Conjunction POS count to Subordinating Conjunction count |
935
+ | 179 | Synta | POSF_ | Part-of-Speech Features | to_ContW_C | total count of Content words |
936
+ | 180 | Synta | POSF_ | Part-of-Speech Features | as_ContW_C | average count of Content words per sentence |
937
+ | 181 | Synta | POSF_ | Part-of-Speech Features | at_ContW_C | average count of Content words per token |
938
+ | 182 | Synta | POSF_ | Part-of-Speech Features | to_FuncW_C | total count of Function words |
939
+ | 183 | Synta | POSF_ | Part-of-Speech Features | as_FuncW_C | average count of Function words per sentence |
940
+ | 184 | Synta | POSF_ | Part-of-Speech Features | at_FuncW_C | average count of Function words per token |
941
+ | 185 | Synta | POSF_ | Part-of-Speech Features | ra_CoFuW_C | ratio of Content words to Function words |
942
+ | 186 | LxSem | VarF_ | Variation Ratio Features | SimpNoV_S | unique Nouns/total Nouns (Noun Variation-1) |
943
+ | 187 | LxSem | VarF_ | Variation Ratio Features | SquaNoV_S | (unique Nouns**2)/total Nouns (Squared Noun Variation-1) |
944
+ | 188 | LxSem | VarF_ | Variation Ratio Features | CorrNoV_S | unique Nouns/sqrt(2*total Nouns) (Corrected Noun Variation-1) |
945
+ | 189 | LxSem | VarF_ | Variation Ratio Features | SimpVeV_S | unique Verbs/total Verbs (Verb Variation-1) |
946
+ | 190 | LxSem | VarF_ | Variation Ratio Features | SquaVeV_S | (unique Verbs**2)/total Verbs (Squared Verb Variation-1) |
947
+ | 191 | LxSem | VarF_ | Variation Ratio Features | CorrVeV_S | unique Verbs/sqrt(2*total Verbs) (Corrected Verb Variation-1) |
948
+ | 192 | LxSem | VarF_ | Variation Ratio Features | SimpAjV_S | unique Adjectives/total Adjectives (Adjective Variation-1) |
949
+ | 193 | LxSem | VarF_ | Variation Ratio Features | SquaAjV_S | (unique Adjectives**2)/total Adjectives (Squared Adjective Variation-1) |
950
+ | 194 | LxSem | VarF_ | Variation Ratio Features | CorrAjV_S | unique Adjectives/sqrt(2*total Adjectives) (Corrected Adjective Variation-1) |
951
+ | 195 | LxSem | VarF_ | Variation Ratio Features | SimpAvV_S | unique Adverbs/total Adverbs (AdVerb Variation-1) |
952
+ | 196 | LxSem | VarF_ | Variation Ratio Features | SquaAvV_S | (unique Adverbs**2)/total Adverbs (Squared AdVerb Variation-1) |
953
+ | 197 | LxSem | VarF_ | Variation Ratio Features | CorrAvV_S | unique Adverbs/sqrt(2*total Adverbs) (Corrected AdVerb Variation-1) |
954
+ | 198 | LxSem | TTRF_ | Type Token Ratio Features | SimpTTR_S | unique tokens/total tokens (TTR) |
955
+ | 199 | LxSem | TTRF_ | Type Token Ratio Features | CorrTTR_S | unique tokens/sqrt(2*total tokens) (Corrected TTR) |
956
+ | 200 | LxSem | TTRF_ | Type Token Ratio Features | BiLoTTR_S | log(unique tokens)/log(total tokens) (Bi-Logarithmic TTR) |
957
+ | 201 | LxSem | TTRF_ | Type Token Ratio Features | UberTTR_S | (log(unique tokens))^2/log(total tokens/unique tokens) (Uber Index) |
958
+ | 202 | LxSem | TTRF_ | Type Token Ratio Features | MTLDTTR_S | Measure of Textual Lexical Diversity (default TTR = 0.72) |
959
+ | 203 | LxSem | PsyF_ | Psycholinguistic Features | to_AAKuW_C | total AoA (Age of Acquisition) of words |
960
+ | 204 | LxSem | PsyF_ | Psycholinguistic Features | as_AAKuW_C | average AoA of words per sentence |
961
+ | 205 | LxSem | PsyF_ | Psycholinguistic Features | at_AAKuW_C | average AoA of words per token |
962
+ | 206 | LxSem | PsyF_ | Psycholinguistic Features | to_AAKuL_C | total lemmas AoA of lemmas |
963
+ | 207 | LxSem | PsyF_ | Psycholinguistic Features | as_AAKuL_C | average lemmas AoA of lemmas per sentence |
964
+ | 208 | LxSem | PsyF_ | Psycholinguistic Features | at_AAKuL_C | average lemmas AoA of lemmas per token |
965
+ | 209 | LxSem | PsyF_ | Psycholinguistic Features | to_AABiL_C | total lemmas AoA of lemmas, Bird norm |
966
+ | 210 | LxSem | PsyF_ | Psycholinguistic Features | as_AABiL_C | average lemmas AoA of lemmas, Bird norm per sentence |
967
+ | 211 | LxSem | PsyF_ | Psycholinguistic Features | at_AABiL_C | average lemmas AoA of lemmas, Bird norm per token |
968
+ | 212 | LxSem | PsyF_ | Psycholinguistic Features | to_AABrL_C | total lemmas AoA of lemmas, Bristol norm |
969
+ | 213 | LxSem | PsyF_ | Psycholinguistic Features | as_AABrL_C | average lemmas AoA of lemmas, Bristol norm per sentence |
970
+ | 214 | LxSem | PsyF_ | Psycholinguistic Features | at_AABrL_C | average lemmas AoA of lemmas, Bristol norm per token |
971
+ | 215 | LxSem | PsyF_ | Psycholinguistic Features | to_AACoL_C | total AoA of lemmas, Cortese and Khanna norm |
972
+ | 216 | LxSem | PsyF_ | Psycholinguistic Features | as_AACoL_C | average AoA of lemmas, Cortese and Khanna norm per sentence |
973
+ | 217 | LxSem | PsyF_ | Psycholinguistic Features | at_AACoL_C | average AoA of lemmas, Cortese and Khanna norm per token |
974
+ | 218 | LxSem | WorF_ | Word Familiarity | to_SbFrQ_C | total SubtlexUS FREQcount value |
975
+ | 219 | LxSem | WorF_ | Word Familiarity | as_SbFrQ_C | average SubtlexUS FREQcount value per sentenc |
976
+ | 220 | LxSem | WorF_ | Word Familiarity | at_SbFrQ_C | average SubtlexUS FREQcount value per token |
977
+ | 221 | LxSem | WorF_ | Word Familiarity | to_SbCDC_C | total SubtlexUS CDcount value |
978
+ | 222 | LxSem | WorF_ | Word Familiarity | as_SbCDC_C | average SubtlexUS CDcount value per sentence |
979
+ | 223 | LxSem | WorF_ | Word Familiarity | at_SbCDC_C | average SubtlexUS CDcount value per token |
980
+ | 224 | LxSem | WorF_ | Word Familiarity | to_SbFrL_C | total SubtlexUS FREQlow value |
981
+ | 225 | LxSem | WorF_ | Word Familiarity | as_SbFrL_C | average SubtlexUS FREQlow value per sentence |
982
+ | 226 | LxSem | WorF_ | Word Familiarity | at_SbFrL_C | average SubtlexUS FREQlow value per token |
983
+ | 227 | LxSem | WorF_ | Word Familiarity | to_SbCDL_C | total SubtlexUS CDlow value |
984
+ | 228 | LxSem | WorF_ | Word Familiarity | as_SbCDL_C | average SubtlexUS CDlow value per sentence |
985
+ | 229 | LxSem | WorF_ | Word Familiarity | at_SbCDL_C | average SubtlexUS CDlow value per token |
986
+ | 230 | LxSem | WorF_ | Word Familiarity | to_SbSBW_C | total SubtlexUS SUBTLWF value |
987
+ | 231 | LxSem | WorF_ | Word Familiarity | as_SbSBW_C | average SubtlexUS SUBTLWF value per sentence |
988
+ | 232 | LxSem | WorF_ | Word Familiarity | at_SbSBW_C | average SubtlexUS SUBTLWF value per token |
989
+ | 233 | LxSem | WorF_ | Word Familiarity | to_SbL1W_C | total SubtlexUS Lg10WF value |
990
+ | 234 | LxSem | WorF_ | Word Familiarity | as_SbL1W_C | average SubtlexUS Lg10WF value per sentence |
991
+ | 235 | LxSem | WorF_ | Word Familiarity | at_SbL1W_C | average SubtlexUS Lg10WF value per token |
992
+ | 236 | LxSem | WorF_ | Word Familiarity | to_SbSBC_C | total SubtlexUS SUBTLCD value |
993
+ | 237 | LxSem | WorF_ | Word Familiarity | as_SbSBC_C | average SubtlexUS SUBTLCD value per sentence |
994
+ | 238 | LxSem | WorF_ | Word Familiarity | at_SbSBC_C | average SubtlexUS SUBTLCD value per token |
995
+ | 239 | LxSem | WorF_ | Word Familiarity | to_SbL1C_C | total SubtlexUS Lg10CD value |
996
+ | 240 | LxSem | WorF_ | Word Familiarity | as_SbL1C_C | average SubtlexUS Lg10CD value per sentence |
997
+ | 241 | LxSem | WorF_ | Word Familiarity | at_SbL1C_C | average SubtlexUS Lg10CD value per token |
998
+ | 242 | ShaTr | ShaF_ | Shallow Features | TokSenM_S | total count of tokens x total count of sentence |
999
+ | 243 | ShaTr | ShaF_ | Shallow Features | TokSenS_S | sqrt(total count of tokens x total count of sentence) |
1000
+ | 244 | ShaTr | ShaF_ | Shallow Features | TokSenL_S | log(total count of tokens)/log(total count of sentence) |
1001
+ | 245 | ShaTr | ShaF_ | Shallow Features | as_Token_C | average count of tokens per sentence |
1002
+ | 246 | ShaTr | ShaF_ | Shallow Features | as_Sylla_C | average count of syllables per sentence |
1003
+ | 247 | ShaTr | ShaF_ | Shallow Features | at_Sylla_C | average count of syllables per token |
1004
+ | 248 | ShaTr | ShaF_ | Shallow Features | as_Chara_C | average count of characters per sentence |
1005
+ | 249 | ShaTr | ShaF_ | Shallow Features | at_Chara_C | average count of characters per token |
1006
+ | 250 | ShaTr | TraF_ | Traditional Formulas | SmogInd_S | Smog Index |
1007
+ | 251 | ShaTr | TraF_ | Traditional Formulas | ColeLia_S | Coleman Liau Readability Score |
1008
+ | 252 | ShaTr | TraF_ | Traditional Formulas | Gunning_S | Gunning Fog Count Score |
1009
+ | 253 | ShaTr | TraF_ | Traditional Formulas | AutoRea_S | New Automated Readability Index |
1010
+ | 254 | ShaTr | TraF_ | Traditional Formulas | FleschG_S | Flesch Kincaid Grade Level |
1011
+ | 255 | ShaTr | TraF_ | Traditional Formulas | LinseaW_S | Linsear Write Formula Score"""
1012
+
1013
+ lsca_names = lca_names + sca_names
1014
+ name_map = {lsca_names[i]: full_names[i] for i in range(len(lsca_names))}
1015
+
1016
+ type_map = {lingfeat_names[i]: lingfeat_subtypes[i] for i in range(len(lingfeat_names))}
1017
+ type_map.update({n: 'lexical' for n in lca_names})
1018
+ type_map.update({n: 'syntax' for n in sca_names})
1019
+
1020
+
1021
+ # from lingfeat_full_names import lf_names
1022
+ lf_names = lf_names.split('\n')
1023
+
1024
+ lf_names = [tuple(x.split('|')[5:7]) for x in lf_names]
1025
+ lf_map = {k.strip(): v.strip() for k,v in lf_names}
1026
+ name_map.update(lf_map)
1027
+
1028
+ used_indices = [
1029
+ 1, 2, 3, 4, 5, 6, 7, 10, 11, 18, 25, 30, 31, 34, 35, 36, 37, 39, 40, 41, 57,
1030
+ 63, 64, 65, 66, 67, 68, 73, 121, 124, 129, 134, 136, 254,
1031
+ 257, 258, 261, 263, 272, 274
1032
+ ]
1033
+
1034
+ eval_indices = [4,5,6,18,257,272]
1035
+ eval_indices = [used_indices.index(idx) for idx in eval_indices]
1036
+
1037
+ lftk_df = pd.read_csv('lftk_ids.csv')
1038
+
1039
+ lftk_types = {row['key']: row['domain'] for i,row in lftk_df.iterrows()}
1040
+ type_map.update(lftk_types)
1041
+
1042
+ type_map = {k:\
1043
+ 'syntax' if v == 'surface'\
1044
+ else 'lexical' if v == 'lexico-semantics'\
1045
+ else v\
1046
+ for k,v in type_map.items()}
1047
+
1048
+ lftkplus_names = lca_names + sca_names + lftk_names
1049
+ lftkplus_names = [lftkplus_names[i] for i in used_indices]
1050
+
1051
+ lftk_map = {k: v for k,v in zip(lftk_names, lftk_full_names)}
1052
+ name_map.update(lftk_map)
1053
+ rev_name_map = {v: k for k,v in name_map.items()}
demo.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def run_gradio(model, tokenizer, scaler, ling_collection, examples=None, lng_names=None, M=None):
2
+ import numpy as np
3
+ import torch
4
+ from datetime import datetime
5
+ from compute_lng import compute_lng
6
+ import gradio as gr
7
+ m = np.load('assets/m.npy')
8
+ m = -1/m
9
+ m[m == -np.inf] = 0
10
+ m /= 100
11
+ device = model.backbone.device
12
+
13
+ def visibility(mode):
14
+ if mode == 0:
15
+ vis_group = group1
16
+ elif mode == 1:
17
+ vis_group = group2
18
+ elif mode == 2:
19
+ vis_group = group3
20
+
21
+ output = [gr.update(value=''), gr.update(value='')]
22
+ for component in components:
23
+ if component in vis_group:
24
+ output.append(gr.update(visible=True))
25
+ else:
26
+ output.append(gr.update(visible=False))
27
+ return output
28
+
29
+ def generate(sent1, ling):
30
+ input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
31
+ ling1 = scaler.transform([ling['Source']])
32
+ ling2 = scaler.transform([ling['Target']])
33
+ inputs = {'sentence1_input_ids': input_ids,
34
+ 'sentence1_ling': torch.tensor(ling1).float().to(device),
35
+ 'sentence2_ling': torch.tensor(ling2).float().to(device),
36
+ 'sentence1_attention_mask': torch.ones_like(input_ids)}
37
+ preds = []
38
+ with torch.no_grad():
39
+ pred = model.infer(inputs).cpu().numpy()
40
+ pred = tokenizer.batch_decode(pred,
41
+ skip_special_tokens=True)[0]
42
+
43
+ return pred
44
+
45
+ def generate_with_feedbacks(sent1, ling):
46
+ preds = []
47
+ eta = 0.1
48
+ input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
49
+ ling1 = torch.tensor(scaler.transform([ling['Source']])).float().to(device)
50
+ ling2 = torch.tensor(scaler.transform([ling['Target']])).float().to(device)
51
+ ling1_embed = model.ling_embed(ling1)
52
+ ling2_embed = model.ling_embed(ling2)
53
+ cur_ling = ling1_embed + eta * (ling2_embed - ling1_embed)
54
+ inputs = {'sentence1_input_ids': input_ids,
55
+ 'sent1_ling_embed': ling1_embed,
56
+ 'sent2_ling_embed': ling2_embed,
57
+ 'sentence1_attention_mask': torch.ones_like(input_ids)}
58
+ converged = False
59
+ c = 0
60
+ while not converged:
61
+ with torch.no_grad():
62
+ pred = model.infer(inputs)
63
+ inputs_pred = inputs.copy()
64
+ inputs_pred.update({'input_ids': pred,
65
+ 'attention_mask': torch.ones_like(pred)})
66
+ ling_pred = model.ling_disc(**inputs_pred)
67
+ ling_pred_embed = model.ling_embed(ling_pred)
68
+
69
+ if len(interpolations) == 0 or pred != interpolations[-1]:
70
+ interpolations.append(pred)
71
+
72
+ diff = torch.mean((ling2_embed - ling_pred_embed)**2)
73
+ scale = torch.norm(cur_ling)/torch.norm(ling2)
74
+
75
+ # print(f'Diff: {diff.item():.3f} / Scale: ({scale.item():.3f})>> {tokenizer.batch_decode(pred.cpu().numpy(), skip_special_tokens=True)[0]}')
76
+ if diff < 1e-5 or c >= 50:
77
+ converged = True
78
+ else:
79
+ # cur_ling = cur_ling + eta * (ling2_embed - ling_pred_embed)
80
+ inputs.update({
81
+ 'sentence1_input_ids': pred,
82
+ # 'sent2_ling_embed': ling2_embed,
83
+ 'sentence1_attention_mask': torch.ones_like(pred)
84
+ })
85
+ c += 1
86
+
87
+ pred = tokenizer.batch_decode(pred.cpu().numpy(),
88
+ skip_special_tokens=True)[0]
89
+
90
+ return pred
91
+ def generate_with_feedback(sent1, ling, approx):
92
+ if sent1 == '':
93
+ return ['Please input a source text.', '']
94
+ preds = []
95
+ interpolations = []
96
+ input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
97
+ ling1 = torch.tensor(scaler.transform([ling['Source']])).float().to(device)
98
+ ling2 = torch.tensor(scaler.transform([ling['Target']])).float().to(device)
99
+ ling1_embed = model.ling_embed(ling1)
100
+ ling2_embed = model.ling_embed(ling2)
101
+ inputs = {'sentence1_input_ids': input_ids,
102
+ 'sent1_ling_embed': ling1_embed,
103
+ 'sent2_ling_embed': ling2_embed,
104
+ 'sentence1_attention_mask': torch.ones_like(input_ids)}
105
+ converged = False
106
+ c = 0
107
+ eta = 0.3
108
+ while not converged:
109
+ with torch.no_grad():
110
+ pred = model.infer(inputs)
111
+ inputs_pred = inputs.copy()
112
+ inputs_pred.update({'input_ids': pred,
113
+ 'attention_mask': torch.ones_like(pred)})
114
+ pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
115
+ skip_special_tokens=True)[0]
116
+ if 'approximate' in approx:
117
+ ling_pred = model.ling_disc(**inputs_pred)
118
+ elif 'exact' in approx:
119
+ ling_pred = compute_lng(pred_text)
120
+ ling_pred = scaler.transform([ling_pred])[0]
121
+ ling_pred = torch.tensor(ling_pred).to(pred.device).float()
122
+ else:
123
+ raise ValueError()
124
+ ling_pred_embed = model.ling_embed(ling_pred)
125
+
126
+ if len(interpolations) == 0 or pred_text != interpolations[-1]:
127
+ interpolations.append(pred_text)
128
+
129
+ diff = torch.mean((ling2_embed - ling_pred_embed)**2)
130
+
131
+ # print(f'Diff {diff.item():.3f}>> {tokenizer.batch_decode(pred.cpu().numpy(), skip_special_tokens=True)[0]}')
132
+ if diff < 10 or c >= 50:
133
+ converged = True
134
+ else:
135
+ ling2_embed = ling2_embed + eta * (ling_pred_embed - ling2_embed)
136
+ inputs.update({'sent2_ling_embed': ling2_embed})
137
+ c += 1
138
+
139
+
140
+ interpolation = '-- ' + '\n-- '.join(interpolations)
141
+ return [pred_text, interpolation]
142
+
143
+ def generate_random(sent1, ling, count, approx):
144
+ preds, interpolations = [], []
145
+ for c in range(count):
146
+ idx = np.random.randint(0, len(ling_collection))
147
+ ling_ex = ling_collection[idx]
148
+ ling['Target'] = ling_ex
149
+ pred, interpolation = generate_with_feedback(sent1, ling, approx)
150
+ preds.append(pred)
151
+ interpolations.append(interpolation)
152
+ return '\n***\n'.join(preds), '\n***\n'.join(interpolations), ling
153
+
154
+ def estimate_gen(sent1, sent2, ling, approx):
155
+ if 'approximate' in approx:
156
+ input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
157
+ with torch.no_grad():
158
+ ling_pred = model.ling_disc(input_ids=input_ids).cpu().numpy()
159
+ ling_pred = scaler.inverse_transform(ling_pred)[0]
160
+ elif 'exact' in approx:
161
+ ling_pred = compute_lng(sent2)
162
+ else:
163
+ raise ValueError()
164
+
165
+ ling['Target'] = ling_pred
166
+ gen = generate_with_feedback(sent1, ling, approx)
167
+ results = gen + [ling]
168
+
169
+ return results
170
+
171
+ def estimate_tgt(sent2, ling, approx):
172
+ if 'approximate' in approx:
173
+ input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
174
+ with torch.no_grad():
175
+ ling_pred = model.ling_disc(input_ids=input_ids).cpu().numpy()
176
+ ling_pred = scaler.inverse_transform(ling_pred)[0]
177
+ elif 'exact' in approx:
178
+ ling_pred = compute_lng(sent2)
179
+ else:
180
+ raise ValueError()
181
+
182
+ ling['Target'] = ling_pred
183
+ return ling
184
+
185
+ def estimate_src(sent1, ling, approx):
186
+ if 'approximate' in approx:
187
+ input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
188
+ with torch.no_grad():
189
+ ling_pred = model.ling_disc(input_ids=input_ids).cpu().numpy()
190
+ ling_pred = scaler.inverse_transform(ling_pred)[0]
191
+ elif 'exact' in approx:
192
+ ling_pred = compute_lng(sent1)
193
+ else:
194
+ raise ValueError()
195
+
196
+ ling['Source'] = ling_pred
197
+ return ling
198
+
199
+ def rand_target(ling):
200
+ ling['Target'] = scaler.inverse_transform([np.random.randn(*ling['Target'].shape)])[0]
201
+ return ling
202
+
203
+ def rand_ex_target(ling):
204
+ idx = np.random.randint(0, len(examples))
205
+ ling_ex = examples[idx][1]
206
+ ling['Target'] = ling_ex['Target']
207
+ return ling
208
+
209
+ def copy(ling):
210
+ ling['Target'] = ling['Source']
211
+ return ling
212
+
213
+ def add_noise(ling):
214
+ x = scaler.transform([ling['Target']])
215
+ x += np.random.randn(*ling['Target'].shape)
216
+ x = scaler.inverse_transform(x)[0]
217
+ ling['Target'] = x
218
+ return ling
219
+
220
+ def add(ling):
221
+ x = scaler.transform([ling['Target']])
222
+ x += m
223
+ x = scaler.inverse_transform(x)[0]
224
+ ling['Target'] = x
225
+ return ling
226
+
227
+ def sub(ling):
228
+ x = scaler.transform([ling['Target']])
229
+ x -= m
230
+ x = scaler.inverse_transform(x)[0]
231
+ ling['Target'] = x
232
+ return ling
233
+
234
+ # title = ''
235
+ # for i, model in enumerate(models):
236
+ # if i > 0:
237
+ # title += '\n'
238
+ # title += f"model ({i})\n\tUsing VAE = {model.args.ling_vae}\n\tUsing ICA = {model.args.use_ica}\n\tNumber of features = {model.args.lng_dim if not model.args.use_ica else model.args.n_ica}"
239
+ title = """
240
+ # LingConv: A System for Controlled Linguistic Conversion
241
+
242
+ ## Description
243
+
244
+ This system is an encoder-decoder model for complexity controlled text generation, guided by 241
245
+ linguistic complexity indices as key attributes. Given a sentence and a desired level of linguistic
246
+ complexity, the model can generate diverse paraphrases that maintain consistent meaning, adjusted for
247
+ different linguistic complexity levels. However, it's important to note that not all index combinations are
248
+ feasible (such as requesting a sentence of "length" 5 with 10 "unique words"). To ensure high quality
249
+ outputs, our approach interpolates the embedding of linguistic indices to locate the most closely matched,
250
+ achievable set of indices for the given target.
251
+ """
252
+
253
+ guide = """
254
+ You may use the system in on of the following ways:
255
+
256
+ **Randomized Paraphrase Generation**: Select this option to produce multiple paraphrases with a range
257
+ of linguistic complexity. You need to provide a source text, specify the number of paraphrases you want,
258
+ and click "Generate." The linguistic complexity of the paraphrases will be determined randomly.
259
+
260
+ **Complexity-Matched Paraphrasing**: Select this option to generate a paraphrase of the given source
261
+ sentence that closely mirrors the linguistic complexity of another given sentence. Input your source
262
+ sentence along with another sentence (which will serve only to extract linguistic indices for the
263
+ paraphrase generation). Then, click "Generate."
264
+
265
+ **Manual Linguistic Control**: Select this option to manually control the linguistic complexity of the
266
+ generated text. We provided a set of tools for manual adjustments of the desired linguistic complexity of
267
+ the target sentence. These tools enable the user to extract linguistic indices from a given sentence,
268
+ generate a random (yet coherent) set of linguistic indices, and add or remove noise from the indices.
269
+ These tools are designed for experimental use and require the user to possess linguistic expertise for
270
+ effective input of linguistic indices. To use these tools, select "Tools to assist in setting linguistic
271
+ indices." Once indices are entered, click "Generate."
272
+
273
+
274
+ Second, you may select to use exact or approximate computation of linguistic indices (used in mode (2) and
275
+ in quality control of the genration). Approximate computation is significantly faster.
276
+
277
+ Third, you may view the intermediate sentences of the quality control process by selecting the checkbox.
278
+
279
+ Fourth, you may try out some examples by clicking on "Examples...". Examples consist of a source sentences,
280
+ the indices of the source sentences, and a sample set of target linguistic indices.
281
+
282
+ Please make your choice below.
283
+
284
+ """
285
+
286
+ sent1 = gr.Textbox(label='Source text')
287
+ ling = gr.Dataframe(value = [[x, 0, 0] for x in lng_names],
288
+ headers=['Index', 'Source', 'Target'],
289
+ datatype=['str', 'number', 'number'], visible=False)
290
+ css = """
291
+ #guide span.svelte-s1r2yt {font-size: 22px !important;
292
+ font-weight: 600 !important}
293
+ """
294
+ with gr.Blocks(css=css) as demo:
295
+ gr.Markdown(title)
296
+ with gr.Accordion("Quick Start Guide", open=False, elem_id='guide'):
297
+ gr.Markdown(guide)
298
+
299
+ mode = gr.Radio(value='Randomized Paraphrase Generation',
300
+ label='How would you like to use this system?',
301
+ type="index",
302
+ choices=['Randomized Paraphrase Generation',
303
+ 'Complexity-Matched Paraphrasing', 'Manual Linguistic Control'])
304
+ approx = gr.Radio(value='Use approximate computation of linguistic indices (faster)',
305
+ choices=['Use approximate computation of linguistic indices (faster)',
306
+ 'Use exact computation of linguistic indices'], container=False, show_label=False)
307
+ control_interpolation = gr.Checkbox(label='View the intermediate sentences in the interpolation of linguistic indices')
308
+
309
+ with gr.Accordion("Examples...", open=False):
310
+ gr.Examples(examples, [sent1, ling], examples_per_page=4, label=None)
311
+
312
+ with gr.Row():
313
+ sent1.render()
314
+ with gr.Column():
315
+ sent2 = gr.Textbox(label='Generated text')
316
+ interpolation = gr.Textbox(label='Quality control interpolation', visible=False, lines=5)
317
+ #####################
318
+ with gr.Row():
319
+ generate_random_btn = gr.Button("Generate",
320
+ variant='primary', scale=1, visible=True)
321
+ count = gr.Number(label='Number of generated sentences', value=3, precision=0, scale=1, visible=True)
322
+ # generate_fb_btn = gr.Button("Generate with auto-adjust (towards pred)")
323
+ # generate_fb_s_btn = gr.Button("Generate with auto-adjust (moving s)")
324
+ # add_noise_btn = gr.Button('Add noise to target linguistic indices')
325
+ #####################
326
+ with gr.Row():
327
+ estimate_gen_btn = gr.Button("Generate",
328
+ variant='primary',
329
+ scale=1, visible=False)
330
+ sent_ling_gen = gr.Textbox(label='Text to estimate linguistic indices', scale=1, visible=False)
331
+ #####################
332
+ generate_btn = gr.Button("Generate", variant='primary', visible=False)
333
+ with gr.Accordion("Tools to assist in the setting of linguistic indices...", open=False, visible=False) as ling_tools:
334
+ with gr.Row():
335
+ estimate_tgt_btn = gr.Button("Estimate linguistic indices of this sentence", visible=False)
336
+ sent_ling_est = gr.Textbox(label='Text to estimate linguistic indices', scale=2, visible=False)
337
+ estimate_src_btn = gr.Button("Estimate linguistic indices of source sentence", visible=False)
338
+ # rand_btn = gr.Button("Random target")
339
+ rand_ex_btn = gr.Button("Random target", size='lg', visible=False)
340
+ copy_btn = gr.Button("Copy linguistic indices of source to target", size='sm', visible=False)
341
+ with gr.Row():
342
+ add_btn = gr.Button('Add \u03B5 to target linguistic indices', visible=False)
343
+ sub_btn = gr.Button('Subtract \u03B5 from target linguistic indices', visible=False)
344
+ ling.render()
345
+ #####################
346
+
347
+ estimate_src_btn.click(estimate_src, inputs=[sent1, ling, approx], outputs=[ling])
348
+ estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling_est, ling, approx], outputs=[ling])
349
+ # estimate_tgt_btn.click(estimate_tgt, inputs=[sent_ling, ling], outputs=[ling])
350
+ estimate_gen_btn.click(estimate_gen, inputs=[sent1, sent_ling_gen, ling, approx], outputs=[sent2, interpolation, ling])
351
+ # rand_btn.click(rand_target, inputs=[ling], outputs=[ling])
352
+ rand_ex_btn.click(rand_ex_target, inputs=[ling], outputs=[ling])
353
+ copy_btn.click(copy, inputs=[ling], outputs=[ling])
354
+ generate_btn.click(generate_with_feedback, inputs=[sent1, ling, approx], outputs=[sent2, interpolation])
355
+ generate_random_btn.click(generate_random, inputs=[sent1, ling, count, approx],
356
+ outputs=[sent2, interpolation, ling])
357
+ # generate_fb_btn.click(generate_with_feedback, inputs=[sent1, ling], outputs=sent2s)
358
+ # generate_fb_s_btn.click(generate_with_feedbacks, inputs=[sent1, ling], outputs=sent2s)
359
+ add_btn.click(add, inputs=[ling], outputs=[ling])
360
+ sub_btn.click(sub, inputs=[ling], outputs=[ling])
361
+ # add_noise_btn.click(add_noise, inputs=[ling], outputs=[ling])
362
+
363
+ group1 = [generate_random_btn, count]
364
+ group2 = [estimate_gen_btn, sent_ling_gen]
365
+ group3 = [generate_btn, estimate_src_btn, estimate_tgt_btn, sent_ling_est, rand_ex_btn, copy_btn, add_btn, sub_btn, ling, ling_tools]
366
+ components = group1 + group2 + group3
367
+ mode.change(visibility, inputs=[mode], outputs=[sent2, interpolation] + components)
368
+ control_interpolation.change(lambda v: gr.update(visible=v), inputs=[control_interpolation],
369
+ outputs=[interpolation])
370
+
371
+ demo.launch(share=True)
model.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from torch import nn
6
+ from transformers import T5ForConditionalGeneration, T5EncoderModel, AutoModel, LogitsProcessor, LogitsProcessorList
7
+ from functools import partial
8
+ from compute_lng import compute_lng
9
+ from undecorate import unwrap
10
+ from types import MethodType
11
+ from utils import *
12
+ from ling_disc import DebertaReplacedTokenizer
13
+ from const import *
14
+
15
+
16
+
17
+ def vae_sample(mu, logvar):
18
+ std = torch.exp(0.5 * logvar)
19
+ eps = torch.randn_like(std)
20
+ return eps * std + mu
21
+
22
+ class VAE(nn.Module):
23
+ def __init__(self, args):
24
+ super().__init__()
25
+ self.encoder = nn.Sequential(
26
+ nn.Linear(args.input_dim, args.hidden_dim),
27
+ nn.ReLU(),
28
+ nn.Linear(args.hidden_dim, args.hidden_dim),
29
+ nn.ReLU(),
30
+ )
31
+ self.decoder = nn.Sequential(
32
+ nn.Linear(args.latent_dim, args.hidden_dim),
33
+ nn.ReLU(),
34
+ nn.Linear(args.hidden_dim, args.hidden_dim),
35
+ nn.ReLU(),
36
+ nn.Linear(args.hidden_dim, args.input_dim),
37
+ )
38
+ self.fc_mu = nn.Linear(args.hidden_dim, args.latent_dim)
39
+ self.fc_var = nn.Linear(args.hidden_dim, args.latent_dim)
40
+
41
+ def forward(self, x):
42
+ h = self.encoder(x)
43
+ mu = self.fc_mu(h)
44
+ logvar = self.fc_var(h)
45
+ x = vae_sample(mu, logvar)
46
+ o = self.decoder(x)
47
+ return o, (mu, logvar)
48
+
49
+ class LingGenerator(nn.Module):
50
+ def __init__(self, args, hidden_dim=1000):
51
+ super().__init__()
52
+
53
+ self.gen = T5EncoderModel.from_pretrained('google/flan-t5-small')
54
+ self.hidden_size = self.gen.config.d_model
55
+ self.ling_embed = nn.Linear(args.lng_dim, self.hidden_size)
56
+ # self.gen = nn.Sequential(
57
+ # nn.Linear(args.lng_dim, 2*hidden_dim),
58
+ # nn.ReLU(),
59
+ # nn.BatchNorm1d(2*hidden_dim),
60
+ # nn.Linear(2*hidden_dim, 2*hidden_dim),
61
+ # nn.ReLU(),
62
+ # nn.BatchNorm1d(2*hidden_dim),
63
+ # nn.Linear(2*hidden_dim, hidden_dim),
64
+ # nn.ReLU(),
65
+ # )
66
+
67
+ self.gen_type = args.linggen_type
68
+ self.gen_input = args.linggen_input
69
+ if self.gen_type == 'vae':
70
+ self.gen_mu = nn.Linear(hidden_dim, args.lng_dim)
71
+ self.gen_logvar = nn.Linear(hidden_dim, args.lng_dim)
72
+ elif self.gen_type == 'det':
73
+ self.projection = nn.Linear(self.hidden_size, args.lng_dim)
74
+
75
+ def forward(self, batch):
76
+ inputs_embeds = self.gen.shared(batch['sentence1_input_ids'])
77
+ inputs_att_mask = batch['sentence1_attention_mask']
78
+ bs = inputs_embeds.shape[0]
79
+
80
+ if self.gen_input == 's+l':
81
+ sent1_ling = self.ling_embed(batch['sentence1_ling'])
82
+ sent1_ling = sent1_ling.view(bs, 1, -1)
83
+ inputs_embeds = inputs_embeds + sent1_ling
84
+
85
+ gen = self.gen(inputs_embeds=inputs_embeds,
86
+ attention_mask=inputs_att_mask).last_hidden_state.mean(1)
87
+ # gen = self.gen(batch['sentence1_ling'])
88
+
89
+ cache = {}
90
+ if self.gen_type == 'vae':
91
+ mu = self.gen_mu(gen)
92
+ logvar = self.gen_logvar(gen)
93
+ output = vae_sample(mu, logvar)
94
+ cache['linggen_mu'] = mu
95
+ cache['linggen_logvar'] = logvar
96
+ elif self.gen_type == 'det':
97
+ output = self.projection(gen)
98
+
99
+ return output, cache
100
+
101
+
102
+ class LingDisc(nn.Module):
103
+ def __init__(self,
104
+ model_name,
105
+ disc_type,
106
+ disc_ckpt,
107
+ lng_dim=40,
108
+ quant_nbins=1,
109
+ disc_lng_dim=None,
110
+ lng_ids=None,
111
+ **kwargs):
112
+ super().__init__()
113
+ if disc_type == 't5':
114
+ self.encoder = T5EncoderModel.from_pretrained(model_name)
115
+ hidden_dim = self.encoder.config.d_model
116
+ self.dropout = nn.Dropout(0.2)
117
+ self.lng_dim = disc_lng_dim if disc_lng_dim else lng_dim
118
+ self.quant = quant_nbins > 1
119
+ self.quant = False
120
+ if self.quant:
121
+ self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim * quant_nbins)
122
+ else:
123
+ self.ling_classifier = nn.Linear(hidden_dim, self.lng_dim)
124
+ lng_ids = torch.tensor(lng_ids) if lng_ids is not None else None
125
+ # from const import used_indices
126
+ # lng_ids = torch.tensor(used_indices)
127
+ self.register_buffer('lng_ids', lng_ids)
128
+ elif disc_type == 'deberta':
129
+ self.encoder= DebertaReplacedTokenizer.from_pretrained(
130
+ pretrained_model_name_or_path=disc_ckpt,
131
+ tok_model_name = model_name,
132
+ problem_type='regression', num_labels=40)
133
+ self.quant = False
134
+
135
+ self.disc_type = disc_type
136
+
137
+ def forward(self, **batch):
138
+ if not 'attention_mask' in batch:
139
+ if 'input_ids' in batch:
140
+ att_mask = torch.ones_like(batch['input_ids'])
141
+ else:
142
+ att_mask = torch.ones_like(batch['logits'])[:,:,0]
143
+ else:
144
+ att_mask = batch['attention_mask']
145
+ if 'input_ids' in batch:
146
+ enc_output = self.encoder(input_ids=batch['input_ids'],
147
+ attention_mask=att_mask)
148
+ elif 'logits' in batch:
149
+ logits = batch['logits']
150
+ scores = F.softmax(logits, dim = -1)
151
+ onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device)
152
+ onehot_ = scores - scores.detach() + onehot
153
+
154
+ embed_layer = self.encoder.get_input_embeddings()
155
+ if isinstance(embed_layer, nn.Sequential):
156
+ for i, module in enumerate(embed_layer):
157
+ if i == 0:
158
+ embeds = torch.matmul(onehot_, module.weight)
159
+ else:
160
+ embeds = module(embeds)
161
+ else:
162
+ embeds = onehot_ @ embed_layer.weight
163
+ embeds = torch.matmul(onehot_, embed_layer.weight)
164
+
165
+ enc_output = self.encoder(inputs_embeds=embeds,
166
+ attention_mask=att_mask)
167
+ if self.disc_type == 't5':
168
+ sent_emb = self.dropout(enc_output.last_hidden_state.mean(1))
169
+ bs = sent_emb.shape[0]
170
+ output = self.ling_classifier(sent_emb)
171
+ if self.quant:
172
+ output = output.reshape(bs, -1, self.lng_dim)
173
+ if self.lng_ids is not None:
174
+ output = torch.index_select(output, 1, self.lng_ids)
175
+ elif self.disc_type == 'deberta':
176
+ output = enc_output.logits
177
+ return output
178
+
179
+ class SemEmb(nn.Module):
180
+ def __init__(self, backbone, sep_token_id):
181
+ super().__init__()
182
+ self.backbone = backbone
183
+ self.sep_token_id = sep_token_id
184
+ hidden_dim = self.backbone.config.d_model
185
+ self.projection = nn.Sequential(nn.ReLU(),
186
+ nn.Dropout(0.2),
187
+ nn.Linear(hidden_dim, 1))
188
+
189
+ def forward(self, **batch):
190
+ bs = batch['sentence1_attention_mask'].shape[0]
191
+ ones = torch.ones((bs, 1), device=batch['sentence1_attention_mask'].device)
192
+ sep = torch.ones((bs, 1), dtype=torch.long,
193
+ device=batch['sentence1_attention_mask'].device) * self.sep_token_id
194
+ att_mask = torch.cat([batch['sentence1_attention_mask'], ones, batch['sentence2_attention_mask']], dim=1)
195
+ if 'logits' in batch:
196
+ input_ids = torch.cat([batch['sentence1_input_ids'], sep], dim=1)
197
+ embeds1 = self.backbone.shared(input_ids)
198
+
199
+ logits = batch['logits']
200
+ scores = F.softmax(logits, dim = -1)
201
+ onehot = F.one_hot(logits.argmax(-1), num_classes=logits.shape[2]).float().to(logits.device)
202
+ onehot_ = scores - scores.detach() + onehot
203
+
204
+ embeds2 = onehot_ @ self.backbone.shared.weight
205
+ embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
206
+ hidden_units = self.backbone(inputs_embeds=embeds1_2,
207
+ attention_mask=att_mask).last_hidden_state.mean(1)
208
+ elif 'sentence2_input_ids' in batch:
209
+ input_ids = torch.cat([batch['sentence1_input_ids'], sep, batch['sentence2_input_ids']], dim=1)
210
+ hidden_units = self.backbone(input_ids=input_ids,
211
+ attention_mask=att_mask).last_hidden_state.mean(1)
212
+ probs = self.projection(hidden_units)
213
+ return probs
214
+
215
+ def prepare_inputs_for_generation(
216
+ combine_method,
217
+ ling2_only,
218
+ self,
219
+ input_ids,
220
+ past_key_values=None,
221
+ attention_mask=None,
222
+ head_mask=None,
223
+ decoder_head_mask=None,
224
+ cross_attn_head_mask=None,
225
+ use_cache=None,
226
+ encoder_outputs=None,
227
+ sent1_ling=None,
228
+ sent2_ling=None,
229
+ **kwargs
230
+ ):
231
+
232
+ # cut decoder_input_ids if past is used
233
+ if past_key_values is not None:
234
+ input_ids = input_ids[:, -1:]
235
+
236
+ input_ids = input_ids.clone()
237
+ decoder_inputs_embeds = self.shared(input_ids)
238
+
239
+ if combine_method == 'decoder_add_first':
240
+ sent2_ling = torch.cat([sent2_ling,
241
+ torch.repeat_interleave(torch.zeros_like(sent2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1)
242
+ if combine_method == 'decoder_concat':
243
+ if ling2_only:
244
+ decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
245
+ else:
246
+ decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
247
+ elif combine_method == 'decoder_add'or (past_key_values is None and combine_method == 'decoder_add_first'):
248
+ if ling2_only:
249
+ decoder_inputs_embeds = decoder_inputs_embeds + sent2_ling
250
+ else:
251
+ decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
252
+
253
+ return {
254
+ "decoder_inputs_embeds": decoder_inputs_embeds,
255
+ "past_key_values": past_key_values,
256
+ "encoder_outputs": encoder_outputs,
257
+ "attention_mask": attention_mask,
258
+ "head_mask": head_mask,
259
+ "decoder_head_mask": decoder_head_mask,
260
+ "cross_attn_head_mask": cross_attn_head_mask,
261
+ "use_cache": use_cache,
262
+ }
263
+
264
+ class LogitsAdd(LogitsProcessor):
265
+ def __init__(self, sent2_ling):
266
+ super().__init__()
267
+ self.sent2_ling = sent2_ling
268
+
269
+ def __call__(self, input_ids, scores):
270
+ return scores + self.sent2_ling
271
+
272
+ class EncoderDecoderVAE(nn.Module):
273
+ def __init__(self, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
274
+ super().__init__()
275
+ self.backbone = T5ForConditionalGeneration.from_pretrained(args.model_name)
276
+ self.backbone.prepare_inputs_for_generation = types.MethodType(
277
+ partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
278
+ self.backbone)
279
+ self.args = args
280
+ self.pad_token_id = pad_token_id
281
+ self.eos_token_id = sepeos_token_id
282
+ hidden_dim = self.backbone.config.d_model if not 'logits' in args.combine_method else vocab_size
283
+ if args.combine_method == 'fusion1':
284
+ self.fusion = nn.Sequential(
285
+ nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim),
286
+ )
287
+ elif args.combine_method == 'fusion2':
288
+ self.fusion = nn.Sequential(
289
+ nn.Linear(hidden_dim + 2 * args.lng_dim, hidden_dim),
290
+ nn.ReLU(),
291
+ nn.Linear(hidden_dim, hidden_dim),
292
+ )
293
+ elif 'concat' in args.combine_method or 'add' in args.combine_method:
294
+ if args.ling_embed_type == 'two-layer':
295
+ self.ling_embed = nn.Sequential(
296
+ nn.Linear(args.lng_dim, args.lng_dim),
297
+ nn.ReLU(),
298
+ nn.Linear(args.lng_dim, hidden_dim),
299
+ )
300
+ else:
301
+ self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
302
+ self.ling_dropout = nn.Dropout(args.ling_dropout)
303
+
304
+ if args.ling_vae:
305
+ self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
306
+ self.ling_logvar = nn.Linear(hidden_dim, hidden_dim)
307
+ nn.init.xavier_uniform_(self.ling_embed.weight)
308
+ nn.init.xavier_uniform_(self.ling_mu.weight)
309
+ nn.init.xavier_uniform_(self.ling_logvar.weight)
310
+
311
+
312
+ generate_with_grad = unwrap(self.backbone.generate)
313
+ self.backbone.generate_with_grad = MethodType(generate_with_grad, self.backbone)
314
+
315
+ def get_fusion_layer(self):
316
+ if 'fusion' in self.args.combine_method:
317
+ return self.fusion
318
+ elif 'concat' in self.args.combine_method or 'add' in self.args.combine_method:
319
+ return self.ling_embed
320
+ else:
321
+ return None
322
+
323
+ def sample(self, mu, logvar):
324
+ std = torch.exp(0.5 * logvar)
325
+ return mu + std * torch.randn_like(std)
326
+
327
+ def encode(self, batch):
328
+ if 'inputs_embeds' in batch:
329
+ inputs_embeds = batch['inputs_embeds']
330
+ else:
331
+ inputs_embeds = self.backbone.shared(batch['sentence1_input_ids'])
332
+ inputs_att_mask = batch['sentence1_attention_mask']
333
+ bs = inputs_embeds.shape[0]
334
+ cache = {}
335
+ if self.args.combine_method in ('input_concat', 'input_add'):
336
+ if 'sent1_ling_embed' in batch:
337
+ sent1_ling = batch['sent1_ling_embed']
338
+ else:
339
+ sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
340
+ if 'sent2_ling_embed' in batch:
341
+ sent2_ling = batch['sent2_ling_embed']
342
+ else:
343
+ sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
344
+ if self.args.ling_vae:
345
+ sent1_ling = F.leaky_relu(sent1_ling)
346
+ sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
347
+ sent1_ling = self.sample(sent1_mu, sent1_logvar)
348
+
349
+ sent2_ling = F.leaky_relu(sent2_ling)
350
+ sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
351
+ sent2_ling = self.sample(sent2_mu, sent2_logvar)
352
+ cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
353
+ 'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
354
+ 'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
355
+ else:
356
+ cache.update({'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
357
+ sent1_ling = sent1_ling.view(bs, 1, -1)
358
+ sent2_ling = sent2_ling.view(bs, 1, -1)
359
+ if self.args.combine_method == 'input_concat':
360
+ if self.args.ling2_only:
361
+ inputs_embeds = torch.cat([inputs_embeds, sent2_ling], dim=1)
362
+ inputs_att_mask = torch.cat([inputs_att_mask,
363
+ torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
364
+ else:
365
+ inputs_embeds = torch.cat([inputs_embeds, sent1_ling, sent2_ling], dim=1)
366
+ inputs_att_mask = torch.cat([inputs_att_mask,
367
+ torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
368
+ elif self.args.combine_method == 'input_add':
369
+ if self.args.ling2_only:
370
+ inputs_embeds = inputs_embeds + sent2_ling
371
+ else:
372
+ inputs_embeds = inputs_embeds + sent1_ling + sent2_ling
373
+ return self.backbone.encoder(inputs_embeds=inputs_embeds,
374
+ attention_mask=inputs_att_mask), inputs_att_mask, cache
375
+
376
+ def decode(self, batch, enc_output, inputs_att_mask, generate):
377
+ bs = inputs_att_mask.shape[0]
378
+ cache = {}
379
+ if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add', 'logits_add', 'decoder_add_first'):
380
+ if 'sent1_ling_embed' in batch:
381
+ sent1_ling = batch['sent1_ling_embed']
382
+ elif 'sentence1_ling' in batch:
383
+ sent1_ling = self.ling_embed(self.ling_dropout(batch['sentence1_ling']))
384
+ else:
385
+ sent1_ling = None
386
+ if 'sent2_ling_embed' in batch:
387
+ sent2_ling = batch['sent2_ling_embed']
388
+ else:
389
+ sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
390
+ if self.args.ling_vae:
391
+ sent1_ling = F.leaky_relu(sent1_ling)
392
+ sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
393
+ sent1_ling = self.sample(sent1_mu, sent1_logvar)
394
+
395
+ sent2_ling = F.leaky_relu(sent2_ling)
396
+ sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
397
+ sent2_ling = self.sample(sent2_mu, sent2_logvar)
398
+ cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
399
+ 'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
400
+ 'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
401
+ else:
402
+ cache.update({'sent2_ling': sent2_ling})
403
+ if sent1_ling is not None:
404
+ cache.update({'sent1_ling': sent1_ling})
405
+ if sent1_ling is not None:
406
+ sent1_ling = sent1_ling.view(bs, 1, -1)
407
+ sent2_ling = sent2_ling.view(bs, 1, -1)
408
+ if self.args.combine_method == 'decoder_add_first' and not generate:
409
+ sent2_ling = torch.cat([sent2_ling,
410
+ torch.repeat_interleave(torch.zeros_like(sent2_ling), batch['sentence2_input_ids'].shape[1] - 1, dim=1)], dim = 1)
411
+ else:
412
+ sent1_ling, sent2_ling = None, None
413
+
414
+ if self.args.combine_method == 'embed_concat':
415
+ enc_output.last_hidden_state = torch.cat([enc_output.last_hidden_state,
416
+ sent1_ling, sent2_ling], dim=1)
417
+ inputs_att_mask = torch.cat([inputs_att_mask,
418
+ torch.ones((bs, 2)).to(inputs_att_mask.device)], dim=1)
419
+ elif 'fusion' in self.args.combine_method:
420
+ sent1_ling = batch['sentence1_ling'].unsqueeze(1)\
421
+ .expand(-1, enc_output.last_hidden_state.shape[1], -1)
422
+ sent2_ling = batch['sentence2_ling'].unsqueeze(1)\
423
+ .expand(-1, enc_output.last_hidden_state.shape[1], -1)
424
+ if self.args.ling2_only:
425
+ combined_embedding = torch.cat([enc_output.last_hidden_state, sent2_ling], dim=2)
426
+ else:
427
+ combined_embedding = torch.cat([enc_output.last_hidden_state, sent1_ling, sent2_ling], dim=2)
428
+ enc_output.last_hidden_state = self.fusion(combined_embedding)
429
+
430
+ if generate:
431
+ if self.args.combine_method == 'logits_add':
432
+ logits_processor = LogitsProcessorList([LogitsAdd(sent2_ling.view(bs, -1))])
433
+ else:
434
+ logits_processor = LogitsProcessorList()
435
+
436
+ dec_output = self.backbone.generate_with_grad(
437
+ attention_mask=inputs_att_mask,
438
+ encoder_outputs=enc_output,
439
+ sent1_ling=sent1_ling,
440
+ sent2_ling=sent2_ling,
441
+ return_dict_in_generate=True,
442
+ output_scores=True,
443
+ logits_processor = logits_processor,
444
+ # renormalize_logits=True,
445
+ # do_sample=True,
446
+ # top_p=0.8,
447
+ eos_token_id=self.eos_token_id,
448
+ # min_new_tokens=3,
449
+ # repetition_penalty=1.2,
450
+ max_length=self.args.max_length,
451
+ )
452
+ scores = torch.stack(dec_output.scores, 1)
453
+ cache.update({'scores': scores})
454
+ return dec_output.sequences, cache
455
+
456
+ decoder_input_ids = self.backbone._shift_right(batch['sentence2_input_ids'])
457
+ decoder_inputs_embeds = self.backbone.shared(decoder_input_ids)
458
+ decoder_att_mask = batch['sentence2_attention_mask']
459
+ labels = batch['sentence2_input_ids'].clone()
460
+ labels[labels == self.pad_token_id] = -100
461
+
462
+ if self.args.combine_method == 'decoder_concat':
463
+ if self.args.ling2_only:
464
+ decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
465
+ decoder_att_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
466
+ labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
467
+ labels], dim=1)
468
+ else:
469
+ decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
470
+ decoder_att_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
471
+ labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
472
+ labels], dim=1)
473
+ elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
474
+ if self.args.ling2_only:
475
+ decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sent2_ling
476
+ else:
477
+ decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
478
+
479
+ dec_output = self.backbone(
480
+ decoder_inputs_embeds=decoder_inputs_embeds,
481
+ decoder_attention_mask=decoder_att_mask,
482
+ encoder_outputs=enc_output,
483
+ attention_mask=inputs_att_mask,
484
+ labels=labels,
485
+ )
486
+ if self.args.combine_method == 'logits_add':
487
+ dec_output.logits = dec_output.logits + self.args.combine_weight * sent2_ling
488
+ vocab_size = dec_output.logits.size(-1)
489
+ dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
490
+ return dec_output, cache
491
+
492
+
493
+ def forward(self, batch, generate=False):
494
+ enc_output, enc_att_mask, cache = self.encode(batch)
495
+ dec_output, cache2 = self.decode(batch, enc_output, enc_att_mask, generate)
496
+ cache.update(cache2)
497
+ return dec_output, enc_output, cache
498
+
499
+ def infer_with_cache(self, batch):
500
+ dec_output, _, cache = self(batch, generate = True)
501
+ return dec_output, cache
502
+
503
+ def infer(self, batch):
504
+ dec_output, _ = self.infer_with_cache(batch)
505
+ return dec_output
506
+
507
+ def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer, scaler):
508
+ from torch.autograd import grad
509
+ interpolations = []
510
+ def line_search():
511
+ best_val = None
512
+ best_loss = None
513
+ eta = 1e3
514
+ sem_prob = 1
515
+ patience = 4
516
+ while patience > 0:
517
+ param_ = param - eta * grads
518
+ with torch.no_grad():
519
+ new_loss, pred = get_loss(param_)
520
+ max_len = pred.shape[1]
521
+ lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
522
+ # if lens.item() == 1:
523
+ # patience -= 1
524
+ batch.update({
525
+ 'sentence2_input_ids': pred,
526
+ 'sentence2_attention_mask': sequence_mask(lens, max_len = max_len)
527
+ })
528
+ sem_prob = torch.sigmoid(sem_emb(**batch)).item()
529
+ # if sem_prob <= 0.1:
530
+ # patience -= 1
531
+ # f.write(f'[{eta}], [{new_loss.item():.2f}], [{sem_prob:.2f}], {tokenizer.decode(pred[0])}\n')
532
+ # print(f'[{eta}], [{new_loss.item():.2f}], [{sem_prob:.2f}], {tokenizer.decode(pred[0])}\n')
533
+ if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
534
+ return param_
535
+ eta *= 2.25
536
+ patience -= 1
537
+ return False
538
+
539
+ def get_loss(param):
540
+ if self.args.feedback_param == 'l':
541
+ batch.update({'sent2_ling_embed': param})
542
+ elif self.args.feedback_param == 's':
543
+ batch.update({'inputs_embeds': param})
544
+
545
+ if self.args.feedback_param == 'logits':
546
+ logits = param
547
+ pred = param.argmax(-1)
548
+ else:
549
+ pred, cache = self.infer_with_cache(batch)
550
+ logits = cache['scores']
551
+ out = ling_disc(logits = logits)
552
+ probs = F.softmax(out, 1)
553
+ if ling_disc.quant:
554
+ loss = F.cross_entropy(out, batch['sentence2_discr'])
555
+ else:
556
+ loss = F.mse_loss(out, batch['sentence2_ling'])
557
+ return loss, pred
558
+
559
+ if self.args.feedback_param == 'l':
560
+ ling2_embed = self.ling_embed(batch['sentence2_ling'])
561
+ param = torch.nn.Parameter(ling2_embed, requires_grad = True)
562
+ elif self.args.feedback_param == 's':
563
+ inputs_embeds = self.backbone.shared(batch['sentence1_input_ids'])
564
+ param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
565
+ elif self.args.feedback_param == 'logits':
566
+ logits = self.infer_with_cache(batch)[1]['scores']
567
+ param = torch.nn.Parameter(logits, requires_grad = True)
568
+ f = open(self.args.fb_log, 'a') if self.args.fb_log else None
569
+ target_np = batch['sentence2_ling'][0].cpu().numpy()
570
+ while True:
571
+ loss, pred = get_loss(param)
572
+ pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
573
+ skip_special_tokens=True)[0]
574
+ if f:
575
+ # from compute_lng import compute_lng
576
+ # lng_pred = scaler.transform(np.array([compute_lng(pred_text)])[:,used_indices])[0]
577
+ # real_loss = np.mean((lng_pred - target_np)**2)
578
+ # f.write(f'Loss: {loss.item():.2f}\tReal loss:{real_loss:.2f}\t{pred_text}\n')
579
+ f.write(f'*** [{loss.item():.2f}], {pred_text}\n')
580
+ interpolations.append(pred_text)
581
+ if loss < 1:
582
+ break
583
+ self.zero_grad()
584
+ grads = grad(loss, param)[0]
585
+ param = line_search()
586
+ if param is False:
587
+ break
588
+ if f:
589
+ f.write(f'[return] {pred_text}\n\n')
590
+ f.close()
591
+ return pred, [pred_text, interpolations]
592
+
593
+ def infer_with_feedback(self, ling_disc, batch, tokenizer, scaler, approx=False):
594
+ interpolations = []
595
+ converged = False
596
+ c = 0
597
+ eta = 0.3
598
+ use_embed = True
599
+ if use_embed:
600
+ ling1_embed = self.ling_embed(batch['sentence1_ling'])
601
+ ling2_embed = self.ling_embed(batch['sentence2_ling'])
602
+ batch.update({
603
+ 'sent1_ling_embed': ling1_embed,
604
+ 'sent2_ling_embed': ling2_embed,
605
+ })
606
+ else:
607
+ ling2 = batch['sentence2_ling']
608
+ ling2_orig = batch['sentence2_ling'].clone()
609
+ while not converged:
610
+ with torch.no_grad():
611
+ pred = self.infer(batch)
612
+ inputs_pred = batch.copy()
613
+ inputs_pred.update({'input_ids': pred,
614
+ 'attention_mask': torch.ones_like(pred)})
615
+ pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
616
+ skip_special_tokens=True)[0]
617
+ if approx:
618
+ ling_pred = ling_disc(**inputs_pred)
619
+ else:
620
+ ling_pred = compute_lng(pred_text)
621
+ ling_pred = scaler.transform([ling_pred])[0]
622
+ ling_pred = torch.tensor(ling_pred).to(pred.device).float()
623
+ if use_embed:
624
+ ling_pred_embed = self.ling_embed(ling_pred)
625
+ # diff = torch.mean((ling2_embed - ling_pred_embed)**2)
626
+ # else:
627
+ diff = torch.mean((ling2_orig - ling_pred)**2)
628
+
629
+
630
+ # print(f'Diff {diff.item():.3f}>> {tokenizer.batch_decode(pred.cpu().numpy(), skip_special_tokens=True)[0]}')
631
+ if diff < 1e-1 or c == 6:
632
+ converged = True
633
+ elif use_embed:
634
+ ling2_embed = ling2_embed + eta * (ling_pred_embed - ling2_embed)
635
+ batch.update({'sent2_ling_embed': ling2_embed})
636
+ else:
637
+ ling2 = ling2 + eta * (ling_pred - ling2)
638
+ batch.update({'sentence2_ling': ling2})
639
+
640
+ c += 1
641
+
642
+ if len(interpolations) == 0 or pred_text != interpolations[-1]:
643
+ interpolations.append(pred_text)
644
+
645
+ return [pred_text, interpolations]
646
+
647
+ def set_grad(module, state):
648
+ if module is not None:
649
+ for p in module.parameters():
650
+ p.requires_grad = state
651
+
652
+ def set_grad_except(model, name, state):
653
+ for n, p in model.named_parameters():
654
+ if not name in n:
655
+ p.requires_grad = state
656
+
657
+ class SemEmbPipeline():
658
+ def __init__(self,
659
+ ckpt = "/data/mohamed/checkpoints/ling_conversion_sem_emb_best.pt"):
660
+ self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")
661
+ self.model = SemEmb(T5EncoderModel.from_pretrained('google/flan-t5-base'), self.tokenizer.get_vocab()['</s>'])
662
+ state = torch.load(ckpt)
663
+ self.model.load_state_dict(state['model'], strict=False)
664
+ self.model.eval()
665
+ self.model.cuda()
666
+
667
+ def __call__(self, sentence1, sentence2):
668
+ sentence1 = self.tokenizer(sentence1, return_attention_mask = True, return_tensors = 'pt')
669
+ sentence2 = self.tokenizer(sentence2, return_attention_mask = True, return_tensors = 'pt')
670
+ sem_logit = self.model(
671
+ sentence1_input_ids = sentence1.input_ids.cuda(),
672
+ sentence1_attention_mask = sentence1.attention_mask.cuda(),
673
+ sentence2_input_ids = sentence2.input_ids.cuda(),
674
+ sentence2_attention_mask = sentence2.attention_mask.cuda(),
675
+ )
676
+ sem_prob = torch.sigmoid(sem_logit).item()
677
+ return sem_prob
678
+
679
+ class LingDiscPipeline():
680
+ def __init__(self,
681
+ model_name="google/flan-t5-base",
682
+ disc_type='deberta',
683
+ disc_ckpt='/data/mohamed/checkpoints/ling_disc/deberta-v3-small_flan-t5-base_40',
684
+ # disc_type='t5',
685
+ # disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
686
+ ):
687
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name)
688
+ self.model = LingDisc(model_name, disc_type, disc_ckpt)
689
+ self.model.eval()
690
+ self.model.cuda()
691
+
692
+ def __call__(self, sentence):
693
+ inputs = self.tokenizer(sentence, return_tensors = 'pt')
694
+ with torch.no_grad():
695
+ ling_pred = self.model(input_ids=inputs.input_ids.cuda())
696
+ return ling_pred
options.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ from const import lca_names, sca_names, lingfeat_names
4
+ import os, json
5
+ from copy import deepcopy
6
+ import numpy as np
7
+
8
+ def parse_args(ckpt=None):
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument('--data_dir', default='/data/mohamed/data')
11
+ parser.add_argument('--data', default='ling_conversion')
12
+ parser.add_argument('--data_sources')
13
+ parser.add_argument('--data_type', default='text')
14
+ parser.add_argument('--aim_repo', default='/data/mohamed/')
15
+ parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
16
+ parser.add_argument('--kld_annealing', default='cyclic')
17
+ parser.add_argument('--lingpred_annealing', default='mono')
18
+ parser.add_argument('--ling_embed_type', default = 'one-layer')
19
+ parser.add_argument('--combine_weight', default=1, type=float)
20
+ parser.add_argument('--alpha_kld', default=1, type=float)
21
+ parser.add_argument('--alpha_lingpred', default=1, type=float)
22
+ parser.add_argument('--alpha_sem', default=1, type=float)
23
+ parser.add_argument('--max_grad_norm', default=10, type=float)
24
+ parser.add_argument('--sem_loss_tao', default=0.5, type=float)
25
+ parser.add_argument('--sem_loss_eps', default=1, type=float)
26
+ parser.add_argument('--ckpt')
27
+ parser.add_argument('--disc_ckpt')
28
+ parser.add_argument('--sem_ckpt')
29
+ parser.add_argument('--lng_ids')
30
+ parser.add_argument('--lng_ids_idx', type=int)
31
+ parser.add_argument('--lng_ids_path', default='/data/mohamed/indices')
32
+ parser.add_argument('--preds_dir', default='/data/mohamed/preds')
33
+ parser.add_argument('--model_name', default="google/flan-t5-base")
34
+ parser.add_argument('--disc_type', default="t5")
35
+ parser.add_argument('--aim_exp', default='ling-conversion')
36
+ parser.add_argument('--sem_loss_type', default='dedicated')
37
+ parser.add_argument('--combine_method', default='none')
38
+ parser.add_argument('--train_log', type=int, default=200)
39
+ parser.add_argument('--val_log', type=int, default=2000)
40
+ parser.add_argument('--batch_size', type=int, default=64)
41
+ parser.add_argument('--eval_batch_size', type=int, default=32)
42
+ parser.add_argument('--max_eval_samples', type=int, default=1000)
43
+ parser.add_argument('--test_batch_size', type=int, default=1)
44
+ parser.add_argument('--hidden_dim', type=int, default=500)
45
+ parser.add_argument('--latent_dim', type=int, default=150)
46
+ parser.add_argument('--lng_dim', type=int, default=40)
47
+ parser.add_argument('--disc_lng_dim', type=int)
48
+ parser.add_argument('--use_lora', action='store_true')
49
+ parser.add_argument('--lora_r', type=int, default=64)
50
+ parser.add_argument('--gpu', type=str, default='0')
51
+ parser.add_argument('--epochs', type=int, default=10)
52
+ parser.add_argument('--grad_accumulation', type=int, default=1)
53
+ parser.add_argument('--n_ica', type=int, default=10)
54
+ parser.add_argument('--max_length', type=int, default=200)
55
+ parser.add_argument('--total_steps', type=int)
56
+ parser.add_argument('--kld_const', type=float, default=1)
57
+ parser.add_argument('--lr', type=float, default=1e-4)
58
+ parser.add_argument('--kl_weight', type=float, default=1e-1)
59
+ parser.add_argument('--weight_decay', type=float, default=1e-2)
60
+ parser.add_argument('--ling_dropout', type=float, default=0.1)
61
+ parser.add_argument('--predict_fn', default = 'logs/test.txt')
62
+ parser.add_argument('--save_predict', action='store_true')
63
+ parser.add_argument('--use_ica', action='store_true')
64
+ parser.add_argument('--pretrain_gen', action='store_true')
65
+ parser.add_argument('--pretrain_sem', action='store_true')
66
+ parser.add_argument('--pretrain_disc', action='store_true')
67
+ parser.add_argument('--linggen_type', default='none')
68
+ parser.add_argument('--linggen_input', default='s+l')
69
+ parser.add_argument('--aug_same', action='store_true')
70
+ parser.add_argument('--ling_vae', action='store_true')
71
+ parser.add_argument('--process_lingpred', action='store_true')
72
+ parser.add_argument('--fudge_lambda', type=float, default=1.0)
73
+ parser.add_argument('--use_lingpred', action='store_true')
74
+ parser.add_argument('--ling2_only', action='store_true')
75
+ parser.add_argument('--cycle_loss', action='store_true')
76
+ parser.add_argument('--disc_loss', action='store_true')
77
+ parser.add_argument('--sem_loss', action='store_true')
78
+ parser.add_argument('--sim_loss', action='store_true')
79
+ parser.add_argument('--optuna', action='store_true')
80
+ parser.add_argument('--debug', action='store_true')
81
+ parser.add_argument('--demo', action='store_true')
82
+ parser.add_argument('--fudge', action='store_true')
83
+ parser.add_argument('--fb_log', default='feedback_logs/default.txt')
84
+ parser.add_argument('--eval_only', action='store_true')
85
+ parser.add_argument('--predict_with_feedback', action='store_true')
86
+ parser.add_argument('--feedback_param', default = 's')
87
+ parser.add_argument('--eval_ling', action='store_true')
88
+ parser.add_argument('--seed', type=int, default=0)
89
+ parser.add_argument('--major_arg', default = 0, type=int)
90
+ parser.add_argument('--quantize_lng', action='store_true')
91
+ parser.add_argument('--quant_nbins', type=int, default=20)
92
+ parser.add_argument('--src_lng', default = 'ling')
93
+ parser.add_argument('--to_restore', nargs='+', default=[])
94
+ # args = parser.parse_args()
95
+ args, unknown = parser.parse_known_args()
96
+ args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'
97
+
98
+ major_arg = args.major_arg
99
+ to_restore = [
100
+ 'total_steps','major_arg','gpu','demo', 'eval_only', 'save_predict', 'predict_fn', 'fudge', 'predict_with_feedback',
101
+ 'feedback_param', 'fb_log', 'data_dir', 'data', 'disc_ckpt', 'disc_type', 'sem_ckpt', 'fudge_lambda', 'test_batch_size', 'src_lng'
102
+ ] + args.to_restore
103
+ to_restore = {k: args.__dict__[k] for k in to_restore}
104
+
105
+ if not args.disc_loss or args.disc_ckpt:
106
+ args.disc_steps = 0
107
+
108
+ if args.data_sources is not None:
109
+ args.data_sources = args.data_sources.split(',')
110
+
111
+ if ckpt is not None:
112
+ args.ckpt = ckpt
113
+
114
+ args_list = [args]
115
+ if args.ckpt:
116
+ if ',' in args.ckpt:
117
+ ckpts = args.ckpt.split(',')
118
+ args_list = [deepcopy(args) for _ in range(len(ckpts))]
119
+ for i in range(len(ckpts)):
120
+ args_path = ckpts[i].replace('_best', '').replace('.pt', '.json')
121
+ with open(args_path) as f:
122
+ args_list[i].__dict__.update(json.load(f))
123
+ args_list[i].__dict__.update(to_restore)
124
+ args_list[i].ckpt = ckpts[i]
125
+ else:
126
+ args_path = args.ckpt.replace('_best', '').replace('.pt', '.json')
127
+ ckpt = args.ckpt
128
+ with open(args_path) as f:
129
+ args.__dict__.update(json.load(f))
130
+ args.__dict__.update(to_restore)
131
+ args.ckpt = ckpt
132
+
133
+ lng_names = lca_names + sca_names + lingfeat_names
134
+ for i in range(len(args_list)):
135
+ if args_list[i].lng_ids or args_list[i].lng_ids_idx:
136
+ if args_list[i].lng_ids_idx:
137
+ lng_ids = np.load(os.path.join(args_list[i].lng_ids_path, f'{args_list[i].lng_ids_idx}.npy'))
138
+ elif args_list[i].lng_ids[0].isnumeric():
139
+ lng_ids = [int(x) for x in args_list[i].lng_ids.split(',')]
140
+ elif ',' in args_list[i].lng_ids:
141
+ lng_ids = [lng_names.index(x) for x in args_list[i].lng_ids.split(',')]
142
+ else:
143
+ lng_ids = np.load(args_list[i].lng_ids)
144
+ args_list[i].lng_dim = len(lng_ids)
145
+ args_list[i].lng_ids = lng_ids.tolist()
146
+ # lng_names = [lng_names[i] for i in lng_ids]
147
+ elif args_list[i].use_ica:
148
+ args_list[i].lng_dim = args_list[i].n_ica
149
+ if args_list[i].disc_lng_dim is None:
150
+ args_list[i].disc_lng_dim = args_list[i].lng_dim
151
+
152
+ if not args.ckpt and not args.eval_only:
153
+ args_path = os.path.join(args.ckpt_dir, '%s.json'%args.name)
154
+ with open(args_path, 'w') as f:
155
+ s = json.dumps(args.__dict__)
156
+ f.write(s)
157
+
158
+ return args_list[major_arg], args_list, lng_names
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ joblib
4
+ gensim
5
+ supar
6
+ transformers
7
+ scikit-learn
8
+ tqdm
9
+ spacy
10
+ sentencepiece