ajitrajasekharan commited on
Commit
9c98606
1 Parent(s): 981717f

Upload batch_main.py

Browse files
Files changed (1) hide show
  1. batch_main.py +63 -0
batch_main.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ import string
4
+ import pdb
5
+ import argparse
6
+
7
+ from transformers import BertTokenizer, BertForMaskedLM
8
+ import BatchInference as bd
9
+ import batched_main_NER as ner
10
+ import aggregate_server_json as aggr
11
+ import json
12
+
13
+
14
+ DEFAULT_TOP_K = 20
15
+ SPECIFIC_TAG=":__entity__"
16
+ DEFAULT_MODEL_PATH="ajitrajasekharan/biomedical"
17
+ DEFAULT_RESULTS="results.txt"
18
+
19
+
20
+ def perform_inference(text,bio_model,ner_bio,aggr_ner):
21
+ print("Getting predictions from BIO model...")
22
+ bio_descs = bio_model.get_descriptors(text,None)
23
+ print("Computing BIO results...")
24
+ bio_ner = ner_bio.tag_sentence_service(text,bio_descs)
25
+ obj = json.loads(bio_ner)
26
+ combined_arr = [obj,obj]
27
+ aggregate_results = aggr_ner.fetch_all(text,combined_arr)
28
+ return aggregate_results
29
+
30
+
31
+ def process_input(results):
32
+ try:
33
+ input_file = results.input
34
+ output_file = results.output
35
+ print("Initializing BIO module...")
36
+ bio_model = bd.BatchInference("bio/desc_a100_config.json",'ajitrajasekharan/biomedical',False,False,DEFAULT_TOP_K,True,True, "bio/","bio/a100_labels.txt",False)
37
+ ner_bio = ner.UnsupNER("bio/ner_a100_config.json")
38
+ print("Initializing Aggregation module...")
39
+ aggr_ner = aggr.AggregateNER("./ensemble_config.json")
40
+ wfp = open(output_file,"w")
41
+ with open(input_file) as fp:
42
+ for line in fp:
43
+ text_input = line.strip().split()
44
+ print(text_input)
45
+ text_input = [t + ":__entity__" for t in text_input]
46
+ text_input = ' '.join(text_input)
47
+ start = time.time()
48
+ results = perform_inference(text_input,bio_model,ner_bio,aggr_ner)
49
+ print(f"prediction took {time.time() - start:.2f}s")
50
+ pdb.set_trace()
51
+ wfp.write(json.dumps(results))
52
+ wfp.write("\n\n")
53
+ wfp.close()
54
+ except Exception as e:
55
+ print("Some error occurred in batch processing")
56
+
57
+ if __name__ == "__main__":
58
+ parser = argparse.ArgumentParser(description='Batch handling of NER ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
59
+ parser.add_argument('-model', action="store", dest="model", default=DEFAULT_MODEL_PATH,help='BERT pretrained models, or custom model path')
60
+ parser.add_argument('-input', action="store", dest="input", required=True,help='Input file with sentences')
61
+ parser.add_argument('-output', action="store", dest="output", default=DEFAULT_RESULTS,help='Output file with sentences')
62
+ results = parser.parse_args()
63
+ process_input(results)