Spaces:
Runtime error
Runtime error
ki-ki13
commited on
Commit
•
d26f6fb
1
Parent(s):
8dadc22
first init
Browse files- app.py +18 -0
- format_entity.py +19 -0
- model.py +43 -0
- model/config.json +49 -0
- model/pytorch_model.bin +3 -0
- model/special_tokens_map.json +7 -0
- model/tokenizer.json +0 -0
- model/tokenizer_config.json +14 -0
- model/training_args.bin +3 -0
- model/vocab.txt +0 -0
- sample_app.py +41 -0
app.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from model import predict_ner
|
3 |
+
|
4 |
+
|
5 |
+
iface = gr.Interface(
|
6 |
+
fn = predict_ner,
|
7 |
+
inputs="text",
|
8 |
+
outputs="html",
|
9 |
+
title="Named Entity Recognition for Electronic Medical Record using Bidirectional Long Short Term Memory with ClinicalBERT",
|
10 |
+
description="This tool identifies and highlights entities in text. <br><span style='background-color:red;color: white;'>Red</span> is for problems, <br><span style='background-color:blue;color:white;'>Blue</span> is for tests, and <br><span style='background-color:green;color: white;'>Green</span> is for treatments.",
|
11 |
+
css="span { font-weight: bold; } .problem { background-color: red } .test { background-color: blue } .treatment { background-color: green}",
|
12 |
+
examples=[
|
13 |
+
["The patient presented with symptoms of fever and cough. A chest X-ray was performed to assess the condition."],
|
14 |
+
["After diagnosis, the physician prescribed antibiotics for the treatment of the infection."],
|
15 |
+
["The patient, a 55-year-old male, presented to the emergency department with complaints of chest pain and shortness of breath. He has a history of hypertension and diabetes. On physical examination, the patient appeared diaphoretic, and his blood pressure was elevated at 160/90 mmHg. An electrocardiogram (ECG) was performed, which showed ST-segment elevation in the anterior leads, consistent with an acute myocardial infarction. The patient was immediately started on aspirin, nitroglycerin, and clopidogrel and was taken for emergent cardiac catheterization. Coronary angiography revealed a critical stenosis in the left anterior descending artery, which was successfully stented. The patient's chest pain improved, and he was admitted to the cardiac care unit for further monitoring. Laboratory tests showed elevated troponin levels, confirming the myocardial infarction. The patient was counseled on lifestyle modifications, including diet and exercise, and was prescribed medications for long-term management. He was discharged home in stable condition with instructions to follow up with his cardiologist in one week."]
|
16 |
+
])
|
17 |
+
|
18 |
+
iface.launch()
|
format_entity.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def format_entities(input_text, predicted_labels):
|
2 |
+
formatted_text = ""
|
3 |
+
entity_open = False
|
4 |
+
for token, label in zip(input_text, predicted_labels):
|
5 |
+
if label.startswith("B-"):
|
6 |
+
if entity_open:
|
7 |
+
formatted_text += "</span>"
|
8 |
+
formatted_text += f'<span class="{label[2:]}">{token} ({label}) '
|
9 |
+
entity_open = True
|
10 |
+
elif label.startswith("I-"):
|
11 |
+
formatted_text += token + " "
|
12 |
+
else:
|
13 |
+
if entity_open:
|
14 |
+
formatted_text += "</span>"
|
15 |
+
formatted_text += token + " "
|
16 |
+
entity_open = False
|
17 |
+
if entity_open:
|
18 |
+
formatted_text += "</span>"
|
19 |
+
return formatted_text
|
model.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from format_entity import format_entities
|
3 |
+
from transformers import DistilBertForTokenClassification, DistilBertTokenizer
|
4 |
+
|
5 |
+
DRIVE_BASE_PATH = "model/"
|
6 |
+
model_path = f"{DRIVE_BASE_PATH}"
|
7 |
+
|
8 |
+
model = DistilBertForTokenClassification.from_pretrained(model_path)
|
9 |
+
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
|
10 |
+
|
11 |
+
def predict_ner(input_text):
|
12 |
+
# Tokenize the input text
|
13 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
14 |
+
|
15 |
+
# Make predictions
|
16 |
+
with torch.no_grad():
|
17 |
+
outputs = model(**inputs)
|
18 |
+
|
19 |
+
# Process the NER results
|
20 |
+
labels = outputs.logits.argmax(dim=2)
|
21 |
+
predicted_labels = [model.config.id2label[label_id] for label_id in labels[0].tolist()]
|
22 |
+
# probabilities = torch.nn.functional.softmax(outputs.logits, dim=2)
|
23 |
+
|
24 |
+
# Exclude [SEP] and [CLS] tokens from tokenized_text and predicted_labels
|
25 |
+
tokenized_text = tokenizer.tokenize(tokenizer.decode(inputs["input_ids"][0]))
|
26 |
+
token_label_pairs = [
|
27 |
+
(token, label) for token, label in zip(tokenized_text, predicted_labels)
|
28 |
+
if token not in ["[SEP]", "[CLS]"]
|
29 |
+
]
|
30 |
+
|
31 |
+
# Format the results vertically, excluding [SEP] and [CLS]
|
32 |
+
formatted_results = format_entities(
|
33 |
+
[pair[0] for pair in token_label_pairs],
|
34 |
+
[pair[1] for pair in token_label_pairs]
|
35 |
+
)
|
36 |
+
|
37 |
+
# Get top 3 probabilities and labels
|
38 |
+
# top_n_probs, top_n_labels = get_top_n_probs(probabilities[0], 6, list(model.config.id2label.values()))
|
39 |
+
|
40 |
+
return formatted_results
|
41 |
+
|
42 |
+
|
43 |
+
|
model/config.json
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "nlpie/clinical-distilbert",
|
3 |
+
"activation": "gelu",
|
4 |
+
"adapters": {
|
5 |
+
"adapters": {},
|
6 |
+
"config_map": {},
|
7 |
+
"fusion_config_map": {},
|
8 |
+
"fusions": {}
|
9 |
+
},
|
10 |
+
"architectures": [
|
11 |
+
"DistilBertForTokenClassification"
|
12 |
+
],
|
13 |
+
"attention_dropout": 0.1,
|
14 |
+
"dim": 768,
|
15 |
+
"dropout": 0.1,
|
16 |
+
"hidden_dim": 3072,
|
17 |
+
"id2label": {
|
18 |
+
"0": "I-treatment",
|
19 |
+
"1": "O",
|
20 |
+
"2": "B-test",
|
21 |
+
"3": "I-problem",
|
22 |
+
"4": "B-treatment",
|
23 |
+
"5": "I-test",
|
24 |
+
"6": "B-problem"
|
25 |
+
},
|
26 |
+
"initializer_range": 0.02,
|
27 |
+
"label2id": {
|
28 |
+
"I-treatment": 0,
|
29 |
+
"O": 1,
|
30 |
+
"B-test": 2,
|
31 |
+
"I-problem": 3,
|
32 |
+
"B-treatment": 4,
|
33 |
+
"I-test": 5,
|
34 |
+
"B-problem": 6
|
35 |
+
},
|
36 |
+
"max_position_embeddings": 512,
|
37 |
+
"model_type": "distilbert",
|
38 |
+
"n_heads": 12,
|
39 |
+
"n_layers": 6,
|
40 |
+
"output_past": true,
|
41 |
+
"pad_token_id": 0,
|
42 |
+
"qa_dropout": 0.1,
|
43 |
+
"seq_classif_dropout": 0.2,
|
44 |
+
"sinusoidal_pos_embds": false,
|
45 |
+
"tie_weights_": true,
|
46 |
+
"torch_dtype": "float32",
|
47 |
+
"transformers_version": "4.27.0.dev0",
|
48 |
+
"vocab_size": 28996
|
49 |
+
}
|
model/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:081f66e415a24009bcdebad613beebe627a1a7b07855549ea02494fe0ce23f74
|
3 |
+
size 260818805
|
model/special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
model/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/tokenizer_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"do_lower_case": false,
|
4 |
+
"mask_token": "[MASK]",
|
5 |
+
"model_max_length": 512,
|
6 |
+
"name_or_path": "distilbert-base-cased",
|
7 |
+
"pad_token": "[PAD]",
|
8 |
+
"sep_token": "[SEP]",
|
9 |
+
"special_tokens_map_file": null,
|
10 |
+
"strip_accents": null,
|
11 |
+
"tokenize_chinese_chars": true,
|
12 |
+
"tokenizer_class": "DistilBertTokenizer",
|
13 |
+
"unk_token": "[UNK]"
|
14 |
+
}
|
model/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:145be841037ec5f2d5313b58d1c40166d68fc7cf901a253a8827bcbb768cfae8
|
3 |
+
size 3503
|
model/vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sample_app.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import DistilBertForTokenClassification, DistilBertTokenizer
|
2 |
+
import torch
|
3 |
+
|
4 |
+
DRIVE_BASE_PATH = "model/"
|
5 |
+
model_path = f"{DRIVE_BASE_PATH}"
|
6 |
+
|
7 |
+
model = DistilBertForTokenClassification.from_pretrained(model_path)
|
8 |
+
tokenizer = DistilBertTokenizer.from_pretrained(model_path)
|
9 |
+
|
10 |
+
def predict_ner(input_text):
|
11 |
+
# Tokenize the input text
|
12 |
+
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
13 |
+
|
14 |
+
# Make predictions
|
15 |
+
with torch.no_grad():
|
16 |
+
outputs = model(**inputs)
|
17 |
+
|
18 |
+
# Process the NER results
|
19 |
+
labels = outputs.logits.argmax(dim=2)
|
20 |
+
predicted_labels = [model.config.id2label[label_id] for label_id in labels[0].tolist()]
|
21 |
+
# predicted_labels = [label_mapping.get(model.config.id2label[label_id], "O") for label_id in labels[0].tolist()]
|
22 |
+
tokenized_text = tokenizer.tokenize(tokenizer.decode(inputs["input_ids"][0]))
|
23 |
+
|
24 |
+
# Pair tokens with their labels, excluding [SEP] and [CLS]
|
25 |
+
token_label_pairs = [(token, label) for token, label in zip(tokenized_text, predicted_labels) if token not in ["[SEP]", "[CLS]"]]
|
26 |
+
|
27 |
+
# Format the results vertically, excluding [SEP] and [CLS]
|
28 |
+
formatted_results = []
|
29 |
+
for token, label in token_label_pairs:
|
30 |
+
formatted_results.append(f"Token: {token}, Label: {label}")
|
31 |
+
|
32 |
+
return {"text": input_text, "formatted_results": formatted_results}
|
33 |
+
|
34 |
+
|
35 |
+
input_text = """Also , due to worsening renal function , she was started on octreotide / midodrine / albumin for hepatorenal
|
36 |
+
syndrome ( Cr 3.3 at its worst ) which resolved prior to her discharge ."""
|
37 |
+
result = predict_ner(input_text)
|
38 |
+
print(result['text'])
|
39 |
+
# print(result['token_probabilities'])
|
40 |
+
for item in result['formatted_results']:
|
41 |
+
print(item)
|