ki-ki13 commited on
Commit
d26f6fb
1 Parent(s): 8dadc22

first init

Browse files
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model import predict_ner
3
+
4
+
5
+ iface = gr.Interface(
6
+ fn = predict_ner,
7
+ inputs="text",
8
+ outputs="html",
9
+ title="Named Entity Recognition for Electronic Medical Record using Bidirectional Long Short Term Memory with ClinicalBERT",
10
+ description="This tool identifies and highlights entities in text. <br><span style='background-color:red;color: white;'>Red</span> is for problems, <br><span style='background-color:blue;color:white;'>Blue</span> is for tests, and <br><span style='background-color:green;color: white;'>Green</span> is for treatments.",
11
+ css="span { font-weight: bold; } .problem { background-color: red } .test { background-color: blue } .treatment { background-color: green}",
12
+ examples=[
13
+ ["The patient presented with symptoms of fever and cough. A chest X-ray was performed to assess the condition."],
14
+ ["After diagnosis, the physician prescribed antibiotics for the treatment of the infection."],
15
+ ["The patient, a 55-year-old male, presented to the emergency department with complaints of chest pain and shortness of breath. He has a history of hypertension and diabetes. On physical examination, the patient appeared diaphoretic, and his blood pressure was elevated at 160/90 mmHg. An electrocardiogram (ECG) was performed, which showed ST-segment elevation in the anterior leads, consistent with an acute myocardial infarction. The patient was immediately started on aspirin, nitroglycerin, and clopidogrel and was taken for emergent cardiac catheterization. Coronary angiography revealed a critical stenosis in the left anterior descending artery, which was successfully stented. The patient's chest pain improved, and he was admitted to the cardiac care unit for further monitoring. Laboratory tests showed elevated troponin levels, confirming the myocardial infarction. The patient was counseled on lifestyle modifications, including diet and exercise, and was prescribed medications for long-term management. He was discharged home in stable condition with instructions to follow up with his cardiologist in one week."]
16
+ ])
17
+
18
+ iface.launch()
format_entity.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def format_entities(input_text, predicted_labels):
2
+ formatted_text = ""
3
+ entity_open = False
4
+ for token, label in zip(input_text, predicted_labels):
5
+ if label.startswith("B-"):
6
+ if entity_open:
7
+ formatted_text += "</span>"
8
+ formatted_text += f'<span class="{label[2:]}">{token} ({label}) '
9
+ entity_open = True
10
+ elif label.startswith("I-"):
11
+ formatted_text += token + " "
12
+ else:
13
+ if entity_open:
14
+ formatted_text += "</span>"
15
+ formatted_text += token + " "
16
+ entity_open = False
17
+ if entity_open:
18
+ formatted_text += "</span>"
19
+ return formatted_text
model.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from format_entity import format_entities
3
+ from transformers import DistilBertForTokenClassification, DistilBertTokenizer
4
+
5
+ DRIVE_BASE_PATH = "model/"
6
+ model_path = f"{DRIVE_BASE_PATH}"
7
+
8
+ model = DistilBertForTokenClassification.from_pretrained(model_path)
9
+ tokenizer = DistilBertTokenizer.from_pretrained(model_path)
10
+
11
+ def predict_ner(input_text):
12
+ # Tokenize the input text
13
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
14
+
15
+ # Make predictions
16
+ with torch.no_grad():
17
+ outputs = model(**inputs)
18
+
19
+ # Process the NER results
20
+ labels = outputs.logits.argmax(dim=2)
21
+ predicted_labels = [model.config.id2label[label_id] for label_id in labels[0].tolist()]
22
+ # probabilities = torch.nn.functional.softmax(outputs.logits, dim=2)
23
+
24
+ # Exclude [SEP] and [CLS] tokens from tokenized_text and predicted_labels
25
+ tokenized_text = tokenizer.tokenize(tokenizer.decode(inputs["input_ids"][0]))
26
+ token_label_pairs = [
27
+ (token, label) for token, label in zip(tokenized_text, predicted_labels)
28
+ if token not in ["[SEP]", "[CLS]"]
29
+ ]
30
+
31
+ # Format the results vertically, excluding [SEP] and [CLS]
32
+ formatted_results = format_entities(
33
+ [pair[0] for pair in token_label_pairs],
34
+ [pair[1] for pair in token_label_pairs]
35
+ )
36
+
37
+ # Get top 3 probabilities and labels
38
+ # top_n_probs, top_n_labels = get_top_n_probs(probabilities[0], 6, list(model.config.id2label.values()))
39
+
40
+ return formatted_results
41
+
42
+
43
+
model/config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "nlpie/clinical-distilbert",
3
+ "activation": "gelu",
4
+ "adapters": {
5
+ "adapters": {},
6
+ "config_map": {},
7
+ "fusion_config_map": {},
8
+ "fusions": {}
9
+ },
10
+ "architectures": [
11
+ "DistilBertForTokenClassification"
12
+ ],
13
+ "attention_dropout": 0.1,
14
+ "dim": 768,
15
+ "dropout": 0.1,
16
+ "hidden_dim": 3072,
17
+ "id2label": {
18
+ "0": "I-treatment",
19
+ "1": "O",
20
+ "2": "B-test",
21
+ "3": "I-problem",
22
+ "4": "B-treatment",
23
+ "5": "I-test",
24
+ "6": "B-problem"
25
+ },
26
+ "initializer_range": 0.02,
27
+ "label2id": {
28
+ "I-treatment": 0,
29
+ "O": 1,
30
+ "B-test": 2,
31
+ "I-problem": 3,
32
+ "B-treatment": 4,
33
+ "I-test": 5,
34
+ "B-problem": 6
35
+ },
36
+ "max_position_embeddings": 512,
37
+ "model_type": "distilbert",
38
+ "n_heads": 12,
39
+ "n_layers": 6,
40
+ "output_past": true,
41
+ "pad_token_id": 0,
42
+ "qa_dropout": 0.1,
43
+ "seq_classif_dropout": 0.2,
44
+ "sinusoidal_pos_embds": false,
45
+ "tie_weights_": true,
46
+ "torch_dtype": "float32",
47
+ "transformers_version": "4.27.0.dev0",
48
+ "vocab_size": 28996
49
+ }
model/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:081f66e415a24009bcdebad613beebe627a1a7b07855549ea02494fe0ce23f74
3
+ size 260818805
model/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_lower_case": false,
4
+ "mask_token": "[MASK]",
5
+ "model_max_length": 512,
6
+ "name_or_path": "distilbert-base-cased",
7
+ "pad_token": "[PAD]",
8
+ "sep_token": "[SEP]",
9
+ "special_tokens_map_file": null,
10
+ "strip_accents": null,
11
+ "tokenize_chinese_chars": true,
12
+ "tokenizer_class": "DistilBertTokenizer",
13
+ "unk_token": "[UNK]"
14
+ }
model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:145be841037ec5f2d5313b58d1c40166d68fc7cf901a253a8827bcbb768cfae8
3
+ size 3503
model/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
sample_app.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DistilBertForTokenClassification, DistilBertTokenizer
2
+ import torch
3
+
4
+ DRIVE_BASE_PATH = "model/"
5
+ model_path = f"{DRIVE_BASE_PATH}"
6
+
7
+ model = DistilBertForTokenClassification.from_pretrained(model_path)
8
+ tokenizer = DistilBertTokenizer.from_pretrained(model_path)
9
+
10
+ def predict_ner(input_text):
11
+ # Tokenize the input text
12
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
13
+
14
+ # Make predictions
15
+ with torch.no_grad():
16
+ outputs = model(**inputs)
17
+
18
+ # Process the NER results
19
+ labels = outputs.logits.argmax(dim=2)
20
+ predicted_labels = [model.config.id2label[label_id] for label_id in labels[0].tolist()]
21
+ # predicted_labels = [label_mapping.get(model.config.id2label[label_id], "O") for label_id in labels[0].tolist()]
22
+ tokenized_text = tokenizer.tokenize(tokenizer.decode(inputs["input_ids"][0]))
23
+
24
+ # Pair tokens with their labels, excluding [SEP] and [CLS]
25
+ token_label_pairs = [(token, label) for token, label in zip(tokenized_text, predicted_labels) if token not in ["[SEP]", "[CLS]"]]
26
+
27
+ # Format the results vertically, excluding [SEP] and [CLS]
28
+ formatted_results = []
29
+ for token, label in token_label_pairs:
30
+ formatted_results.append(f"Token: {token}, Label: {label}")
31
+
32
+ return {"text": input_text, "formatted_results": formatted_results}
33
+
34
+
35
+ input_text = """Also , due to worsening renal function , she was started on octreotide / midodrine / albumin for hepatorenal
36
+ syndrome ( Cr 3.3 at its worst ) which resolved prior to her discharge ."""
37
+ result = predict_ner(input_text)
38
+ print(result['text'])
39
+ # print(result['token_probabilities'])
40
+ for item in result['formatted_results']:
41
+ print(item)