James Kelly commited on
Commit
8155451
1 Parent(s): 23d06eb

cloned most of ctmatch into this spaces repo... it will have to handle the data too, we'll see. using ctmatch requirements.txt

Browse files
app.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from ctmatch.match import CTMatch, PipeConfig
3
+ import gradio as gr
4
+
5
+
6
+ pipe_config = PipeConfig(
7
+ classifier_model_checkpoint='semaj83/scibert_finetuned_pruned_ctmatch',
8
+ ir_setup=True,
9
+ filters=["svm", "classifier"],
10
+ )
11
+
12
+ CTM = CTMatch(pipe_config)
13
+
14
+
15
+ def ctmatch_web_api(topic_query: str) -> str:
16
+ return '\n\n'.join([f"{nid}: {txt}" for nid, txt in CTM.match_pipeline(topic_query, top_k=5)])
17
+
18
+
19
+ if __name__ == "__main__":
20
+
21
+ with gr.Blocks(css=".gradio-container {background-color: #00CED1}") as demo:
22
+ name = gr.Textbox(lines=5, label="patient description", placeholder="Patient is a 45-year-old man with a history of anaplastic astrocytoma...")
23
+ output = gr.Textbox(lines=10, label="matching trials")
24
+ greet_btn = gr.Button("match")
25
+ greet_btn.click(fn=ctmatch_web_api, inputs=name, outputs=output, api_name="match")
26
+
27
+ demo.queue().launch(share=True, debug=True)
ctmatch/__init__.py ADDED
File without changes
ctmatch/ct_data_paths.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Tuple
3
+
4
+ TREC_REL_PATH = "/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_21_judgments.txt"
5
+ KZ_REL_PATH = "/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/qrels-clinical_trials.txt"
6
+
7
+ TREC_RELLED_TOPIC_PATH = "/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec21_topics.jsonl"
8
+ KZ_RELLED_TOPIC_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/processed_kz_data/processed_kz_topics.jsonl'
9
+
10
+ KZ_DOC_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/clinicaltrials.gov-16_dec_2015.zip'
11
+ KZ_PROCESSED_DOC_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/processed_kz_data/processed_kz_docs.jsonl'
12
+
13
+ TREC_ML_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_data.jsonl'
14
+ KZ_ML_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/kz_data.jsonl'
15
+
16
+
17
+ def get_data_tuples(trec_or_kz: str = 'trec') -> List[Tuple[str, str]]:
18
+ if trec_or_kz == 'trec':
19
+ return get_trec_doc_data_tuples(), get_trec_topic_data_tuples()
20
+ return get_kz_doc_data_tuples(), get_kz_topic_data_tuples()
21
+
22
+
23
+
24
+ # --------------------------------------------------------------------------------------------------------------- #
25
+ # data from TREC clinical track 2021 & 2022
26
+ # --------------------------------------------------------------------------------------------------------------- #
27
+
28
+
29
+ def get_trec_doc_data_tuples() -> List[Tuple[str]]:
30
+ trec22_pt1_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part1.zip'
31
+ trec_pt1_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part1.jsonl'
32
+
33
+ trec22_pt2_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part2.zip'
34
+ trec_pt2_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part2.jsonl'
35
+
36
+ trec22_pt3_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part3.zip'
37
+ trec_pt3_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part3.jsonl'
38
+
39
+ trec22_pt4_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part4.zip'
40
+ trec_pt4_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part4.jsonl'
41
+
42
+ trec22_pt5_docs = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_docs_21/ClinicalTrials.2021-04-27.part5.zip'
43
+ trec_pt5_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part5.jsonl'
44
+
45
+ trec_doc_data_tuples = [
46
+ (trec22_pt1_docs, trec_pt1_target),
47
+ (trec22_pt2_docs, trec_pt2_target),
48
+ (trec22_pt3_docs, trec_pt3_target),
49
+ (trec22_pt4_docs, trec_pt4_target),
50
+ (trec22_pt5_docs, trec_pt5_target)
51
+ ]
52
+
53
+ return trec_doc_data_tuples
54
+
55
+
56
+ def get_trec_topic_data_tuples() -> List[Tuple[str]]:
57
+ trec21_topic_path = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_21_topics.xml'
58
+ trec21_topic_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec21_topics.jsonl'
59
+ trec22_topic_path = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/trec_22_topics.xml'
60
+ trec22_topic_target = '/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_topics.jsonl'
61
+
62
+ trec_topic_data_tuples = [
63
+ (trec21_topic_path, trec21_topic_target),
64
+ (trec22_topic_path, trec22_topic_target)
65
+ ]
66
+ return trec_topic_data_tuples
67
+
68
+
69
+
70
+
71
+ # --------------------------------------------------------------------------------------------------------------- #
72
+ # data from Koontz, et al. (2016)
73
+ # --------------------------------------------------------------------------------------------------------------- #
74
+ def get_kz_doc_data_tuples() -> List[Tuple[str]]:
75
+ # kz_doc_data_tuples = []
76
+ # for i in range(1, 18):
77
+ # kz_doc_path = f'/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/kz_doc_splits/kz_doc_split{i}.zip'
78
+ # kz_doc_target = f'/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/processed_kz_data/processed_kz_doc_split{i}.jsonl'
79
+ # kz_doc_data_tuples.append((kz_doc_path, kz_doc_target))
80
+ kz_docs = KZ_DOC_PATH
81
+ kz_docs_target = KZ_PROCESSED_DOC_PATH
82
+ return [(kz_docs, kz_docs_target)]
83
+
84
+ #return kz_doc_data_tuples
85
+
86
+ def get_kz_topic_data_tuples() -> List[Tuple[str]]:
87
+ kz_topic_desc_path = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/topics-2014_2015-description.topics'
88
+ kz_topic_target = '/Users/jameskelly/Documents/cp/ctmatch/data/kz_data/processed_kz_data/processed_kz_topics.jsonl'
89
+ kz_topic_data_tuples = [
90
+ (kz_topic_desc_path, kz_topic_target)
91
+ ]
92
+ return kz_topic_data_tuples
ctmatch/ctmatch_prep.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
3
+ import ct_data_paths as ctpaths
4
+ import numpy as np
5
+ import random
6
+ import json
7
+
8
+ from proc import CTConfig, CTProc, CTDocument, CTTopic
9
+ from scripts.vis_scripts import analyze_test_rels
10
+ from ctproc_ctmatch_utils import get_processed_data, truncate
11
+ import ctproc_eda as eda
12
+
13
+ LLM_END_PROMPT: str = "Revelance score (0, 1, or 2) : [CLS] "
14
+
15
+ class DataConfig(NamedTuple):
16
+ save_path: str
17
+ trec_or_kz: str = 'trec'
18
+ filtered_topic_keys: Set[str] = {'id', 'text_sents', 'age', 'gender'}
19
+ filtered_doc_keys: Set[str] = {'id', 'elig_min_age', 'elig_max_age', 'elig_gender', 'condition', 'elig_crit'}
20
+ max_topic_len: Optional[int] = None
21
+ max_inc_len: Optional[int] = None
22
+ max_exc_len: Optional[int] = None
23
+ prepend_elig_age: bool = True
24
+ prepend_elig_gender: bool = True
25
+ include_only: bool = False
26
+ downsample_zeros_n: Optional[int] = None
27
+ sep: str = '[SEP]'
28
+ llm_prep: bool = False
29
+ first_n_only: Optional[int] = None
30
+ convert_snli: bool = False
31
+ infer_category_model: Optional[str] = None
32
+
33
+
34
+
35
+
36
+ def proc_docs_and_topics(trec_or_kz: str = 'trec') -> Tuple[Dict[str, Dict[str, str]], Dict[str, Dict[str, str]]]:
37
+
38
+ doc_tuples, topic_tuples = ctpaths.get_data_tuples(trec_or_kz)
39
+
40
+ id2topic = dict()
41
+ for topic_source, topic_target in topic_tuples:
42
+ id2topic.update(proc_topics(topic_source, topic_target, trec_or_kz=trec_or_kz))
43
+ print(f"processed {trec_or_kz} topic source: {topic_source}, and wrote to {topic_target}")
44
+
45
+ id2doc = dict()
46
+ for doc_source, doc_target in doc_tuples:
47
+ id2doc.update(proc_docs(doc_source, doc_target))
48
+ print(f"processed {trec_or_kz} doc source: {doc_source}, and wrote to {doc_target}")
49
+
50
+
51
+ return id2topic, id2doc
52
+
53
+
54
+
55
+
56
+
57
+ def proc_docs(doc_path: str, output_path: str) -> Dict[str, CTDocument]:
58
+
59
+ ct_config = CTConfig(
60
+ data_path=doc_path,
61
+ write_file=output_path,
62
+ nlp=True
63
+ )
64
+
65
+ cp = CTProc(ct_config)
66
+ id2doc = {res.id : res for res in cp.process_data()}
67
+ return id2doc
68
+
69
+
70
+
71
+ def proc_topics(topic_path: str, output_path: str, trec_or_kz: str = 'trec') -> Dict[str, CTTopic]:
72
+
73
+ ct_config = CTConfig(
74
+ data_path=topic_path,
75
+ write_file=output_path,
76
+ nlp=True,
77
+ is_topic=True,
78
+ trec_or_kz=trec_or_kz
79
+ )
80
+
81
+ cp = CTProc(ct_config)
82
+ id2topic = {res.id : res for res in cp.process_data()}
83
+ return id2topic
84
+
85
+
86
+
87
+ def filter_doc_for_ir(doc, dconfig) -> Dict[str, List[str]]:
88
+ new_doc = dict()
89
+ new_doc['id'] = doc['id']
90
+ new_doc['text'] = prep_doc_text(doc, dconfig)
91
+ return new_doc
92
+
93
+
94
+ def prep_ir_dataset(dconfig: DataConfig):
95
+ # need a file of all docs with their
96
+ # 1. ids,
97
+ # 2. combined text...
98
+ # 3.
99
+
100
+ # get path to processed docs
101
+ doc_tuples, _ = ctpaths.get_data_tuples(dconfig.trec_or_kz)
102
+
103
+ # get all processed docs
104
+ id2doc = dict()
105
+ for _, processed_doc_path in doc_tuples:
106
+ print(f"getting docs from {processed_doc_path}")
107
+ for doc in get_processed_data(processed_doc_path):
108
+ doc = filter_doc_for_ir(doc, dconfig)
109
+ doc['category'] = np.asarray(sorted(doc['category']).values()) # makes a consistently ordered category vector
110
+ id2doc[doc.id] = doc
111
+ return id2doc
112
+
113
+
114
+ # --------------------------------------------------------------------------------------------------------------- #
115
+ # pre-processing functions to save a form of triples for a particular model spec
116
+ # --------------------------------------------------------------------------------------------------------------- #
117
+
118
+ def prep_fine_tuning_dataset(
119
+ dconfig: DataConfig
120
+ ) -> None:
121
+ """
122
+ trec_or_kz: 'trec' or 'kz'
123
+ desc: create dict of triplets of topic, doc, relevancy scores,
124
+ save into a single jsonl file
125
+ """
126
+ print(f"trec_or_kz: {dconfig.trec_or_kz}")
127
+ topic_path, rel_path = get_topic_and_rel_path(dconfig.trec_or_kz)
128
+
129
+
130
+ # get set of all relevant doc ids
131
+ rel_type_dict, rel_dict, all_qrelled_docs = analyze_test_rels(rel_path)
132
+
133
+ # get path to processed docs (already got topic path)
134
+ doc_tuples, _ = ctpaths.get_data_tuples(dconfig.trec_or_kz)
135
+
136
+ # get mappings of doc ids to doc dicts and topic ids to topic dicts
137
+ id2doc, id2topic = get_doc_and_topic_mappings(all_qrelled_docs, doc_tuples, topic_path)
138
+ print(len(id2doc), len(all_qrelled_docs))
139
+
140
+ missing_docs = set()
141
+ skipped = 0
142
+
143
+ # save combined triples of doc, topic, relevancy score
144
+ with open(dconfig.save_path, 'w') as f:
145
+ print(f"saving to: {dconfig.save_path}")
146
+
147
+ for topic_id in rel_dict:
148
+ for doc_id in rel_dict[topic_id]:
149
+ label = rel_dict[topic_id][doc_id]
150
+ if downsample_zero(label, rel_type_dict['0'], dconfig):
151
+ skipped += 1
152
+ continue
153
+
154
+ if doc_id in id2doc:
155
+ combined = create_combined_doc(
156
+ id2doc[doc_id],
157
+ id2topic[topic_id],
158
+ label,
159
+ dconfig=dconfig,
160
+ )
161
+
162
+ # save to file as jsonl
163
+ f.write(json.dumps(combined))
164
+ f.write('\n')
165
+ else:
166
+ missing_docs.add(doc_id)
167
+
168
+
169
+ print(f"number of docs missing: {len(missing_docs)}, number of zeros skipped: {skipped}")
170
+ for md in missing_docs:
171
+ print(md)
172
+
173
+
174
+
175
+ def create_combined_doc(
176
+ doc, topic,
177
+ rel_score,
178
+ dconfig: DataConfig,
179
+ ):
180
+ combined = dict()
181
+
182
+ # get filtered and truncated and SEP tokenized topic text
183
+ combined['topic'] = prep_topic_text(topic, dconfig)
184
+
185
+ # get filtered and truncated and SEP tokenized doc text
186
+ combined['doc'] = prep_doc_text(doc, dconfig)
187
+
188
+ # get relevancy score as string
189
+ if dconfig.convert_snli:
190
+ rel_score = convert_label_snli(rel_score)
191
+
192
+ combined['label'] = str(rel_score)
193
+
194
+ return combined
195
+
196
+
197
+ def convert_label_snli(label: int) -> int:
198
+ if label == 2:
199
+ return 1
200
+ elif label == 1:
201
+ return 2
202
+ return label
203
+
204
+
205
+
206
+ def downsample_zero(label: str, zero_ct: int, dconfig: DataConfig) -> bool:
207
+ if dconfig.downsample_zeros_n is not None:
208
+ if (label == 0) and (random.random() > (dconfig.downsample_zeros_n / zero_ct)):
209
+ return True
210
+ return False
211
+
212
+
213
+ def prep_topic_text(topic: Dict[str, Union[List[str], str, float]], dconfig: DataConfig) -> str:
214
+ topic_text = ' '.join(topic['text_sents'])
215
+ topic_text = truncate(topic_text, dconfig.max_topic_len)
216
+ return topic_text
217
+
218
+
219
+ def get_n_crit(crit_list: List[str], dconfig: DataConfig) -> List[str]:
220
+ if dconfig.first_n_only is not None:
221
+ crit_list = crit_list[:min(len(crit_list), dconfig.first_n_only)]
222
+ return crit_list
223
+
224
+
225
+ def prep_doc_text(doc: Dict[str, Union[List[str], str, float]], dconfig: DataConfig) -> str:
226
+
227
+ # combine lists of strings into single string
228
+ doc_inc = ' '.join(get_n_crit(doc['elig_crit']['include_criteria'], dconfig))
229
+ doc_exc = ' '.join(get_n_crit(doc['elig_crit']['exclude_criteria'], dconfig))
230
+
231
+
232
+ if 'condition' in dconfig.filtered_doc_keys:
233
+ doc_inc = f"{' '.join(doc['condition'])} {doc_inc}"
234
+ if dconfig.llm_prep:
235
+ doc_inc = "Condition: " + doc_inc + ", "
236
+
237
+ #truncate criteria separately if in config
238
+ doc_inc = truncate(doc_inc, dconfig.max_inc_len)
239
+ doc_exc = truncate(doc_exc, dconfig.max_exc_len)
240
+
241
+
242
+ if dconfig.prepend_elig_gender:
243
+ doc_inc = f"{doc['elig_gender']} {dconfig.sep} {doc_inc}"
244
+ if dconfig.llm_prep:
245
+ doc_inc = "Gender: " + doc_inc + ", "
246
+
247
+ if dconfig.prepend_elig_age:
248
+ if dconfig.llm_prep:
249
+ doc_inc = f"Trial Doc: A person who is between {doc['elig_min_age']}-{doc['elig_max_age']} years old who meets the following Inclusion Criteria: {doc_inc}"
250
+ else:
251
+ doc_inc = f"eligible ages (years): {doc['elig_min_age']}-{doc['elig_max_age']}, {dconfig.sep} {doc_inc}"
252
+
253
+ # combine criteria into single string
254
+ if dconfig.include_only:
255
+ if dconfig.llm_prep:
256
+ doc_inc += LLM_END_PROMPT
257
+ return doc_inc
258
+
259
+ if dconfig.llm_prep:
260
+ return f"{doc_inc} and does not meet these Exclusion Criteria: {doc_exc} {LLM_END_PROMPT}"
261
+
262
+ return f"{doc_inc} {dconfig.sep} {doc_exc}"
263
+
264
+
265
+
266
+
267
+ # --------------------------------------------------------------------------------------------------------------- #
268
+ # utility functions
269
+ # --------------------------------------------------------------------------------------------------------------- #
270
+
271
+ def age_match(min_doc_age: float, max_doc_age: float, topic_age: float) -> bool:
272
+ if topic_age < min_doc_age:
273
+ return False
274
+ if topic_age > max_doc_age:
275
+ return False
276
+ return True
277
+
278
+ def gender_match(doc_gender: str, topic_gender: str) -> bool:
279
+ if doc_gender == 'All':
280
+ return True
281
+ if doc_gender == topic_gender:
282
+ return True
283
+ return False
284
+
285
+
286
+ def get_topic_and_rel_path(trec_or_kz: str = 'trec') -> Tuple[str, str]:
287
+ if trec_or_kz == 'trec':
288
+ rel_path = ctpaths.TREC_REL_PATH
289
+ topic_path = ctpaths.TREC_RELLED_TOPIC_PATH
290
+ else:
291
+ rel_path = ctpaths.KZ_REL_PATH
292
+ topic_path = ctpaths.KZ_RELLED_TOPIC_PATH
293
+ return topic_path, rel_path
294
+
295
+
296
+ def get_doc_and_topic_mappings(all_qrelled_docs: Set[str], doc_tuples: List[Tuple[str, str]], topic_path: str) -> Tuple[Dict[str, Dict[str, str]], Dict[str, Dict[str, str]]]:
297
+ """
298
+ desc: get mappings of doc ids to doc dicts and topic ids to topic dicts
299
+ """
300
+
301
+ # get all processed topics
302
+ id2topic = {t['id']:t for t in get_processed_data(topic_path)}
303
+
304
+ # get all processed docs
305
+ id2doc = dict()
306
+ for _, processed_doc_path in doc_tuples:
307
+ print(f"getting docs from {processed_doc_path}")
308
+ for doc in get_processed_data(processed_doc_path):
309
+ if doc['id'] in all_qrelled_docs:
310
+ id2doc[doc['id']] = doc
311
+
312
+ return id2doc, id2topic
313
+
314
+
315
+ if __name__ == '__main__':
316
+ # proc_docs_and_topics('kz')
317
+ # eda.explore_trec_data(part=2, rand_print=0.001) # select part 1-5 (~70k docs per part)
318
+ # eda.explore_kz_data(rand_print=0.00001) # all in one file (~200k docs)
319
+
320
+ # example config:
321
+ # class DataConfig(NamedTuple):
322
+ # save_path: str
323
+ # trec_or_kz: str = 'trec'
324
+ # filtered_topic_keys: Set[str] = {'id', 'text_sents', 'age', 'gender'}
325
+ # filtered_doc_keys: Set[str] = {'id', 'elig_min_age', 'elig_max_age', 'elig_gender', 'condition', 'elig_crit'}
326
+ # max_topic_len: Optional[int] = None
327
+ # max_inc_len: Optional[int] = None
328
+ # max_exc_len: Optional[int] = None
329
+ # prepend_elig_age: bool = True
330
+ # prepend_elig_gender: bool = True
331
+ # include_only: bool = False
332
+ # downsample_zeros_n: Optional[int] = None
333
+ # sep: str = '[SEP]'
334
+ # llm_prep: bool = False
335
+ # first_n_only: Optional[int] = None
336
+ # convert_snli: bool = False
337
+ # infer_category_model: Optional[str] = None
338
+
339
+ dconfig = DataConfig(
340
+ trec_or_kz='trec',
341
+ save_path=ctpaths.TREC_ML_PATH, # make sure to change this!
342
+ sep='',
343
+ first_n_only=10,
344
+ max_topic_len=200,
345
+ llm_prep=False,
346
+ prepend_elig_age=True,
347
+ prepend_elig_gender=False
348
+ )
349
+ prep_fine_tuning_dataset(dconfig)
350
+ #eda.explore_prepped(ctpaths.TREC_KZ_PATH)
351
+
ctmatch/dataprep.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ # external imports
5
+ from datasets import Dataset, load_dataset, ClassLabel, Features, Value
6
+ from transformers import AutoTokenizer
7
+ import pandas as pd
8
+ import numpy as np
9
+
10
+ # package tools
11
+ from .utils.ctmatch_utils import train_test_val_split, get_processed_data, get_test_rels
12
+ from .pipeconfig import PipeConfig
13
+
14
+
15
+ # path to ctmatch dataset on HF hub
16
+ CTMATCH_CLASSIFICATION_DATASET_ROOT = "semaj83/ctmatch_classification"
17
+ CTMATCH_IR_DATASET_ROOT = "semaj83/ctmatch_ir"
18
+ CLASSIFIER_DATA_PATH = "combined_classifier_data.jsonl"
19
+ DOC_TEXTS_PATH = "doc_texts.txt"
20
+ DOC_CATEGORIES_VEC_PATH = "doc_categories.txt"
21
+ DOC_EMBEDDINGS_VEC_PATH = "doc_embeddings.txt"
22
+ INDEX2DOCID_PATH = "index2docid.txt"
23
+
24
+
25
+ SUPPORTED_LMS = [
26
+ 'roberta-large', 'cross-encoder/nli-roberta-base',
27
+ 'microsoft/biogpt', 'allenai/scibert_scivocab_uncased',
28
+ 'facebook/bart-large', 'gpt2',
29
+ 'semaj83/scibert_finetuned_ctmatch', 'semaj83/scibert_finetuned_pruned_ctmatch'
30
+
31
+ ]
32
+
33
+
34
+ class DataPrep:
35
+ # multiple 'datasets' need to be prepared for the pipeline
36
+ # 1. the dataset for the classifier model triplets and a dataframe, ~ 25k rows
37
+ # 2. the dataset for the category model, every doc ~200k rows
38
+ # 3. the dataset for the embedding model, every doc < 200k rows
39
+
40
+
41
+
42
+ def __init__(self, pipe_config: PipeConfig) -> None:
43
+ self.pipe_config = pipe_config
44
+ self.classifier_tokenizer = self.get_classifier_tokenizer()
45
+ self.ct_dataset = None
46
+ self.ct_train_dataset_df = None
47
+ self.index2docid = None
48
+ self.doc_embeddings_df = None
49
+ self.doc_categories_df = None
50
+
51
+ if pipe_config.ir_setup:
52
+ self.load_ir_data()
53
+ else:
54
+ self.load_classifier_data()
55
+
56
+
57
+
58
+
59
+ def get_classifier_tokenizer(self):
60
+ model_checkpoint = self.pipe_config.classifier_model_checkpoint
61
+ if model_checkpoint not in SUPPORTED_LMS:
62
+ raise ValueError(f"Model checkpoint {model_checkpoint} not supported. Please use one of {SUPPORTED_LMS}")
63
+ if 'scibert' in model_checkpoint:
64
+ tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased', use_fast=True)
65
+ else:
66
+ tokenizer = AutoTokenizer.from_pretrained(self.pipe_config.classifier_model_checkpoint)
67
+ if self.pipe_config.classifier_model_checkpoint == 'gpt2':
68
+ tokenizer.pad_token = tokenizer.eos_token
69
+ return tokenizer
70
+
71
+
72
+ # ------------------ Classifier Data Loading ------------------ #
73
+ def load_classifier_data(self) -> Dataset:
74
+ self.ct_dataset = load_dataset(CTMATCH_CLASSIFICATION_DATASET_ROOT, data_files=CLASSIFIER_DATA_PATH)
75
+ self.ct_dataset = train_test_val_split(self.ct_dataset, self.pipe_config.splits, self.pipe_config.seed)
76
+ self.add_features()
77
+ self.tokenize_dataset()
78
+ self.ct_dataset = self.ct_dataset.rename_column("label", "labels")
79
+ # self.ct_dataset = self.ct_dataset.rename_column("topic", "sentence1")
80
+ # self.ct_dataset = self.ct_dataset.rename_column("doc", "sentence2")
81
+ self.ct_dataset.set_format(type='torch', columns=['doc', 'labels', 'topic', 'input_ids', 'attention_mask'])
82
+ if not self.pipe_config.use_trainer:
83
+ self.ct_dataset = self.ct_dataset.remove_columns(['doc', 'topic']) # removing labels for next-token prediction...
84
+
85
+ self.ct_train_dataset_df = self.ct_dataset['train'].remove_columns(['input_ids', 'attention_mask', 'token_type_ids']).to_pandas()
86
+
87
+ return self.ct_dataset
88
+
89
+
90
+ def add_features(self) -> None:
91
+ if self.pipe_config.convert_snli:
92
+ names = ['contradiction', 'entailment', 'neutral']
93
+ else:
94
+ names = ["not_relevant", "partially_relevant", "relevant"]
95
+
96
+ features = Features({
97
+ 'doc': Value(dtype='string', id=None),
98
+ 'label': ClassLabel(names=names),
99
+ 'topic': Value(dtype='string', id=None)
100
+ })
101
+ self.ct_dataset["train"] = self.ct_dataset["train"].map(lambda x: x, batched=True, features=features)
102
+ self.ct_dataset["test"] = self.ct_dataset["test"].map(lambda x: x, batched=True, features=features)
103
+ self.ct_dataset["validation"] = self.ct_dataset["validation"].map(lambda x: x, batched=True, features=features)
104
+
105
+
106
+ def tokenize_function(self, examples):
107
+ return self.classifier_tokenizer(
108
+ examples["topic"], examples["doc"],
109
+ truncation=self.pipe_config.truncation,
110
+ padding=self.pipe_config.padding,
111
+ max_length=self.pipe_config.max_length
112
+ )
113
+
114
+ def tokenize_dataset(self):
115
+ self.ct_dataset = self.ct_dataset.map(self.tokenize_function, batched=True)
116
+
117
+
118
+ def get_category_data(self, vectorize=True):
119
+ category_data = dict()
120
+ sorted_cat_keys = None
121
+ for cdata in get_processed_data(self.pipe_config.category_path):
122
+
123
+ # cdata = {<nct_id>: {cat1: float1, cat2: float2...}}
124
+ cdata_id, cdata_dict = list(cdata.items())[0]
125
+ if sorted_cat_keys is None:
126
+ sorted_cat_keys = sorted(cdata_dict.keys())
127
+
128
+ if vectorize:
129
+ cat_vec = np.asarray([cdata_dict[k] for k in sorted_cat_keys])
130
+ else:
131
+ cat_vec = cdata_dict
132
+
133
+ category_data[cdata_id] = cat_vec
134
+ return category_data
135
+
136
+
137
+
138
+ # ------------------ IR Data Loading ------------------ #
139
+ def process_ir_data_from_hf(self, ds_path, is_text: bool = False):
140
+ ds = load_dataset(CTMATCH_IR_DATASET_ROOT, data_files=ds_path)
141
+ if is_text:
142
+ return pd.DataFrame(ds['train'])
143
+
144
+ arrays = [np.asarray(a['text'].split(','), dtype=float) for a in ds['train']]
145
+ return pd.DataFrame(arrays)
146
+
147
+ def load_ir_data(self) -> None:
148
+ self.index2docid = self.process_ir_data_from_hf(INDEX2DOCID_PATH, is_text=True)
149
+ self.doc_embeddings_df = self.process_ir_data_from_hf(DOC_EMBEDDINGS_VEC_PATH)
150
+ self.doc_categories_df = self.process_ir_data_from_hf(DOC_CATEGORIES_VEC_PATH)
151
+ self.doc_texts_df = self.process_ir_data_from_hf(DOC_TEXTS_PATH, is_text=True)
152
+
ctmatch/eda.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, NamedTuple, Tuple
3
+ from utils.ctmatch_utils import get_processed_data
4
+ from collections import defaultdict
5
+ import ct_data_paths
6
+ import random
7
+
8
+ from ctproc.scripts.vis_scripts import (
9
+ analyze_test_rels
10
+ )
11
+
12
+
13
+ class ExplorePaths(NamedTuple):
14
+ doc_path: str
15
+ topic_path: str
16
+ rel_path: str
17
+
18
+
19
+
20
+ # --------------------------------------------------------------------------------------------------------------- #
21
+ # EDA functions
22
+ # --------------------------------------------------------------------------------------------------------------- #
23
+
24
+ def explore_kz_data(rand_print: float = 0.001) -> None:
25
+ kz_data_paths = ExplorePaths(
26
+ rel_path = ct_data_paths.KZ_REL_PATH,
27
+ doc_path = ct_data_paths.KZ_PROCESSED_DOC_PATH,
28
+ topic_path = ct_data_paths.KZ_RELLED_TOPIC_PATH
29
+ )
30
+
31
+ explore_data(kz_data_paths, rand_print=rand_print)
32
+
33
+
34
+ def explore_trec_data(part: int = 1, rand_print: float = 0.001) -> None:
35
+ # post processing analysis
36
+ trec_data_paths = ExplorePaths(
37
+ rel_path = ct_data_paths.TREC_REL_PATH,
38
+ doc_path = f'/Users/jameskelly/Documents/cp/ctmatch/data/trec_data/processed_trec_data/processed_trec22_docs_part{part}.jsonl',
39
+ topic_path = ct_data_paths.TREC_RELLED_TOPIC_PATH
40
+ )
41
+
42
+ explore_data(trec_data_paths, rand_print=rand_print)
43
+
44
+
45
+
46
+ def explore_data(data_paths: ct_data_paths.ExplorePaths, rand_print: float) -> None:
47
+
48
+ # process relevancy judgements
49
+ type_dict, rel_dict, all_qrelled_docs = analyze_test_rels(data_paths.rel_path)
50
+
51
+ # get processed topics
52
+ id2topic = {t['id']:t for t in get_processed_data(data_paths.topic_path)}
53
+ print(f"number of processed topics: {len(id2topic)}")
54
+
55
+ # get relevant processed docs
56
+ id2docs = {doc['id']:doc for doc in get_processed_data(data_paths.doc_path, get_only=all_qrelled_docs)}
57
+ print(f"number of relevant processed docs: {len(id2docs)}")
58
+
59
+ explore_pairs(id2topic, id2docs, rel_dict, max_print=1000, rand_print=rand_print)
60
+
61
+
62
+
63
+
64
+
65
+ def explore_pairs(id2topic: Dict[str, Dict[str, str]], id2docs: Dict[str, Dict[str, str]], rel_dict: Dict[str, Dict[str, str]], rand_print: float, max_print:int = 100000) -> None:
66
+ rel_scores = defaultdict(int)
67
+ age_mismatches, gender_mismatches = 0, 0
68
+ for pt_id, topic in id2topic.items():
69
+ for doc_id in rel_dict[pt_id]:
70
+ if doc_id in id2docs:
71
+ rel_score = rel_dict[pt_id][doc_id]
72
+ rel_scores[rel_score] += 1
73
+ if rel_score == 2:
74
+ age_mismatches, gender_mismatches = check_match(
75
+ topic = topic,
76
+ doc = id2docs[doc_id],
77
+ rel_score = rel_score,
78
+ age_mismatches = age_mismatches,
79
+ gender_mismatches = gender_mismatches
80
+ )
81
+
82
+ if random.random() < rand_print:
83
+ print_pair(topic, id2docs[doc_id], rel_score, marker='%')
84
+
85
+ print(rel_scores.items())
86
+ print(f"{age_mismatches=}, {gender_mismatches=}")
87
+
88
+
89
+
90
+
91
+ def check_match(topic: Dict[str, str], doc: Dict[str, str], rel_score: int, age_mismatches: int, gender_mismatches: int) -> Tuple[int, int]:
92
+ age_matches = age_match(doc['elig_min_age'], doc['elig_max_age'], topic['age'])
93
+ if not age_matches:
94
+ #print_pair(topic, doc, rel_score)
95
+ age_mismatches += 1
96
+
97
+ gender_matches = gender_match(doc['elig_gender'], topic['gender'])
98
+ if not gender_matches:
99
+ #print_pair(topic, doc, rel_score)
100
+ gender_mismatches += 1
101
+
102
+ return age_mismatches, gender_mismatches
103
+
104
+
105
+
106
+ def print_pair(topic: Dict[str, str], doc: Dict[str, str], rel_score: int, marker: str = '*') -> None:
107
+ print(marker*200)
108
+ print(f"topic id: {topic['id']}, nct_id: {doc['id']}, rel score: {rel_score}")
109
+ print(f"topic info: \nage: {topic['age']}, gender: {topic['gender']}")
110
+ print(topic['raw_text'])
111
+ print(f"doc info: gender: {doc['elig_gender']}, min age: {doc['elig_min_age']}, max age: {doc['elig_max_age']}")
112
+ print(doc['elig_crit']['raw_text'])
113
+ print(marker*200)
114
+ print()
ctmatch/evaluator.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+ from typing import List, NamedTuple, Optional, Tuple, Union
4
+
5
+ from .utils.eval_utils import (
6
+ calc_first_positive_rank, calc_f1, get_kz_topic2text, get_trec_topic2text
7
+ )
8
+ from .pipeconfig import PipeConfig
9
+ from .match import CTMatch
10
+ from pathlib import Path
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class EvaluatorConfig(NamedTuple):
18
+ rel_paths: List[str]
19
+ trec_topic_path: Union[Path, str] = None
20
+ kz_topic_path: Union[Path, str] = None
21
+ max_topics: int = 200
22
+ openai_api_key: Optional[str] = None
23
+ filters: Optional[List[str]] = None
24
+ sanity_check_ids: Optional[List[str]] = None
25
+
26
+
27
+ class Evaluator:
28
+ def __init__(self, eval_config: EvaluatorConfig) -> None:
29
+ self.rel_paths: List[str] = eval_config.rel_paths
30
+ self.trec_topic_path: Union[Path, str] = eval_config.trec_topic_path
31
+ self.kz_topic_path: Union[Path, str] = eval_config.kz_topic_path
32
+
33
+ self.rel_dict: dict = None
34
+ self.topicid2text: dict = None
35
+ self.ctm = None
36
+ self.openai_api_key = eval_config.openai_api_key
37
+ self.filters = eval_config.filters
38
+ self.sanity_check_ids = eval_config.sanity_check_ids
39
+
40
+ assert self.rel_paths is not None, "paths to relevancy judgments must be set in pipe_config if pipe_config.evaluate=True"
41
+ assert ((self.trec_topic_path is not None) or (self.kz_topic_path is not None)), "at least one of trec_topic_path or kz_topic_path) must be set as pipe_config.evaluate=True"
42
+
43
+ self.setup()
44
+
45
+ self.max_topics: int = len(self.topicid2text) if eval_config.max_topics is None else min(len(self.topicid2text), eval_config.max_topics)
46
+
47
+
48
+
49
+ def get_combined_rel_dict(self, rel_paths: List[str]) -> dict:
50
+ combined_rel_dict = dict()
51
+ for rel_path in rel_paths:
52
+ with open(rel_path, 'r') as f:
53
+ for line in f.readlines():
54
+ topic_id, _, doc_id, rel = line.split()
55
+ if topic_id not in combined_rel_dict:
56
+ combined_rel_dict[topic_id] = dict()
57
+ combined_rel_dict[topic_id][doc_id] = int(rel)
58
+ return combined_rel_dict
59
+
60
+ def setup(self):
61
+ self.rel_dict = self.get_combined_rel_dict(self.rel_paths)
62
+ self.topicid2text = dict()
63
+ if self.kz_topic_path is not None:
64
+ self.topicid2text = get_kz_topic2text(self.kz_topic_path)
65
+
66
+ if self.trec_topic_path is not None:
67
+ self.topicid2text.update(get_trec_topic2text(self.trec_topic_path))
68
+
69
+ # loads all remaining needed datasets into memory
70
+ pipe_config = PipeConfig(
71
+ openai_api_key=self.openai_api_key,
72
+ ir_setup=True,
73
+ filters=self.filters
74
+ )
75
+ self.ctm = CTMatch(pipe_config=pipe_config)
76
+
77
+
78
+
79
+ def evaluate(self):
80
+ """
81
+ desc: run the pipeline over every topic and associated labelled set of documents,
82
+ and compute the mean mrr over all topics (how far down to the first relevant document)
83
+ """
84
+ frrs, f1s, fprs = [], [], []
85
+ for topic_id, topic_text in tqdm(list(self.topicid2text.items())[:self.max_topics]):
86
+
87
+ if topic_id not in self.rel_dict:
88
+ # can't evaluate with no judgments
89
+ continue
90
+
91
+ doc_ids = list(self.rel_dict[topic_id].keys())
92
+ logger.info(f"number of ranked docs: {len(doc_ids)}")
93
+ doc_set = self.get_indexes_from_ids(doc_ids)
94
+
95
+ # run IR pipeline on set of indexes corresponding to labelled doc_ids
96
+ ranked_pairs = self.ctm.match_pipeline(topic_text, doc_set=doc_set)
97
+
98
+ # get NCTIDs from ranking
99
+ ranked_ids = [nct_id for nct_id, doc_text in ranked_pairs]
100
+
101
+ # calculate metrics
102
+ fpr, frr = calc_first_positive_rank(ranked_ids, self.rel_dict[topic_id])
103
+ f1 = calc_f1(ranked_ids, self.rel_dict[topic_id])
104
+
105
+ if self.sanity_check_ids is not None and (topic_id in self.sanity_check_ids):
106
+ self.sanity_check(topic_id, topic_text, ranked_pairs, self.rel_dict[topic_id])
107
+
108
+ fprs.append(fpr)
109
+ frrs.append(frr)
110
+ f1s.append(f1)
111
+
112
+ mean_fpr = sum(fprs)/len(fprs)
113
+ std_fpr = np.std(fprs)
114
+ mean_frr = sum(frrs)/len(frrs)
115
+ std_frr = np.std(frrs)
116
+ mean_f1 = sum(f1s)/len(f1s)
117
+ std_f1 = np.std(f1s)
118
+
119
+ return {
120
+ "mean_fpr":mean_fpr, "std_fpr":std_fpr,
121
+ "mean_frr":mean_frr, "std_frr":std_frr,
122
+ "mean_f1":mean_f1, "std_f1":std_f1
123
+ }
124
+
125
+
126
+ def get_indexes_from_ids(self, doc_id_set: List[str]) -> List[int]:
127
+ """
128
+ desc: get the indexes of the documents in doc_id_set in the order they appear in the ranking
129
+ returns: list of indexes
130
+ """
131
+ doc_indices = []
132
+ for doc_id in doc_id_set:
133
+ index_row = np.where(self.ctm.data.index2docid['text'] == doc_id)
134
+ if len(index_row[0]) == 0:
135
+ continue
136
+ doc_indices.append(index_row[0][0])
137
+ return doc_indices
138
+
139
+ def sanity_check(self, topic_id, topic_text, ranked_pairs: List[Tuple[str, str]], rel_dict) -> None:
140
+ logger.info(f"{topic_id=} {topic_text}")
141
+ for doc_id, doc_text in ranked_pairs:
142
+ rel_score = rel_dict[doc_id]
143
+ logger.info(rel_score, doc_id, doc_text)
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
ctmatch/match.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+ from typing import Any, Dict, List, Optional, Tuple
4
+
5
+
6
+ # external imports
7
+ from sentence_transformers import SentenceTransformer
8
+ from transformers import pipeline
9
+ from numpy.linalg import norm
10
+ from pathlib import Path
11
+ from sklearn import svm
12
+ import numpy as np
13
+ import torch
14
+ import json
15
+
16
+
17
+ # package tools
18
+ from .models.classifier_model import ClassifierModel
19
+ from .utils.ctmatch_utils import get_processed_data, exclusive_argmax
20
+ from .models.gen_model import GenModel
21
+ from .pipeconfig import PipeConfig
22
+ from .pipetopic import PipeTopic
23
+ from .dataprep import DataPrep
24
+
25
+
26
+ logger = logging.getLogger(__name__)
27
+ logger.setLevel(logging.INFO)
28
+
29
+
30
+ CT_CATEGORIES = [
31
+ "pulmonary", "cardiac", "gastrointestinal", "renal", "psychological", "genetic", "pediatric",
32
+ "neurological", "cancer", "reproductive", "endocrine", "infection", "healthy", "other"
33
+ ]
34
+
35
+
36
+ GEN_INIT_PROMPT = "I will give you a patient description and a set of clinical trial documents. Each document will have a NCTID. I would like you to return the set of NCTIDs ranked from most to least relevant for patient in the description.\n"
37
+
38
+
39
+ class CTMatch:
40
+
41
+ def __init__(self, pipe_config: Optional[PipeConfig] = None) -> None:
42
+ # default to model config with full ir setup
43
+ self.pipe_config = pipe_config if pipe_config is not None else PipeConfig(ir_setup=True)
44
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
45
+ self.data = DataPrep(self.pipe_config)
46
+ self.classifier_model = ClassifierModel(self.pipe_config, self.data, self.device)
47
+ self.embedding_model = SentenceTransformer(self.pipe_config.embedding_model_checkpoint)
48
+ self.gen_model = GenModel(self.pipe_config)
49
+ self.category_model = None
50
+ self.filters: Optional[List[str]] = pipe_config.filters
51
+
52
+ # filter params
53
+ self.sim_top_n = 10000
54
+ self.svm_top_n = 100
55
+ self.classifier_top_n = 50
56
+ self.gen_top_n = 10
57
+
58
+
59
+ # main api method
60
+ def match_pipeline(self, topic: str, top_k: int = 10, doc_set: Optional[List[int]] = None) -> List[str]:
61
+
62
+ if doc_set is None:
63
+ # start off will all doc indexes
64
+ doc_set = [i for i in range(len(self.data.index2docid))]
65
+ else:
66
+ self.reset_filter_params(len(doc_set))
67
+
68
+ # get topic representations for pipeline filters
69
+ pipe_topic = self.get_pipe_topic(topic)
70
+
71
+ if self.filters is None or ('sim' in self.filters):
72
+ # first filter, category + embedding similarity
73
+ doc_set = self.sim_filter(pipe_topic, doc_set, top_n=self.sim_top_n)
74
+
75
+ if self.filters is None or ('svm' in self.filters):
76
+ # second filter, SVM
77
+ doc_set = self.svm_filter(pipe_topic, doc_set, top_n=self.svm_top_n)
78
+
79
+ if self.filters is None or ('classifier' in self.filters):
80
+ # third filter, classifier-LM (reranking)
81
+ doc_set = self.classifier_filter(pipe_topic, doc_set, top_n=self.classifier_top_n)
82
+
83
+ if self.filters is None or ('gen' in self.filters):
84
+ # fourth filter, generative-LM
85
+ doc_set = self.gen_filter(pipe_topic, doc_set, top_n=top_k)
86
+
87
+ return self.get_return_data(doc_set[:min(top_k, len(doc_set))])
88
+
89
+
90
+ def reset_filter_params(self, val: int) -> None:
91
+ self.sim_top_n = self.svm_top_n = self.classifier_top_n = self.gen_top_n = val
92
+
93
+
94
+ # ------------------------------------------------------------------------------------------ #
95
+ # filtering methods
96
+ # ------------------------------------------------------------------------------------------ #
97
+
98
+ def sim_filter(self, pipe_topic: PipeTopic, doc_set: List[int], top_n: int) -> List[int]:
99
+ """
100
+ filter documents by similarity to topic
101
+ doing this with loop and cosine similarity instead of linear kernel because of memory issues
102
+ """
103
+ logger.info(f"running sim filter on {len(doc_set)} docs")
104
+
105
+ topic_cat_vec = exclusive_argmax(pipe_topic.category_vec)
106
+ norm_topic_emb = norm(pipe_topic.embedding_vec)
107
+ cosine_dists = []
108
+ for doc_idx in doc_set:
109
+ doc_cat_vec = self.redist_other_category(self.data.doc_categories_df.iloc[doc_idx].values)
110
+
111
+ # only consider strongest predicted category
112
+ doc_cat_vec = exclusive_argmax(doc_cat_vec)
113
+ doc_emb_vec = self.data.doc_embeddings_df.iloc[doc_idx].values
114
+
115
+ topic_argmax = np.argmax(topic_cat_vec)
116
+ doc_argmax = np.argmax(doc_cat_vec)
117
+ cat_dist = 0. if (topic_argmax == doc_argmax) else 1.
118
+ emb_dist = np.dot(pipe_topic.embedding_vec, doc_emb_vec) / (norm_topic_emb * norm(doc_emb_vec))
119
+ combined_dist = cat_dist + emb_dist
120
+ cosine_dists.append(combined_dist)
121
+
122
+ sorted_indices = list(np.argsort(cosine_dists))[:min(len(doc_set), top_n)]
123
+
124
+ # return top n doc indices by combined similiarity, biggest to smallest
125
+ return [doc_set[i] for i in sorted_indices]
126
+
127
+
128
+ def svm_filter(self, topic: PipeTopic, doc_set: List[int], top_n: int) -> List[int]:
129
+ """
130
+ filter documents by training an SVM on topic and doc embeddings
131
+ """
132
+ logger.info(f"running svm filter on {len(doc_set)} documents")
133
+
134
+ # build training data and prediction vector of single positive class for SVM
135
+ topic_embedding_vec = topic.embedding_vec[np.newaxis, :]
136
+ x = np.concatenate([topic_embedding_vec, self.data.doc_embeddings_df.iloc[doc_set].values], axis=0)
137
+ y = np.zeros(len(doc_set) + 1)
138
+ y[0] = 1
139
+
140
+ # define and fit SVM
141
+ clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
142
+ clf.fit(x, y)
143
+
144
+ # infer for similarities
145
+ similarities = clf.decision_function(x)
146
+
147
+ # get top n doc indices by similiarity, biggest to smallest
148
+ result = list(np.argsort(-similarities)[:min(len(doc_set) + 1, top_n + 1)])
149
+
150
+ # remove topic from result
151
+ result.remove(0)
152
+
153
+ # indexes got shifted by 1 because topic was included in doc_set
154
+ return [doc_set[(r - 1)] for r in result]
155
+
156
+
157
+
158
+ def classifier_filter(self, pipe_topic: PipeTopic, doc_set: List[int], top_n: int) -> List[int]:
159
+ """
160
+ filter documents by classifier no relevance prediction
161
+ """
162
+ logger.info(f"running classifier filter on {len(doc_set)} documents")
163
+
164
+ # get doc texts
165
+ doc_texts = [v[0] for v in self.data.doc_texts_df.iloc[doc_set].values]
166
+
167
+ # sort by reverse irrelevant prediction
168
+ neg_predictions = np.asarray([p[0] for p in self.classifier_model.batch_inference(pipe_topic.topic_text, doc_texts, return_preds=True)])
169
+
170
+ # return top n doc indices by classifier, biggest to smallest
171
+ sorted_indices = list(np.argsort(neg_predictions)[:min(len(doc_set), top_n)])
172
+ return [doc_set[i] for i in sorted_indices]
173
+
174
+
175
+
176
+ def gen_filter(self, topic: PipeTopic, doc_set: List[int], top_n: int = 10) -> List[int]:
177
+ """
178
+ gen model supplies a ranking of remaming docs by evaluating the pairs of topic and doc texts
179
+
180
+ in order to overcome the context length limitation, we need to do a kind of left-binary search over multiple
181
+ prompts to arrive at a ranking that meets the number of documents requirement (top_n)
182
+
183
+ may take a few minutes to run through all queries and subqueries depending on size of doc_set
184
+
185
+ """
186
+ logger.info(f"running gen filter on {len(doc_set)} documents")
187
+
188
+ assert top_n > 0, "top_n must be greater than 0"
189
+
190
+ ranked_docs = doc_set
191
+ iters = 0
192
+ while (len(ranked_docs) > top_n) and (iters < 10) and (len(ranked_docs) // 2 > top_n):
193
+ query_prompts = self.get_subqueries(topic, ranked_docs)
194
+
195
+ logger.info(f"calling gen model on {len(query_prompts)} subqueries")
196
+
197
+ # get gen model response for each query_prompt
198
+ subrankings = []
199
+ for prompt in query_prompts:
200
+ subrank = self.gen_model.gen_response(prompt)
201
+
202
+ # keep the top half of each subranking
203
+ subrankings.extend(subrank[:len(subrank) // 2])
204
+
205
+ ranked_docs = subrankings
206
+ iters += 1
207
+
208
+ return ranked_docs[:min(len(ranked_docs), top_n)]
209
+
210
+ # ------------------------------------------------------------------------------------------ #
211
+ # filter helper methods
212
+ # ------------------------------------------------------------------------------------------ #
213
+
214
+ def get_pipe_topic(self, topic):
215
+ pipe_topic = PipeTopic(
216
+ topic_text=topic,
217
+ embedding_vec=self.get_embeddings([topic])[0], # 1 x embedding_dim (default=384)
218
+ category_vec=self.get_categories(topic) # 1 x 14
219
+ )
220
+ return pipe_topic
221
+
222
+
223
+ def get_embeddings(self, texts: List[str]) -> List[float]:
224
+ return self.embedding_model.encode(texts)
225
+
226
+ def get_categories(self, text: str) -> str:
227
+ if self.category_model is None:
228
+ self.category_model = pipeline(
229
+ 'zero-shot-classification',
230
+ model=self.pipe_config.category_model_checkpoint,
231
+ device=0
232
+ )
233
+ output = self.category_model(text, candidate_labels=CT_CATEGORIES)
234
+ score_dict = {output['labels'][i]:output['scores'][i] for i in range(len(output['labels']))}
235
+
236
+ # to be consistent with doc category vecs
237
+ sorted_keys = sorted(score_dict.keys())
238
+ return self.redist_other_category(np.array([score_dict[k] for k in sorted_keys]))
239
+
240
+ def redist_other_category(self, category_vec: np.ndarray, other_dim:int = 8) -> np.ndarray:
241
+ """
242
+ redistribute 'other' category weight to all other categories
243
+ """
244
+ other_wt = category_vec[other_dim]
245
+ other_wt_dist = other_wt / (len(category_vec) - 1)
246
+ redist_cat_vec = category_vec + other_wt_dist
247
+ redist_cat_vec[other_dim] = 0
248
+ return redist_cat_vec
249
+
250
+
251
+ def get_gen_query_prompt(self, topic: PipeTopic, doc_set: List[int]) -> str:
252
+ query_prompt = f"{GEN_INIT_PROMPT}Patient description: {topic.topic_text}\n"
253
+
254
+ for i, doc_text in enumerate(self.data.doc_texts_df.iloc[doc_set].values):
255
+ query_prompt += f"NCTID: {doc_set[i]}, "
256
+ query_prompt += f"Eligbility Criteria: {doc_text[0]}\n"
257
+
258
+ # not really token length bc not tokenized yet but close enough if we undershoot
259
+ prompt_len = len(query_prompt.split())
260
+ if prompt_len > self.pipe_config.max_query_length:
261
+ break
262
+
263
+ return query_prompt, i
264
+
265
+
266
+ def get_subqueries(self, topic: PipeTopic, doc_set: List[int]) -> List[str]:
267
+ query_prompts = []
268
+ i = 0
269
+ while i < len(doc_set) - 1:
270
+
271
+ # break the querying over remaining doc set into multiple prompts
272
+ query_prompt, used_i = self.get_gen_query_prompt(topic, doc_set[i:])
273
+ query_prompts.append(query_prompt)
274
+ i += used_i
275
+
276
+ return query_prompts
277
+
278
+
279
+ def get_return_data(self, doc_set: List[int]) -> List[Tuple[str, str]]:
280
+ return_data = []
281
+ for idx in doc_set:
282
+ nctid = self.data.index2docid.iloc[idx].values[0]
283
+ return_data.append((nctid, self.data.doc_texts_df.iloc[idx].values[0]))
284
+ return return_data
285
+
286
+
287
+
288
+ # ------------------------------------------------------------------------------------------ #
289
+ # data prep methods that rely on model in CTMatch object (not run during routine program)
290
+ # ------------------------------------------------------------------------------------------ #
291
+
292
+ def prep_ir_text(self, doc: Dict[str, List[str]], max_len: int = 512) -> str:
293
+ inc_text = ' '.join(doc['elig_crit']['include_criteria'])
294
+ exc_text = ' '.join(doc['elig_crit']['exclude_criteria'])
295
+ all_text = f"Inclusion Criteria: {inc_text}, Exclusion Criteria: {exc_text}"
296
+ split_text = all_text.split()
297
+ return ' '.join(split_text[:min(max_len, len(split_text))])
298
+
299
+
300
+ def prep_and_save_ir_dataset(self):
301
+ category_data = self.data.get_category_data()
302
+ with open(self.pipe_config.ir_save_path, 'w') as wf:
303
+ for ir_data in self.prep_ir_data():
304
+ ir_data['categories'] = str(category_data[ir_data['id']])
305
+ wf.write(json.dumps(ir_data))
306
+ wf.write('\n')
307
+
308
+
309
+ def prep_ir_data(self):
310
+ for data_path in self.pipe_config.processed_data_paths:
311
+ for i, doc in enumerate(get_processed_data(data_path)):
312
+ if i % 10000 == 0:
313
+ logger.info(f"Prepping doc {i}")
314
+
315
+ ir_data_entry = dict()
316
+ ir_data_entry['id'] = doc['id']
317
+ doc_text = self.prep_ir_text(doc)
318
+ ir_data_entry['doc_text'] = doc_text
319
+ yield ir_data_entry
320
+
321
+
322
+ def save_texts(self) -> Dict[int, str]:
323
+ idx2id = dict()
324
+ with open(Path(self.pipe_config.ir_save_path).parent / 'texts', 'w', encoding='utf-8') as wf:
325
+ for i, doc in enumerate(get_processed_data(self.pipe_config.ir_save_path)):
326
+ idx2id[i] = doc['id']
327
+ if i % 10000 == 0:
328
+ logger.info(f"Prepping doc {i}")
329
+
330
+ wf.write(doc['doc_text'])
331
+ wf.write('\n')
332
+ return idx2id
333
+
ctmatch/models/classifier_model.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+ from pathlib import Path
4
+ from tqdm.auto import tqdm
5
+ from typing import List, Tuple
6
+
7
+ from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, get_scheduler
8
+ from optimum.onnxruntime import ORTModelForSequenceClassification
9
+ from optimum.onnxruntime.configuration import OptimizationConfig
10
+ from optimum.onnxruntime import ORTOptimizer
11
+ import evaluate
12
+
13
+ from sklearn.metrics import confusion_matrix, classification_report
14
+ from sklearn.metrics import f1_score
15
+ from torch.utils.data import DataLoader
16
+ from torch.optim import AdamW
17
+ from torch import nn
18
+ import torch
19
+
20
+ from nn_pruning.patch_coordinator import ModelPatchingCoordinator, SparseTrainingArguments
21
+ from nn_pruning.inference_model_patcher import optimize_model
22
+ from nn_pruning.sparse_trainer import SparseTrainer
23
+
24
+
25
+ from ..pipeconfig import PipeConfig
26
+ from ..dataprep import DataPrep
27
+
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ PRUNED_HUB_MODEL_NAME = 'semaj83/scibert_finetuned_pruned_ctmatch'
32
+
33
+
34
+ class WeightedLossTrainer(Trainer):
35
+ def __init__(self, label_weights, *args, **kwargs):
36
+ super().__init__(*args, **kwargs)
37
+ self.label_weights = label_weights
38
+
39
+ def compute_loss(self, model, inputs, return_outputs=False):
40
+ outputs = model(**inputs)
41
+ logits = outputs.get("logits")
42
+ labels = inputs.get("labels")
43
+ loss_func = nn.CrossEntropyLoss(weight=self.label_weights)
44
+ loss = loss_func(logits, labels)
45
+ return (loss, outputs) if return_outputs else loss
46
+
47
+
48
+
49
+
50
+ class PruningTrainer(SparseTrainer, WeightedLossTrainer):
51
+ def __init__(self, sparse_args, *args, **kwargs):
52
+ WeightedLossTrainer.__init__(self, *args, **kwargs)
53
+ SparseTrainer.__init__(self, sparse_args)
54
+
55
+
56
+ class ClassifierModel:
57
+
58
+ def __init__(self, model_config: PipeConfig, data: DataPrep, device: str):
59
+ self.model_config = model_config
60
+ self.dataset = data.ct_dataset
61
+ self.tokenizer = data.classifier_tokenizer
62
+ self.tokenize_func = data.tokenize_function
63
+ self.trainer = None
64
+ self.optimizer = None
65
+ self.lr_scheduler = None
66
+ self.device = device
67
+
68
+ if not self.model_config.ir_setup:
69
+ self.train_dataset_df = data.ct_dataset['train'].to_pandas()
70
+ self.num_training_steps = self.model_config.train_epochs * len(self.dataset['train'])
71
+
72
+ self.model = self.load_model()
73
+ self.pruned_model = None
74
+
75
+ if not self.model_config.use_trainer and not self.model_config.ir_setup:
76
+ self.train_dataloader, self.val_dataloader = self.get_dataloaders()
77
+
78
+
79
+ if self.model_config.prune:
80
+ self.prune_trainer = None
81
+ self.sparse_args = self.get_sparse_args()
82
+ self.mpc = self.get_model_patching_coordinator()
83
+
84
+
85
+ # ------------------ Model Loading ------------------ #
86
+ def get_model(self):
87
+ if self.model_config.num_classes == 0:
88
+ return AutoModelForSequenceClassification.from_pretrained(self.model_config.classifier_model_checkpoint)
89
+
90
+ id2label, label2id = self.get_label_mapping()
91
+ model = AutoModelForSequenceClassification.from_pretrained(
92
+ self.model_config.classifier_model_checkpoint,
93
+ num_labels=self.model_config.num_classes, # makes the last head be replaced with a linear layer with num_labels outputs (fine-tuning)
94
+ id2label=id2label, label2id=label2id,
95
+ ignore_mismatched_sizes=True # because of pruned model changes
96
+ )
97
+
98
+ if 'pruned' in self.model_config.classifier_model_checkpoint:
99
+ model = optimize_model(model, "dense")
100
+
101
+ return self.add_pad_token(model)
102
+
103
+
104
+ def add_pad_token(self, model):
105
+ if model.config.pad_token_id is None:
106
+ model.config.pad_token_id = model.config.eos_token_id
107
+ return model
108
+
109
+
110
+ def load_model(self):
111
+ self.model = self.get_model()
112
+
113
+ if self.model_config.ir_setup:
114
+ return self.model
115
+
116
+ self.optimizer = AdamW(self.model.parameters(), lr=self.model_config.learning_rate, weight_decay=self.model_config.weight_decay)
117
+ self.num_training_steps = self.model_config.train_epochs * len(self.dataset['train'])
118
+ self.lr_scheduler = get_scheduler(
119
+ name="linear",
120
+ optimizer=self.optimizer,
121
+ num_warmup_steps=self.model_config.warmup_steps,
122
+ num_training_steps=self.num_training_steps
123
+ )
124
+
125
+ if self.model_config.use_trainer and not self.model_config.prune:
126
+ self.trainer = self.get_trainer()
127
+ else:
128
+ self.model = self.model.to(self.device)
129
+
130
+ return self.model
131
+
132
+
133
+ def get_label_mapping(self):
134
+ #id2label = {idx:self.dataset['train'].features["labels"].int2str(idx) for idx in range(3)}
135
+ id2label = {'0':'not_relevant', '1':'partially_relevant', '2':'relevant'}
136
+ label2id = {v:k for k, v in id2label.items()}
137
+ return id2label, label2id
138
+
139
+ def get_label_weights(self):
140
+ label_weights = (1 - (self.train_dataset_df["labels"].value_counts().sort_index() / len(self.train_dataset_df))).values
141
+ label_weights = torch.from_numpy(label_weights).float().to("cuda")
142
+
143
+
144
+ def get_trainer(self):
145
+ return WeightedLossTrainer(
146
+ model=self.model,
147
+ optimizers=(self.optimizer, self.lr_scheduler),
148
+ args=self.get_training_args_obj(),
149
+ compute_metrics=self.compute_metrics,
150
+ train_dataset=self.dataset["train"],
151
+ eval_dataset=self.dataset["validation"],
152
+ tokenizer=self.tokenizer,
153
+ label_weights=self.get_label_weights()
154
+ )
155
+
156
+
157
+ def get_training_args_obj(self):
158
+ output_dir = self.model_config.output_dir if self.model_config.output_dir is not None else self.model_config.classifier_data_path.parent.parent.as_posix()
159
+
160
+ return TrainingArguments(
161
+ output_dir=output_dir,
162
+ num_train_epochs=self.model_config.train_epochs,
163
+ learning_rate=self.model_config.learning_rate,
164
+ per_device_train_batch_size=self.model_config.batch_size,
165
+ per_device_eval_batch_size=self.model_config.batch_size,
166
+ weight_decay=self.model_config.weight_decay,
167
+ evaluation_strategy="epoch",
168
+ logging_steps=len(self.dataset["train"]) // self.model_config.batch_size,
169
+ fp16=self.model_config.fp16
170
+ )
171
+
172
+
173
+
174
+ def train_and_predict(self):
175
+ if self.trainer is not None:
176
+ self.trainer.train()
177
+ predictions = self.trainer.predict(self.dataset["test"])
178
+ logger.info(predictions.metrics.items())
179
+ else:
180
+ self.loss_func = nn.CrossEntropyLoss(weight=self.get_label_weights())
181
+ self.manual_train()
182
+ self.manual_eval()
183
+
184
+
185
+
186
+ # ------------------ native torch training loop ------------------ #
187
+ def get_dataloaders(self) -> Tuple[DataLoader, DataLoader]:
188
+ train_dataloader = DataLoader(self.dataset['train'], shuffle=True, batch_size=self.model_config.batch_size)
189
+ val_dataloader = DataLoader(self.dataset['validation'], batch_size=self.model_config.batch_size)
190
+ return train_dataloader, val_dataloader
191
+
192
+
193
+
194
+ # taken from ctmatch for messing about
195
+ def manual_train(self):
196
+ progress_bar = tqdm(range(self.num_training_steps))
197
+ self.model.train()
198
+ for epoch in range(self.model_config.train_epochs):
199
+ for batch in tqdm(self.train_dataloader):
200
+ batch = {k: v.to(self.model.device) for k, v in batch.items()}
201
+ outputs = self.model(**batch)
202
+ loss = self.loss_func(outputs.logits, batch['labels'])
203
+ #total_loss += loss.item()
204
+ loss.backward()
205
+
206
+ self.optimizer.step()
207
+ self.lr_scheduler.step()
208
+ self.optimizer.zero_grad()
209
+
210
+ self.manual_eval()
211
+ logger.info(f"{loss=}")
212
+ progress_bar.update(1)
213
+
214
+
215
+
216
+
217
+ def manual_eval(self):
218
+ metric = evaluate.load("f1")
219
+ self.model.eval()
220
+ for batch in self.val_dataloader:
221
+ batch = {k: v.to(self.model.device) for k, v in batch.items()}
222
+
223
+ # don't learn during evaluation
224
+ with torch.no_grad():
225
+ outputs = self.model(**batch)
226
+
227
+ logits = outputs.logits
228
+ predictions = torch.argmax(logits, dim=-1)
229
+ metric.add_batch(predictions=predictions, references=batch["labels"])
230
+
231
+ logger.info(metric.compute(average='weighted'))
232
+
233
+
234
+
235
+
236
+ def get_sklearn_metrics(self):
237
+ with torch.no_grad():
238
+ if self.model_config.use_trainer:
239
+ if self.model_config.prune:
240
+ self.prune_trainer.model.to(self.device)
241
+ logger.info("using pruned trainer model")
242
+ preds = self.prune_trainer.predict(self.dataset['test']).predictions
243
+ else:
244
+ preds = self.trainer.predict(self.dataset['test']).predictions
245
+
246
+ if "bart" in self.model_config.name:
247
+ preds = preds[0]
248
+
249
+ y_preds = list(preds.argmax(axis=1))
250
+ else:
251
+
252
+ if self.model_config.prune:
253
+ model = self.pruned_model.to(self.device)
254
+ else:
255
+ model = self.model.to(self.device)
256
+ y_preds = []
257
+ for input_ids in self.dataset['test']['input_ids']:
258
+ input_ids = torch.tensor(input_ids).unsqueeze(0).to(self.device)
259
+ y_pred = model(input_ids).logits.argmax().item()
260
+ y_preds.append(y_pred)
261
+
262
+ y_trues = list(self.dataset['test']['labels'])
263
+ return confusion_matrix(y_trues, y_preds), classification_report(y_trues, y_preds)
264
+
265
+
266
+ def compute_metrics(self, pred):
267
+ labels = pred.label_ids
268
+ preds = pred.predictions
269
+ if "bart" in self.model_config.name:
270
+ preds = preds[0]
271
+
272
+ preds = preds.argmax(-1)
273
+ f1 = f1_score(labels, preds, average="weighted")
274
+ return {"f1":f1}
275
+
276
+ def inference_single_example(self, topic: str, doc: str, return_preds: bool = False) -> str:
277
+ """
278
+ desc: method to predict relevance label on new topic, doc examples
279
+ """
280
+ ex = {'doc':doc, 'topic':topic}
281
+ with torch.no_grad():
282
+ inputs = torch.LongTensor(self.tokenize_func(ex)['input_ids']).unsqueeze(0)
283
+ outputs = self.model(inputs).logits
284
+ if return_preds:
285
+ return torch.nn.functional.softmax(outputs, dim=1).squeeze(0)
286
+ return str(outputs.argmax().item())
287
+
288
+
289
+ def batch_inference(self, topic: str, docs: List[str], return_preds: bool = False) -> List[str]:
290
+ topic_repeats = [topic for _ in range(len(docs))]
291
+ inputs = self.tokenizer(
292
+ topic_repeats, docs, return_tensors='pt',
293
+ truncation=self.model_config.truncation,
294
+ padding=self.model_config.padding,
295
+ max_length=self.model_config.max_length
296
+ )
297
+
298
+ with torch.no_grad():
299
+ outputs = torch.nn.functional.softmax(self.model(**inputs).logits, dim=1)
300
+
301
+ if return_preds:
302
+ return outputs
303
+
304
+ return outputs.argmax(dim=1).tolist()
305
+
306
+
307
+
308
+ # ------------------ pruning ------------------ #
309
+
310
+ def prune_model(self):
311
+ self.mpc.patch_model(self.model)
312
+ self.model.save_pretrained("models/patched")
313
+ self.prune_trainer = self.get_pruning_trainer()
314
+ self.prune_trainer.set_patch_coordinator(self.mpc)
315
+ self.prune_trainer.train()
316
+ self.mpc.compile_model(self.prune_trainer.model)
317
+ if self.model_config.push_to_hub:
318
+ # can't save the optimized model to hub
319
+ self.prune_trainer.model.push_to_hub(PRUNED_HUB_MODEL_NAME)
320
+
321
+ self.pruned_model = optimize_model(self.prune_trainer.model, "dense")
322
+
323
+
324
+
325
+ def get_sparse_args(self):
326
+ sparse_args = SparseTrainingArguments()
327
+
328
+ hyperparams = {
329
+ "dense_pruning_method": "topK:1d_alt",
330
+ "attention_pruning_method": "topK",
331
+ "initial_threshold": 1.0,
332
+ "final_threshold": 0.5,
333
+ "initial_warmup": 1,
334
+ "final_warmup": 3,
335
+ "attention_block_rows":32,
336
+ "attention_block_cols":32,
337
+ "attention_output_with_dense": 0
338
+ }
339
+
340
+ for k,v in hyperparams.items():
341
+ if hasattr(sparse_args, k):
342
+ setattr(sparse_args, k, v)
343
+ else:
344
+ print(f"sparse_args does not have argument {k}")
345
+
346
+ return sparse_args
347
+
348
+
349
+ def get_pruning_trainer(self):
350
+ return PruningTrainer(
351
+ sparse_args=self.sparse_args,
352
+ args=self.get_training_args_obj(),
353
+ model=self.model,
354
+ train_dataset=self.dataset["train"],
355
+ eval_dataset=self.dataset["validation"],
356
+ tokenizer=self.tokenizer,
357
+ compute_metrics=self.compute_metrics,
358
+ label_weights=self.get_label_weights()
359
+ )
360
+
361
+
362
+
363
+ def get_model_patching_coordinator(self):
364
+ return ModelPatchingCoordinator(
365
+ sparse_args=self.sparse_args,
366
+ device=self.device,
367
+ cache_dir="checkpoints",
368
+ logit_names="logits",
369
+ teacher_constructor=None
370
+ )
371
+
372
+
373
+ # onyx optimization
374
+ def optimize_model(self):
375
+ onnx_path = Path("onnx")
376
+ model_id = self.model_config.classifier_model_checkpoint
377
+ #assert self.pruned_model is not None, "pruned model must be loaded before optimizing"
378
+ opt_model = ORTModelForSequenceClassification.from_pretrained(model_id, from_transformers=True)
379
+ optimizer = ORTOptimizer.from_pretrained(opt_model)
380
+ optimization_config = OptimizationConfig(optimization_level=99) # enable all optimizations
381
+ optimizer.optimize(
382
+ save_dir=onnx_path,
383
+ optimization_config=optimization_config,
384
+ )
385
+ opt_model.save_pretrained(onnx_path)
386
+ self.tokenizer.save_pretrained(onnx_path)
387
+
388
+ #optimized_model = ORTModelForSequenceClassification.from_pretrained(onnx_path, file_name="model_optimized.onnx")
389
+
390
+ return opt_model
391
+
392
+
393
+
394
+
395
+
396
+
ctmatch/models/gen_model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import List, Optional
3
+
4
+ from ..pipeconfig import PipeConfig
5
+ import openai
6
+ import re
7
+
8
+
9
+
10
+
11
+ class GenModel:
12
+ def __init__(self, pipe_config: PipeConfig) -> None:
13
+ openai.api_key = pipe_config.openai_api_key
14
+ self.pipe_config = pipe_config
15
+
16
+
17
+ def gen_response(self, query_prompt: str, doc_set: Optional[List[int]] = None) -> List[int]:
18
+ """
19
+ uses openai model to return a ranking of ids
20
+ """
21
+ if self.pipe_config.gen_model_checkpoint == 'text-davinci-003':
22
+ response = openai.Completion.create(
23
+ model=self.pipe_config.gen_model_checkpoint,
24
+ prompt=query_prompt,
25
+ temperature=0,
26
+ max_tokens=200,
27
+ top_p=1,
28
+ frequency_penalty=0.0,
29
+ presence_penalty=0.0
30
+ )
31
+ else:
32
+ assert doc_set is not None, "doc_set must be provided for gpt-3.5-turbo"
33
+
34
+ # for gpt-3.5-turbo
35
+ response = openai.ChatCompletion.create(
36
+ model=self.pipe_config.gen_model_checkpoint,
37
+ messages = [{'role': 'user', 'content' : query_prompt}],
38
+ temperature=0.4,
39
+ max_tokens=200,
40
+ top_p=1,
41
+ frequency_penalty=0.2,
42
+ presence_penalty=0.0
43
+ )
44
+
45
+
46
+ if self.pipe_config.gen_model_checkpoint == 'text-davinci-003':
47
+ return self.post_process_chatgpt_response(response)
48
+ return self.post_process_gptturbo_response(response, doc_set=doc_set)
49
+
50
+
51
+ def post_process_chatgpt_response(self, response):
52
+ """
53
+ could be:
54
+ NCTID 6, NCTID 7, NCTID 5
55
+ NCTID: 6, 7, 5
56
+ 6, 7, 5
57
+ '1. 195155\n2. 186848\n3. 194407'
58
+ """
59
+ response_pattern = r"(?:NCTID\:?\s*)? ?(\d+)(?!\.)"
60
+ text = response['choices'][0]['text']
61
+ return [int(s) for s in re.findall(response_pattern, text)]
62
+
63
+ def post_process_gptturbo_response(self, response, doc_set: List[int]):
64
+ """
65
+ could be:
66
+ 'The most relevant clinical trial for this patient is ID 2, followed by ID 3. The remaining trials are not relevant for this patient's condition.'
67
+ """
68
+ text = response['choices'][0]['message']['content']
69
+ ranking = []
70
+ for substr in text.split():
71
+ if substr.isdigit():
72
+ ranking.append(int(substr))
73
+
74
+ # the rest are arbitrarily ranked
75
+ for ncid in doc_set:
76
+ if ncid not in ranking:
77
+ ranking.append(ncid)
78
+ return ranking
79
+
80
+
81
+
82
+
83
+
ctmatch/pipeconfig.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, List, NamedTuple, Optional
3
+ from pathlib import Path
4
+
5
+
6
+ class PipeConfig(NamedTuple):
7
+ name: str = 'scibert_finetuned_ctmatch'
8
+ classifier_model_checkpoint: str = 'semaj83/scibert_finetuned_ctmatch'
9
+ max_length: int = 512
10
+ padding: str = True
11
+ truncation: bool = True
12
+ batch_size: int = 16
13
+ learning_rate: float = 2e-5
14
+ train_epochs: int = 3
15
+ weight_decay: float = 0.01
16
+ warmup_steps: int = 500
17
+ seed: int = 42
18
+ splits: Dict[str, float] = {"train":0.8, "val":0.1}
19
+ classifier_data_path: Path = Path("combined_classifier_data.jsonl")
20
+ output_dir: Optional[str] = None
21
+ convert_snli: bool = False
22
+ use_trainer: bool = False
23
+ num_classes: int = 3
24
+ fp16: bool = False
25
+ early_stopping: bool = False
26
+ push_to_hub: bool = False
27
+ ir_save_path: Optional[str] = None
28
+ category_path: Optional[str] = None
29
+ processed_data_paths: Optional[List[str]] = None
30
+ max_query_length: int = 1200
31
+ category_model_checkpoint: str = "facebook/bart-large-mnli"
32
+ embedding_model_checkpoint: str = "sentence-transformers/all-MiniLM-L6-v2"
33
+ gen_model_checkpoint: str = 'text-davinci-003'
34
+ max_gen: int = 100
35
+ openai_api_key: Optional[str] = None
36
+ ir_setup: bool = False # if true, use the IR model setup, no classifier training or dataprep
37
+ filters: Optional[List[str]] = None # if provided, only use these filters for the IR model, options are {'sim', 'svm', 'classifier', 'gen'}
38
+ prune: bool = False # if true, creates a pruned classifier model
39
+ optimize: bool = False # if true, creates an optimized classifier model
40
+
41
+
42
+
ctmatch/pipetopic.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, NamedTuple
3
+
4
+ class PipeTopic(NamedTuple):
5
+ topic_text: str
6
+ embedding_vec: Any
7
+ category_vec: Any
8
+
9
+
10
+
ctmatch/scripts/build_combined_data.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, List, Tuple
3
+ import json
4
+
5
+
6
+ COMBINED_CAT_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/combined_categories.jsonl'
7
+ CAT_SAVE_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/doc_categories.txt'
8
+ INDEX2DOCID_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/index2docid.txt'
9
+ INDEX2TOPICID_PATH = '/Users/jameskelly/Documents/cp/ctmatch/data/index2topicid.txt'
10
+
11
+ def load_category_dict(cat_path=COMBINED_CAT_PATH) -> Tuple[List, Dict[str, List[float]]]:
12
+ """
13
+ desc: gets category dict from category path
14
+ """
15
+ sorted_cat_keys = None
16
+
17
+ with open(cat_path, 'r') as json_file:
18
+ json_list = list(json_file)
19
+
20
+ all_cat_dict = {}
21
+ for s in json_list:
22
+ s_data = json.loads(s)
23
+ nct_id, cat_dict = s_data.popitem()
24
+
25
+ if sorted_cat_keys is None:
26
+ sorted_cat_keys = sorted(cat_dict.keys())
27
+
28
+ all_cat_dict[nct_id] = [cat_dict[k] for k in sorted_cat_keys]
29
+
30
+ return sorted_cat_keys, all_cat_dict
31
+
32
+
33
+
34
+ def load_index2id(index2id_path: str = INDEX2DOCID_PATH) -> Dict[str, int]:
35
+ """
36
+ desc: loads id2idx from csv path
37
+ """
38
+ index2id = {}
39
+ with open(index2id_path, 'r') as f:
40
+ for line in f:
41
+ if len(line) < 2:
42
+ continue
43
+ idx, nct_id = line.split(',')
44
+ index2id[idx] = nct_id.strip(' \n')
45
+
46
+ return index2id
47
+
48
+
49
+
50
+ def build_cat_csv(save_path: str = CAT_SAVE_PATH) -> None:
51
+ """
52
+ desc: builds csv file for category data
53
+ VERY important that the indexes (order) match the order of the embeddings (for nctid lookup in idx2id)
54
+ """
55
+ sorted_cat_keys, cat_dict = load_category_dict()
56
+ idx2id = load_index2id()
57
+
58
+ with open(save_path, 'w') as f:
59
+ f.write(','.join(sorted_cat_keys))
60
+ f.write('\n')
61
+ for _, nct_id in idx2id.items():
62
+ cat_vec = cat_dict[nct_id]
63
+ cat_vec_str = ','.join([str(c) for c in cat_vec])
64
+ f.write(cat_vec_str)
65
+ f.write('\n')
66
+
67
+
68
+ if __name__ == '__main__':
69
+ build_cat_csv()
70
+
71
+
72
+
73
+
74
+
75
+
76
+
ctmatch/scripts/gen_categories.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Generator, List, Optional, Tuple
3
+ from ctmatch.utils.ctmatch_utils import get_processed_data
4
+ from ctmatch.ct_data_paths import get_data_tuples
5
+ from transformers import pipeline
6
+ import numpy as np
7
+ import json
8
+
9
+ CAT_GEN_MODEL = "facebook/bart-large-mnli"
10
+ #CAT_GEN_MODEL = "microsoft/biogpt"
11
+
12
+ CT_CATEGORIES = [
13
+ "pulmonary", "cardiac", "gastrointestinal", "renal", "psychological", "genetic", "pediatric",
14
+ "neurological", "cancer", "reproductive", "endocrine", "infection", "healthy", "other"
15
+ ]
16
+
17
+
18
+ # --------------------------------------------------------------------------------------------------------------- #
19
+ # this script is for applying zero-shot classification labels from 'facebook/bart-large-mnli' to the
20
+ # documents of the dataset, including test, because we can assume this is something that is realistic to pre-compute
21
+ # since you have the documents apriori
22
+ # --------------------------------------------------------------------------------------------------------------- #
23
+ GET_ONLY = None
24
+
25
+
26
+ def stream_condition_data(data_chunk, doc_or_topic: str = 'doc') -> Generator[str, None, None]:
27
+ for d in data_chunk:
28
+ if doc_or_topic == 'topic':
29
+ yield d['raw_text']
30
+ else:
31
+ condition = d['condition']
32
+ if len(condition) == 0:
33
+ yield 'no information'
34
+ else:
35
+ yield ' '.join(condition).lower()
36
+
37
+
38
+ def add_condition_category_labels(
39
+ trec_or_kz: str = 'trec',
40
+ model_checkpoint=CAT_GEN_MODEL,
41
+ start: int = 0,
42
+ doc_tuples: Optional[List[Tuple[str, str]]] = None,
43
+ category_label='category',
44
+ doc_or_topic: str = 'doc'
45
+ ) -> None:
46
+ pipe = pipeline(model=model_checkpoint, device=0)
47
+ chunk_size = 1000
48
+
49
+ # open the processed documents and add the category labels
50
+ if doc_tuples is None:
51
+ doc_tuples, _ = get_data_tuples(trec_or_kz=trec_or_kz)
52
+
53
+ for _, target in doc_tuples:
54
+ print(f"reading and writing to: {target}")
55
+ data = [d for d in get_processed_data(target, get_only=GET_ONLY)]
56
+ print(f"got {len(data)} records from {target}...")
57
+
58
+ # overwrite with new records having inferred category feature
59
+ with open('/content/drive/MyDrive/ct_data23/processed_trec_topic_X.jsonl', 'w') as f:
60
+ i = start
61
+ print(f'starting at: {i}')
62
+ while i < len(data):
63
+ next_chunk_end = min(len(data), i+chunk_size)
64
+ conditions = stream_condition_data(data[i:next_chunk_end], doc_or_topic=doc_or_topic)
65
+ categories = gen_categories(pipe, conditions)
66
+ print(f"generated {len(categories)} categories for {chunk_size} conditions...")
67
+ for j in range(i, next_chunk_end):
68
+ data[j][category_label] = categories[j - i]
69
+ f.write(json.dumps(data[j]))
70
+ f.write('\n')
71
+
72
+ if doc_or_topic == 'doc':
73
+ print(f"{i=}, doc condition: {data[i]['condition']}, generated category: {data[i]['category'].items()}")
74
+ else:
75
+ print(f"{i=}, topic raw text condition: {data[i]['raw_text']}, generated category: {data[i]['category'].items()}")
76
+
77
+ i += chunk_size
78
+
79
+
80
+ def gen_categories(pipe, text_dataset: Generator[str, None, None]) -> str:
81
+ categories = []
82
+ for output in pipe(text_dataset, candidate_labels=CT_CATEGORIES, batch_size=64):
83
+ score_dict = {output['labels'][i]:output['scores'][i] for i in range(len(output['labels']))}
84
+ #category = max(score_dict, key=score_dict.get)
85
+ categories.append(score_dict)
86
+ return categories
87
+
88
+
89
+ def gen_single_category_vector(pipe, text: str) -> str:
90
+ output = pipe(text, candidate_labels=CT_CATEGORIES)
91
+ score_dict = {output['labels'][i]:output['scores'][i] for i in range(len(output['labels']))}
92
+ return np.array(sorted(score_dict, key=score_dict.get, reverse=True))
ctmatch/scripts/get_web_data.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from selenium import webdriver
3
+
4
+ def save_web_data(url: str) -> None:
5
+ driver = webdriver.Chrome()
6
+ driver.get(url)
7
+ button = driver.find_element_by_class_name("save-list")
8
+ button.click()
9
+
10
+
11
+ if __name__ == "__main__":
12
+ url = "https://clinicaltrials.gov/ct2/results?cond=Heart+Diseases"
13
+ save_web_data(url)
14
+
ctmatch/scripts/split_files.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from pathlib import Path
4
+ import argparse
5
+ import os
6
+
7
+
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument('folder',
10
+ help="supply a folder path to be split up. if not folder, method won't do anything")
11
+
12
+ args = parser.parse_args()
13
+
14
+ MAX_FOLDER_SIZE = 2000
15
+
16
+
17
+ def split_files(folder: Path):
18
+
19
+ assert folder.is_dir()
20
+ num_dirs = 1
21
+ curr_size = 0
22
+
23
+ new_subfolder_path = folder.parent / f"{folder.as_posix()}_{num_dirs}"
24
+ new_subfolder_path.mkdir(exist_ok=True)
25
+ for file in folder.iterdir():
26
+ if curr_size > MAX_FOLDER_SIZE:
27
+ num_dirs += 1
28
+ new_subfolder_path = folder.parent / f"{folder.as_posix()}_{num_dirs}"
29
+ new_subfolder_path.mkdir(exist_ok=True)
30
+ curr_size = 0
31
+ else:
32
+ curr_size += 1
33
+ file.rename(new_subfolder_path / file.name)
34
+
35
+ if __name__ == "__main__":
36
+ split_files(Path(args.folder))
37
+
ctmatch/scripts/vis_script.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, List, NamedTuple
3
+ from ctproc.proc import CTDocument, EligCrit
4
+ from matplotlib import pyplot as plt
5
+ from collections import defaultdict
6
+ from zipfile import ZipFile
7
+ from lxml import etree
8
+ import pandas
9
+ import re
10
+
11
+
12
+ from utils.ctmatch_utils import *
13
+
14
+ class FieldCounter(NamedTuple):
15
+ missfld_counts: Dict[str, int] = defaultdict(int)
16
+ emptfld_counts: Dict[str, int] = defaultdict(int)
17
+ elig_form_counts: Dict[str, int] = defaultdict(int)
18
+ unit_counts: Dict[str, int] = defaultdict(int)
19
+
20
+
21
+ #----------------------------------------------------------------#
22
+ # EDA Utility Functions
23
+ #----------------------------------------------------------------#
24
+
25
+ # viewing
26
+ def print_elig_result(doc, dont_print=[]):
27
+ for k, v in doc.elig_crit.__dict__.items():
28
+ if k in dont_print:
29
+ continue
30
+ if type(v) == list:
31
+ print('\n' + k)
32
+ for v_i in v:
33
+ print(v_i)
34
+ else:
35
+ print(f"{k}: {v}")
36
+
37
+
38
+ def display_elig(docs: List[CTDocument]) -> None:
39
+ age_in_elig_text_dist = count_elig_crit_age_in_text(docs)
40
+ total = sum(age_in_elig_text_dist.values())
41
+ print(f"{total} out of {len(docs)} documents had age in eligibility text: {total / len(docs)}%")
42
+
43
+ age_in_elig_counts_df = pandas.DataFrame(age_in_elig_text_dist, index=[0])
44
+ age_in_elig_counts_df.plot(kind="bar", xticks=[], xlabel="include_or_exclude", ylabel="count", title="Age in Eligibility Criteria Text Distribution")
45
+ print(age_in_elig_counts_df)
46
+ inc_ratio = age_in_elig_text_dist['inc_ct'] / total
47
+ exc_ratio = age_in_elig_text_dist['exc_ct'] / total
48
+ print(f"{age_in_elig_text_dist['inc_ct']} instances in inclusion statements ({inc_ratio}%), {age_in_elig_text_dist['exc_ct']} instances in exclusion statements ({exc_ratio}%)")
49
+
50
+
51
+
52
+ def get_lengths(processed_docs: List[Dict[str, str]]) -> None:
53
+ no_crit, miss_inc, miss_exc = 0, 0, 0
54
+ inc_lens, exc_lens, all_lens = 0, 0, 0
55
+
56
+ for i, d in enumerate(processed_docs):
57
+ crit = d['elig_crit']['raw_text']
58
+ inc_crit = d['elig_crit']['include_criteria']
59
+ exc_crit = d['elig_crit']['exclude_criteria']
60
+
61
+ if len(inc_crit) == 0:
62
+ miss_inc += 1
63
+
64
+ if len(exc_crit) == 0:
65
+ miss_exc += 1
66
+
67
+ if (len(exc_crit) == 0) and (len(inc_crit) == 0):
68
+ no_crit += 1
69
+
70
+ #print(crit)
71
+
72
+ inc_length = sum([len(c.split()) for c in inc_crit])
73
+ exc_length = sum([len(c.split()) for c in exc_crit])
74
+ crit_len = inc_length + exc_length
75
+ inc_lens += inc_length
76
+ exc_lens += exc_length
77
+ all_lens += crit_len
78
+
79
+ print(f"{miss_inc=}, {miss_exc=}, {no_crit=}, {inc_lens / len(processed_docs)}, {exc_lens / len(processed_docs)}, {all_lens / len(processed_docs)}")
80
+
81
+
82
+
83
+ def print_ent_sent(ent_sent):
84
+ for e in ent_sent:
85
+ e_small = {}
86
+ e_small['raw_text'] = e['raw_text']
87
+ e_small['start'] = e['start']
88
+ e_small['end'] = e['end']
89
+ e_small['negation'] = e['negation']
90
+ print(e_small.items())
91
+
92
+
93
+
94
+
95
+
96
+
97
+ #--------------------------------------------------------------------------------------#
98
+ # methods for getting counts
99
+ #--------------------------------------------------------------------------------------#
100
+
101
+ def process_counts(zip_data: str) -> FieldCounter:
102
+ """
103
+ desc: main method for processing a zipped file of clinical trial XML documents from clinicaltrials.gov
104
+ parameterized by CTConfig the self ClinProc object was initialized with
105
+ returns: yields processed CTDocuments one at a time
106
+ """
107
+
108
+ counts = FieldCounter()
109
+ with ZipFile(zip_data, 'r') as zip_reader:
110
+ for i, ct_file in enumerate(zip_reader.namelist()):
111
+ if i % 1000 == 0:
112
+ print(f"{i} docs processed")
113
+
114
+ if not ct_file.endswith('xml'):
115
+ continue
116
+
117
+ counts = get_ct_file_counts(zip_reader.open(ct_file), counts)
118
+ return counts
119
+
120
+
121
+
122
+
123
+ def get_ct_file_counts(xml_filereader, counts: FieldCounter) -> FieldCounter:
124
+ doc_tree = etree.parse(xml_filereader)
125
+ root = doc_tree.getroot()
126
+
127
+ # adding new keys vs subdictionaries?????
128
+ required_fields = {
129
+ "id":None,
130
+ "brief_title":None,
131
+ "eligibility/criteria/textblock":None,
132
+ "eligibility/gender":"Default Value",
133
+ "eligibility/minimum_age":{"male":0, "female":0},
134
+ "eligibility/maximum_age":{"male":999., "female":999.},
135
+ "detailed_description/textblock":None,
136
+ "condition":None,
137
+ "condition/condition_browse":None,
138
+ "intervention/intervention_type":None,
139
+ "intervention/intervention_name":None,
140
+ "intervention_browse/mesh_term":None,
141
+ "brief_summary/textblock":None,
142
+ }
143
+
144
+ for field in required_fields.keys():
145
+ field_tag = 'id_info/nct_id' if field == 'id' else field
146
+ try:
147
+ field_val = root.find(field_tag).text
148
+ if not EMPTY_PATTERN.fullmatch(field_val):
149
+ if field == 'eligibility/criteria/textblock':
150
+ counts.elig_form_counts = get_elig_counts(field_val, counts.elig_form_counts)
151
+ elif "age" in field:
152
+ age_match = AGE_PATTERN.match(field_val)
153
+ if age_match is not None:
154
+ unit = age_match.group('units')
155
+ if unit is not None:
156
+ counts.unit_counts[unit] += 1
157
+
158
+
159
+
160
+ except:
161
+ if root.find(field_tag) is None:
162
+ counts.missfld_counts[field] += 1
163
+ elif EMPTY_PATTERN.fullmatch(root.find(field_tag).text):
164
+ counts.emptfld_counts[field] += 1
165
+
166
+ return counts
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+
176
+
177
+
178
+
179
+ def get_elig_counts(elig_text: str, elig_form_counts: Dict[str, int]) -> Dict[str, int]:
180
+ assert elig_text is not None, "Eligibility text is empty"
181
+ if re.search('[Ii]nclusion [Cc]riteria:[^\w]+\n', elig_text):
182
+ if re.search('[Ee]xclusion Criteria:[^\w]+\n', elig_text):
183
+ elig_form_counts["inc_and_exc"] += 1
184
+ return elig_form_counts
185
+ else:
186
+ elig_form_counts["inc_only"] += 1
187
+ return elig_form_counts
188
+
189
+ elif re.search('[Ee]xclusion [Cc]riteria:[^\w]+\n', elig_text):
190
+ elig_form_counts["exc_only"] += 1
191
+ return elig_form_counts
192
+
193
+ else:
194
+ elig_form_counts["textblock"] += 1
195
+ return elig_form_counts
196
+
197
+
198
+
199
+
200
+ def get_counts(docs: List[CTDocument]):
201
+ gender_dist = defaultdict(int)
202
+ min_age_dist = defaultdict(int)
203
+ max_age_dist = defaultdict(int)
204
+ for doc in docs:
205
+ gender_dist[doc.elig_gender] += 1
206
+ min_age_dist[doc.elig_crit.elig_min_age] += 1
207
+ max_age_dist[doc.elig_max_age] += 1
208
+ return gender_dist, min_age_dist, max_age_dist
209
+
210
+
211
+
212
+ def get_relled(topic_id, rel_dict):
213
+ twos, ones, zeros = set(), set(), set()
214
+ for doc_id, rel in rel_dict[topic_id].items():
215
+ if rel == 1:
216
+ ones.add(doc_id)
217
+ elif rel == 2:
218
+ twos.add(doc_id)
219
+ else:
220
+ zeros.add(doc_id)
221
+ return {"twos": twos, "ones": ones, "zeros": zeros}
222
+
223
+ def scan_for_age(
224
+ elig_crit: EligCrit,
225
+ inc_or_ex: str = 'include'
226
+ ) -> bool:
227
+ crit_to_scan = elig_crit.include_criteria if inc_or_ex == 'include' else elig_crit.exclude_criteria
228
+ for crit in crit_to_scan:
229
+ if re.match(r' ages? ', crit.lower()) is not None:
230
+ return True
231
+ return False
232
+
233
+
234
+ def count_elig_crit_age_in_text(docs, skip_predefined:bool = True):
235
+ age_in_elig_text_dist = defaultdict(int)
236
+ skipped = 0
237
+ for doc in docs:
238
+ if skip_predefined:
239
+ if (doc.elig_min_age != 0) or (doc.elig_max_age != 999): # author(s) have specified SOME criteria, assumes judgment prefers this field to free trex in criteria textblock
240
+ skipped += 1
241
+ continue
242
+
243
+ age_in_elig_text_dist['include'] += scan_for_age(doc.elig_crit, 'include')
244
+ age_in_elig_text_dist['exclude'] += scan_for_age(doc.elig_crit, 'exclude')
245
+
246
+ print(f"Total skipped: {skipped}")
247
+ return age_in_elig_text_dist
248
+
249
+
250
+
251
+
252
+ def get_missing_criteria(docs: List[CTDocument]):
253
+ missing_inc_ids, missing_exc_ids = {}, {}
254
+ for d in docs:
255
+
256
+ if len(d.elig_crit.include_criteria) == 0:
257
+ missing_inc_ids.add(d.nct_id)
258
+
259
+ if len(d.elig_crit.exclude_criteria) == 0:
260
+ missing_exc_ids.add(d.nct_id)
261
+
262
+ return missing_inc_ids, missing_exc_ids
263
+
264
+
265
+ # for evaluating effect of filtering
266
+ def get_doc_percent_elig(filtered_docs_by_topic: Dict[str, set]):
267
+ percents_elig = []
268
+ for topic_id, doc_list in filtered_docs_by_topic.items():
269
+ per = len(doc_list) / 3262.0
270
+ percents_elig.append(per)
271
+ print(topic_id, len(doc_list), per)
272
+ mean_elig = sum(percents_elig) / len(percents_elig)
273
+ print(f"Mean elgibile number of docs: {mean_elig}")
274
+
275
+
276
+
277
+
278
+
279
+ # plotting
280
+
281
+ def plot_counts(missfld_counts, emptfld_counts):
282
+ miss_df = pandas.DataFrame(missfld_counts, index=[0])
283
+ miss_df.plot(kind='bar', xticks=[], title="Missing Fields", ylabel="count", xlabel="field")
284
+ plt.legend(loc=(1.04, 0))
285
+
286
+ empt_df = pandas.DataFrame(emptfld_counts, index=[0])
287
+ empt_df.plot(kind='bar', xticks=[], title="Empty Fields", ylabel="count", xlabel="field")
288
+ plt.legend(loc=(1.04, 0))
289
+
290
+
291
+
292
+
293
+ #----------------------------------------------------------------#
294
+ # EDA Test Data Utility Functions
295
+ #----------------------------------------------------------------#
296
+
297
+
298
+ def get_test_rels(test_rels):
299
+ rel_dict = defaultdict(lambda:defaultdict(int))
300
+ rel_type_dict = defaultdict(int)
301
+ for line in open(test_rels, 'r').readlines():
302
+ topic_id, _, doc_id, rel = re.split(r'\s+', line.strip())
303
+ rel_dict[topic_id][doc_id] = int(rel)
304
+ rel_type_dict[rel] += 1
305
+ return rel_dict, rel_type_dict
306
+
307
+ def analyze_test_rels(test_rels_path):
308
+ rel_dict, rel_type_dict = get_test_rels(test_rels_path)
309
+
310
+ print("Rel Type Results:")
311
+ for t, n in rel_type_dict.items():
312
+ print(t + ': ' + str(n))
313
+
314
+ lengths = dict()
315
+ all_qrelled_docs = set()
316
+ for tid in rel_dict.keys():
317
+ lengths[tid] = len(rel_dict[tid])
318
+ for d in rel_dict[tid].keys():
319
+ all_qrelled_docs.add(d)
320
+ for topic, num_relled in lengths.items():
321
+ print(topic, num_relled)
322
+ print(f"Total relled: {len(all_qrelled_docs)}")
323
+ return rel_type_dict, rel_dict, all_qrelled_docs
324
+
325
+
326
+
327
+
328
+
329
+
330
+ if __name__ == '__main__':
331
+ qrels_path = '/Users/jameskelly/Documents/cp/ctmatch/data/qrels-clinical_trials.txt'
332
+ rel_type_dict, rel_dict, all_qrelled_docs = analyze_test_rels(qrels_path)
333
+ #docs_path = '/Users/jameskelly/Documents/cp/ctproc/clinicaltrials.gov-16_dec_2015_17.zip'
334
+ #counts = process_counts(docs_path)
ctmatch/utils/__init__.py ADDED
File without changes
ctmatch/utils/ctmatch_utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Any, Dict, List, Optional, Set
3
+ from sklearn.metrics.pairwise import linear_kernel
4
+ from collections import defaultdict
5
+ from numpy.linalg import norm
6
+ from datasets import Dataset
7
+ import numpy as np
8
+ import json
9
+ import re
10
+
11
+
12
+
13
+
14
+
15
+ #----------------------------------------------------------------#
16
+ # global regex patterns for use throughout the methods
17
+ #----------------------------------------------------------------#
18
+
19
+
20
+ EMPTY_PATTERN = re.compile('[\n\s]+')
21
+ """
22
+ both_inc_and_exc_pattern = re.compile(r\"\"\"[\s\n]*[Ii]nclusion [Cc]riteria:? # top line of both
23
+ (?:[ ]+[Ee]ligibility[ \w]+\:[ ])? # could contain this unneeded bit next
24
+ (?P<include_crit>[ \n\-\.\?\"\%\r\w\:\,\(\)]*) # this should get all inclusion criteria as a string
25
+ [Ee]xclusion[ ][Cc]riteria:? # delineator to exclusion criteria
26
+ (?P<exclude_crit>[\w\W ]*) # exclusion criteria as string
27
+ \"\"\", re.VERBOSE)
28
+ """
29
+ INC_ONLY_PATTERN = re.compile('[\s\n]+[Ii]nclusion [Cc]riteria:?([\w\W ]*)')
30
+ EXC_ONLY_PATTERN = re.compile('[\n\r ]+[Ee]xclusion [Cc]riteria:?([\w\W ]*)')
31
+ AGE_PATTERN = re.compile('(?P<age>\d+) *(?P<units>\w+).*')
32
+ YEAR_PATTERN = re.compile('(?P<year>[yY]ears?.*)')
33
+ MONTH_PATTERN = re.compile('(?P<month>[mM]o(?:nth)?)')
34
+ WEEK_PATTERN = re.compile('(?P<week>[wW]eeks?)')
35
+
36
+ BOTH_INC_AND_EXC_PATTERN = re.compile("[\s\n]*[Ii]nclusion [Cc]riteria:?(?: +[Ee]ligibility[ \w]+\: )?(?P<include_crit>[ \n\-\.\?\"\%\r\w\:\,\(\)]*)[Ee]xclusion [Cc]riteria:?(?P<exclude_crit>[\w\W ]*)")
37
+
38
+
39
+
40
+ # -------------------------------------------------------------------------------------- #
41
+ # pretokenization utils (should be in a tokenizer...)
42
+ # -------------------------------------------------------------------------------------- #
43
+
44
+ def truncate(s: str, max_tokens: Optional[int] = None) -> str:
45
+ if max_tokens is None:
46
+ return s
47
+ s_tokens = s.split()
48
+ return ' '.join(s_tokens[:min(len(s_tokens), max_tokens)])
49
+
50
+
51
+
52
+ # -------------------------------------------------------------------------------------- #
53
+ # I/O utils
54
+ # -------------------------------------------------------------------------------------- #
55
+
56
+ def save_docs_jsonl(docs: List[Any], writefile: str) -> None:
57
+ """
58
+ desc: iteratively writes contents of docs as jsonl to writefile
59
+ """
60
+ with open(writefile, "w") as outfile:
61
+ for doc in docs:
62
+ json.dump(doc, outfile)
63
+ outfile.write("\n")
64
+
65
+
66
+ def get_processed_data(proc_loc: str, get_only: Optional[Set[str]] = None):
67
+ """
68
+ proc_loc: str or path to location of docs in jsonl form
69
+ """
70
+ with open(proc_loc, 'r') as json_file:
71
+ json_list = list(json_file)
72
+
73
+ if get_only is None:
74
+ for json_str in json_list:
75
+ yield json.loads(json_str)
76
+
77
+ else:
78
+ for s in json_list:
79
+ s_data = json.loads(s)
80
+ if s_data["id"] in get_only:
81
+ yield s_data
82
+ get_only.remove(s_data['id'])
83
+ if len(get_only) == 0:
84
+ return
85
+
86
+
87
+
88
+
89
+
90
+ def train_test_val_split(dataset, splits: Dict[str, float], seed: int = 37) -> Dataset:
91
+ """
92
+ splits a dataset having only "train" into one having train, test, val, with
93
+ split sizes determined by splits["train"] and splits["val"] (dict must have those keys)
94
+
95
+ """
96
+ dataset = dataset["train"].train_test_split(train_size=splits["train"], seed=seed)
97
+ train = dataset["train"]
98
+ sub = train.train_test_split(test_size=splits["val"], seed=seed)
99
+ new_train = sub["train"]
100
+ new_val = sub["test"]
101
+ dataset["train"] = new_train
102
+ dataset["validation"] = new_val
103
+ return dataset
104
+
105
+
106
+
107
+ #----------------------------------------------------------------#
108
+ # computation methods
109
+ #----------------------------------------------------------------#
110
+
111
+ def exclusive_argmax(vector: np.ndarray) -> np.ndarray:
112
+ mask = np.zeros(len(vector))
113
+ argmax = np.argmax(vector)
114
+ vector = vector * mask
115
+ vector[argmax] = 1
116
+ return vector
117
+
118
+
119
+ #----------------------------------------------------------------#
120
+ # evaluation methods (duplicated from ctproc scripts)
121
+ #----------------------------------------------------------------#
122
+
123
+ def get_test_rels(rel_path):
124
+ rel_dict = defaultdict(lambda:defaultdict(int))
125
+ rel_type_dict = defaultdict(int)
126
+ for line in open(rel_path, 'r').readlines():
127
+ topic_id, _, doc_id, rel = re.split(r'\s+', line.strip())
128
+ rel_dict[topic_id][doc_id] = int(rel)
129
+ rel_type_dict[rel] += 1
130
+ return rel_dict, rel_type_dict
131
+
132
+
133
+
ctmatch/utils/eval_utils.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, List, Tuple
3
+
4
+ from sklearn.metrics import f1_score
5
+ from collections import defaultdict
6
+ from lxml import etree
7
+ import numpy as np
8
+
9
+
10
+ def get_trec_topic2text(topic_path) -> Dict[str, str]:
11
+ """
12
+ desc: main method for processing a single XML file of TREC21 patient descriptions called "topics" in this sense
13
+ returns: dict of topicid: topic text
14
+ """
15
+
16
+ topic2text = {}
17
+ topic_root = etree.parse(topic_path).getroot()
18
+ for topic in topic_root:
19
+ topic2text[topic.attrib['number']] = topic.text
20
+
21
+ return topic2text
22
+
23
+
24
+
25
+ def get_kz_topic2text(topic_path) -> Dict[str, str]:
26
+ """
27
+ desc: main method for processing a single XML file of TREC21 patient descriptions called "topics" in this sense
28
+ returns: dict of topicid: topic text
29
+ """
30
+
31
+ topic2text = {}
32
+ with open(topic_path, 'r') as f:
33
+ for line in f.readlines():
34
+ line = line.strip()
35
+
36
+ if line.startswith('<TOP>'):
37
+ topic_id, text = None, None
38
+ continue
39
+
40
+ if line.startswith('<NUM>'):
41
+ topic_id = line[5:-6]
42
+
43
+ elif line.startswith('<TITLE>'):
44
+ text = line[7:].strip()
45
+ topic2text[topic_id] = text
46
+
47
+ return topic2text
48
+
49
+
50
+
51
+ def calc_first_positive_rank(ranked_ids: List[str], doc2rel: Dict[str, int], pos_val: int = 2) -> Tuple[int, float]:
52
+ """
53
+ desc: compute the mean reciprocal rank of a ranking
54
+ returns: mrr
55
+ """
56
+ for i, doc_id in enumerate(ranked_ids):
57
+ if doc2rel[doc_id] == pos_val:
58
+ return i + 1, 1./float(i+1)
59
+ return len(ranked_ids) + 1, 0.0
60
+
61
+
62
+ def calc_f1(ranked_ids: List[str], doc2rel: Dict[str, int]) -> Dict[str, Dict[str, float]]:
63
+ label_counts = get_label_counts(doc2rel)
64
+ predicted, ground_truth = [], []
65
+ for doc_id in ranked_ids:
66
+ # 2, 1, 0
67
+ ground_truth.append(doc2rel[doc_id])
68
+ pred_label = get_predicted_label(label_counts)
69
+ predicted.append(pred_label)
70
+ label_counts[pred_label] -= 1
71
+
72
+ return f1_score(ground_truth, predicted, average='micro')
73
+
74
+
75
+
76
+ def get_label_counts(doc2rel: Dict[str, int]) -> Dict[int, int]:
77
+ """
78
+ return an ordered list of [(2, <count_2s>), (1, <count_1s>), (0, count_0s)]
79
+ """
80
+ label_counts = defaultdict(int)
81
+ for scored_doc in doc2rel:
82
+ label = doc2rel[scored_doc]
83
+ label_counts[label] += 1
84
+ return label_counts
85
+
86
+ def get_predicted_label(label_counts: Dict[int, int]) -> int:
87
+ if label_counts[2] > 0:
88
+ return 2
89
+ if label_counts[1] > 0:
90
+ return 1
91
+ return 0
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ctproc uncoment if doing data prep on raw ct documents
2
+ #https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.4.0/en_core_sci_md-0.4.0.tar.gz uncomment if using ctproc
3
+ #pyserini==0.12.0 uncomment if using ctproc with indexes (not recommended)
4
+ #git+https://github.com/semajyllek/transformers.git@add-biogpt-sequenceclassifier
5
+ #sacremoses uncomment if using biogpt
6
+ sentence-transformers
7
+ huggingface_hub
8
+ scikit-learn
9
+ transformers
10
+ onnxruntime
11
+ nn_pruning
12
+ optimum
13
+ onnx
14
+
15
+ matplotlib
16
+ accelerate
17
+ datasets
18
+ evaluate
19
+ pandas
20
+ openai
21
+ lxml
22
+
23
+ gradio