import os import re import sys import json import tempfile import gradio as gr from transformers import ( TrainingArguments, HfArgumentParser, ) from robust_deid.ner_datasets import DatasetCreator from robust_deid.sequence_tagging import SequenceTagger from robust_deid.sequence_tagging.arguments import ( ModelArguments, DataTrainingArguments, EvaluationArguments, ) from robust_deid.deid import TextDeid class App(object): def __init__( self, model, threshold, span_constraint='super_strict', sentencizer='en_core_sci_sm', tokenizer='clinical', max_tokens=128, max_prev_sentence_token=32, max_next_sentence_token=32, default_chunk_size=32, ignore_label='NA' ): # Create the dataset creator object self._dataset_creator = DatasetCreator( sentencizer=sentencizer, tokenizer=tokenizer, max_tokens=max_tokens, max_prev_sentence_token=max_prev_sentence_token, max_next_sentence_token=max_next_sentence_token, default_chunk_size=default_chunk_size, ignore_label=ignore_label ) parser = HfArgumentParser((ModelArguments, DataTrainingArguments, EvaluationArguments, TrainingArguments)) model_config = App._get_model_config() model_config['model_name_or_path'] = App._get_model_map()[model] if threshold == 'No threshold': model_config['post_process'] = 'argmax' model_config['threshold'] = None else: model_config['post_process'] = 'threshold_max' model_config['threshold'] = App._get_threshold_map()[model_config['model_name_or_path']][threshold] print(model_config) #sys.exit(0) with tempfile.NamedTemporaryFile("w+", delete=False) as tmp: tmp.write(json.dumps(model_config) + '\n') tmp.seek(0) # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. self._model_args, self._data_args, self._evaluation_args, self._training_args = \ parser.parse_json_file(json_file=tmp.name) # Initialize the text deid object self._text_deid = TextDeid(notation=self._data_args.notation, span_constraint=span_constraint) # Initialize the sequence tagger self._sequence_tagger = SequenceTagger( task_name=self._data_args.task_name, notation=self._data_args.notation, ner_types=self._data_args.ner_types, model_name_or_path=self._model_args.model_name_or_path, config_name=self._model_args.config_name, tokenizer_name=self._model_args.tokenizer_name, post_process=self._model_args.post_process, cache_dir=self._model_args.cache_dir, model_revision=self._model_args.model_revision, use_auth_token=self._model_args.use_auth_token, threshold=self._model_args.threshold, do_lower_case=self._data_args.do_lower_case, fp16=self._training_args.fp16, seed=self._training_args.seed, local_rank=self._training_args.local_rank ) # Load the required functions of the sequence tagger self._sequence_tagger.load() def get_ner_dataset(self, notes_file): ner_notes = self._dataset_creator.create( input_file=notes_file, mode='predict', notation=self._data_args.notation, token_text_key='text', metadata_key='meta', note_id_key='note_id', label_key='label', span_text_key='spans' ) return ner_notes def get_predictions(self, ner_notes_file): self._sequence_tagger.set_predict( test_file=ner_notes_file, max_test_samples=self._data_args.max_predict_samples, preprocessing_num_workers=self._data_args.preprocessing_num_workers, overwrite_cache=self._data_args.overwrite_cache ) self._sequence_tagger.setup_trainer(training_args=self._training_args) predictions = self._sequence_tagger.predict() return predictions def get_deid_text_removed(self, notes_file, predictions_file): deid_notes = self._text_deid.run_deid( input_file=notes_file, predictions_file=predictions_file, deid_strategy='remove', keep_age=False, metadata_key='meta', note_id_key='note_id', tokens_key='tokens', predictions_key='predictions', text_key='text', ) return deid_notes def get_deid_text_replaced(self, notes_file, predictions_file): deid_notes = self._text_deid.run_deid( input_file=notes_file, predictions_file=predictions_file, deid_strategy='replace_informative', keep_age=False, metadata_key='meta', note_id_key='note_id', tokens_key='tokens', predictions_key='predictions', text_key='text', ) return deid_notes @staticmethod def _get_highlights(deid_text): pattern = re.compile('<<(PATIENT|STAFF|AGE|DATE|LOCATION|PHONE|ID|EMAIL|PATORG|HOSPITAL|OTHERPHI):(.)*?>>') tag_pattern = re.compile('<<(PATIENT|STAFF|AGE|DATE|LOCATION|PHONE|ID|EMAIL|PATORG|HOSPITAL|OTHERPHI):') text_list = [] current_start = 0 current_end = 0 for match in re.finditer(pattern, deid_text): full_start, full_end = match.span() sub_text = deid_text[full_start:full_end] sub_match = re.search(tag_pattern, sub_text) sub_span = sub_match.span() tag_length = sub_match.span()[1] - sub_match.span()[0] yield (deid_text[current_start:full_start], None) yield (deid_text[full_start+sub_span[1]:full_end-2], sub_match.string[sub_span[0]+2:sub_span[1]-1]) current_start = full_end yield (deid_text[full_end:], None) @staticmethod def _get_model_map(): return { 'OBI-RoBERTa De-ID':'obi/deid_roberta_i2b2', 'OBI-ClinicalBERT De-ID':'obi/deid_bert_i2b2' } @staticmethod def _get_threshold_map(): return { 'obi/deid_bert_i2b2':{"99.5": 4.656325975101986e-06, "99.7":1.8982457699258832e-06}, 'obi/deid_roberta_i2b2':{"99.5": 2.4362972672812125e-05, "99.7":2.396420546444644e-06} } @staticmethod def _get_model_config(): return { "post_process":None, "threshold": None, "model_name_or_path":None, "task_name":"ner", "notation":"BILOU", "ner_types":["PATIENT", "STAFF", "AGE", "DATE", "PHONE", "ID", "EMAIL", "PATORG", "LOC", "HOSP", "OTHERPHI"], "truncation":True, "max_length":512, "label_all_tokens":False, "return_entity_level_metrics":True, "text_column_name":"tokens", "label_column_name":"labels", "output_dir":"./run/models", "logging_dir":"./run/logs", "overwrite_output_dir":False, "do_train":False, "do_eval":False, "do_predict":True, "report_to":[], "per_device_train_batch_size":0, "per_device_eval_batch_size":16, "logging_steps":1000 } def deid(text, model, threshold): notes = [{"text": text, "meta": {"note_id": "note_1", "patient_id": "patient_1"}, "spans": []}] app = App(model, threshold) # Create temp notes file with tempfile.NamedTemporaryFile("w+", delete=False) as tmp: for note in notes: tmp.write(json.dumps(note) + '\n') tmp.seek(0) ner_notes = app.get_ner_dataset(tmp.name) # Create temp ner_notes file with tempfile.NamedTemporaryFile("w+", delete=False) as tmp: for ner_sentence in ner_notes: tmp.write(json.dumps(ner_sentence) + '\n')