ajitrajasekharan's picture
Upload batch_main.py
9c98606
raw
history blame contribute delete
No virus
2.59 kB
import time
import torch
import string
import pdb
import argparse
from transformers import BertTokenizer, BertForMaskedLM
import BatchInference as bd
import batched_main_NER as ner
import aggregate_server_json as aggr
import json
DEFAULT_TOP_K = 20
SPECIFIC_TAG=":__entity__"
DEFAULT_MODEL_PATH="ajitrajasekharan/biomedical"
DEFAULT_RESULTS="results.txt"
def perform_inference(text,bio_model,ner_bio,aggr_ner):
print("Getting predictions from BIO model...")
bio_descs = bio_model.get_descriptors(text,None)
print("Computing BIO results...")
bio_ner = ner_bio.tag_sentence_service(text,bio_descs)
obj = json.loads(bio_ner)
combined_arr = [obj,obj]
aggregate_results = aggr_ner.fetch_all(text,combined_arr)
return aggregate_results
def process_input(results):
try:
input_file = results.input
output_file = results.output
print("Initializing BIO module...")
bio_model = bd.BatchInference("bio/desc_a100_config.json",'ajitrajasekharan/biomedical',False,False,DEFAULT_TOP_K,True,True, "bio/","bio/a100_labels.txt",False)
ner_bio = ner.UnsupNER("bio/ner_a100_config.json")
print("Initializing Aggregation module...")
aggr_ner = aggr.AggregateNER("./ensemble_config.json")
wfp = open(output_file,"w")
with open(input_file) as fp:
for line in fp:
text_input = line.strip().split()
print(text_input)
text_input = [t + ":__entity__" for t in text_input]
text_input = ' '.join(text_input)
start = time.time()
results = perform_inference(text_input,bio_model,ner_bio,aggr_ner)
print(f"prediction took {time.time() - start:.2f}s")
pdb.set_trace()
wfp.write(json.dumps(results))
wfp.write("\n\n")
wfp.close()
except Exception as e:
print("Some error occurred in batch processing")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Batch handling of NER ',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-model', action="store", dest="model", default=DEFAULT_MODEL_PATH,help='BERT pretrained models, or custom model path')
parser.add_argument('-input', action="store", dest="input", required=True,help='Input file with sentences')
parser.add_argument('-output', action="store", dest="output", default=DEFAULT_RESULTS,help='Output file with sentences')
results = parser.parse_args()
process_input(results)