test1 / app.py
kkhushisaid's picture
Update app.py
56deae5 verified
import os
import re
import sys
import json
import tempfile
import gradio as gr
from transformers import (
TrainingArguments,
HfArgumentParser,
)
from robust_deid.ner_datasets import DatasetCreator
from robust_deid.sequence_tagging import SequenceTagger
from robust_deid.sequence_tagging.arguments import (
ModelArguments,
DataTrainingArguments,
EvaluationArguments,
)
from robust_deid.deid import TextDeid
class App(object):
def __init__(
self,
model,
threshold,
span_constraint='super_strict',
sentencizer='en_core_sci_sm',
tokenizer='clinical',
max_tokens=128,
max_prev_sentence_token=32,
max_next_sentence_token=32,
default_chunk_size=32,
ignore_label='NA'
):
# Create the dataset creator object
self._dataset_creator = DatasetCreator(
sentencizer=sentencizer,
tokenizer=tokenizer,
max_tokens=max_tokens,
max_prev_sentence_token=max_prev_sentence_token,
max_next_sentence_token=max_next_sentence_token,
default_chunk_size=default_chunk_size,
ignore_label=ignore_label
)
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, EvaluationArguments, TrainingArguments))
model_config = App._get_model_config()
model_config['model_name_or_path'] = App._get_model_map()[model]
if threshold == 'No threshold':
model_config['post_process'] = 'argmax'
model_config['threshold'] = None
else:
model_config['post_process'] = 'threshold_max'
model_config['threshold'] = App._get_threshold_map()[model_config['model_name_or_path']][threshold]
print(model_config)
#sys.exit(0)
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
tmp.write(json.dumps(model_config) + '\n')
tmp.seek(0)
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
self._model_args, self._data_args, self._evaluation_args, self._training_args = \
parser.parse_json_file(json_file=tmp.name)
# Initialize the text deid object
self._text_deid = TextDeid(notation=self._data_args.notation, span_constraint=span_constraint)
# Initialize the sequence tagger
self._sequence_tagger = SequenceTagger(
task_name=self._data_args.task_name,
notation=self._data_args.notation,
ner_types=self._data_args.ner_types,
model_name_or_path=self._model_args.model_name_or_path,
config_name=self._model_args.config_name,
tokenizer_name=self._model_args.tokenizer_name,
post_process=self._model_args.post_process,
cache_dir=self._model_args.cache_dir,
model_revision=self._model_args.model_revision,
use_auth_token=self._model_args.use_auth_token,
threshold=self._model_args.threshold,
do_lower_case=self._data_args.do_lower_case,
fp16=self._training_args.fp16,
seed=self._training_args.seed,
local_rank=self._training_args.local_rank
)
# Load the required functions of the sequence tagger
self._sequence_tagger.load()
def get_ner_dataset(self, notes_file):
ner_notes = self._dataset_creator.create(
input_file=notes_file,
mode='predict',
notation=self._data_args.notation,
token_text_key='text',
metadata_key='meta',
note_id_key='note_id',
label_key='label',
span_text_key='spans'
)
return ner_notes
def get_predictions(self, ner_notes_file):
self._sequence_tagger.set_predict(
test_file=ner_notes_file,
max_test_samples=self._data_args.max_predict_samples,
preprocessing_num_workers=self._data_args.preprocessing_num_workers,
overwrite_cache=self._data_args.overwrite_cache
)
self._sequence_tagger.setup_trainer(training_args=self._training_args)
predictions = self._sequence_tagger.predict()
return predictions
def get_deid_text_removed(self, notes_file, predictions_file):
deid_notes = self._text_deid.run_deid(
input_file=notes_file,
predictions_file=predictions_file,
deid_strategy='remove',
keep_age=False,
metadata_key='meta',
note_id_key='note_id',
tokens_key='tokens',
predictions_key='predictions',
text_key='text',
)
return deid_notes
def get_deid_text_replaced(self, notes_file, predictions_file):
deid_notes = self._text_deid.run_deid(
input_file=notes_file,
predictions_file=predictions_file,
deid_strategy='replace_informative',
keep_age=False,
metadata_key='meta',
note_id_key='note_id',
tokens_key='tokens',
predictions_key='predictions',
text_key='text',
)
return deid_notes
@staticmethod
def _get_highlights(deid_text):
pattern = re.compile('<<(PATIENT|STAFF|AGE|DATE|LOCATION|PHONE|ID|EMAIL|PATORG|HOSPITAL|OTHERPHI):(.)*?>>')
tag_pattern = re.compile('<<(PATIENT|STAFF|AGE|DATE|LOCATION|PHONE|ID|EMAIL|PATORG|HOSPITAL|OTHERPHI):')
text_list = []
current_start = 0
current_end = 0
for match in re.finditer(pattern, deid_text):
full_start, full_end = match.span()
sub_text = deid_text[full_start:full_end]
sub_match = re.search(tag_pattern, sub_text)
sub_span = sub_match.span()
tag_length = sub_match.span()[1] - sub_match.span()[0]
yield (deid_text[current_start:full_start], None)
yield (deid_text[full_start+sub_span[1]:full_end-2], sub_match.string[sub_span[0]+2:sub_span[1]-1])
current_start = full_end
yield (deid_text[full_end:], None)
@staticmethod
def _get_model_map():
return {
'OBI-RoBERTa De-ID':'obi/deid_roberta_i2b2',
'OBI-ClinicalBERT De-ID':'obi/deid_bert_i2b2'
}
@staticmethod
def _get_threshold_map():
return {
'obi/deid_bert_i2b2':{"99.5": 4.656325975101986e-06, "99.7":1.8982457699258832e-06},
'obi/deid_roberta_i2b2':{"99.5": 2.4362972672812125e-05, "99.7":2.396420546444644e-06}
}
@staticmethod
def _get_model_config():
return {
"post_process":None,
"threshold": None,
"model_name_or_path":None,
"task_name":"ner",
"notation":"BILOU",
"ner_types":["PATIENT", "STAFF", "AGE", "DATE", "PHONE", "ID", "EMAIL", "PATORG", "LOC", "HOSP", "OTHERPHI"],
"truncation":True,
"max_length":512,
"label_all_tokens":False,
"return_entity_level_metrics":True,
"text_column_name":"tokens",
"label_column_name":"labels",
"output_dir":"./run/models",
"logging_dir":"./run/logs",
"overwrite_output_dir":False,
"do_train":False,
"do_eval":False,
"do_predict":True,
"report_to":[],
"per_device_train_batch_size":0,
"per_device_eval_batch_size":16,
"logging_steps":1000
}
def deid(text, model, threshold):
notes = [{"text": text, "meta": {"note_id": "note_1", "patient_id": "patient_1"}, "spans": []}]
app = App(model, threshold)
# Create temp notes file
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
for note in notes:
tmp.write(json.dumps(note) + '\n')
tmp.seek(0)
ner_notes = app.get_ner_dataset(tmp.name)
# Create temp ner_notes file
with tempfile.NamedTemporaryFile("w+", delete=False) as tmp:
for ner_sentence in ner_notes:
tmp.write(json.dumps(ner_sentence) + '\n')