Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Define the model class
|
8 |
+
class MedicalCodePredictor(torch.nn.Module):
|
9 |
+
def __init__(self, bert_model):
|
10 |
+
super().__init__()
|
11 |
+
self.bert = bert_model
|
12 |
+
self.dropout = torch.nn.Dropout(0.1)
|
13 |
+
self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
|
14 |
+
self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
|
15 |
+
|
16 |
+
def forward(self, input_ids, attention_mask):
|
17 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
|
18 |
+
pooled_output = outputs.last_hidden_state[:, 0, :]
|
19 |
+
pooled_output = self.dropout(pooled_output)
|
20 |
+
|
21 |
+
icd_logits = self.icd_classifier(pooled_output)
|
22 |
+
cpt_logits = self.cpt_classifier(pooled_output)
|
23 |
+
|
24 |
+
return icd_logits, cpt_logits
|
25 |
+
|
26 |
+
# Load ICD codes from files
|
27 |
+
def load_icd_codes_from_files():
|
28 |
+
icd_codes = {}
|
29 |
+
directory_path = "./codes/icd_txt_files/" # Path to ICD codes directory
|
30 |
+
|
31 |
+
if os.path.exists(directory_path):
|
32 |
+
for file_name in os.listdir(directory_path):
|
33 |
+
if file_name.endswith(".txt"):
|
34 |
+
file_path = os.path.join(directory_path, file_name)
|
35 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
36 |
+
for line in file:
|
37 |
+
parts = line.strip().split("\t") # Adjust delimiter as needed
|
38 |
+
if len(parts) >= 2:
|
39 |
+
code = parts[0].strip()
|
40 |
+
description = parts[1].strip()
|
41 |
+
icd_codes[code] = description
|
42 |
+
else:
|
43 |
+
print(f"Directory {directory_path} does not exist!")
|
44 |
+
return icd_codes
|
45 |
+
|
46 |
+
# Load CPT codes from files
|
47 |
+
def load_cpt_codes_from_files():
|
48 |
+
cpt_codes = {}
|
49 |
+
directory_path = "./codes/cpt_txt_files/" # Path to CPT codes directory
|
50 |
+
|
51 |
+
if os.path.exists(directory_path):
|
52 |
+
for file_name in os.listdir(directory_path):
|
53 |
+
if file_name.endswith(".txt"):
|
54 |
+
file_path = os.path.join(directory_path, file_name)
|
55 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
56 |
+
for line in file:
|
57 |
+
parts = line.strip().split("\t") # Adjust delimiter as needed
|
58 |
+
if len(parts) >= 2:
|
59 |
+
code = parts[0].strip()
|
60 |
+
description = parts[1].strip()
|
61 |
+
cpt_codes[code] = description
|
62 |
+
else:
|
63 |
+
print(f"Directory {directory_path} does not exist!")
|
64 |
+
return cpt_codes
|
65 |
+
|
66 |
+
# Load ICD and CPT codes dynamically
|
67 |
+
ICD_CODES = load_icd_codes_from_files()
|
68 |
+
CPT_CODES = load_cpt_codes_from_files()
|
69 |
+
|
70 |
+
# Load models
|
71 |
+
@torch.no_grad()
|
72 |
+
def load_models():
|
73 |
+
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
|
74 |
+
base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
|
75 |
+
model = MedicalCodePredictor(base_model)
|
76 |
+
return tokenizer, model
|
77 |
+
|
78 |
+
# Prediction function
|
79 |
+
def predict_codes(text):
|
80 |
+
if not text.strip():
|
81 |
+
return "Please enter a medical summary."
|
82 |
+
|
83 |
+
# Tokenize input
|
84 |
+
inputs = tokenizer(text,
|
85 |
+
return_tensors="pt",
|
86 |
+
max_length=512,
|
87 |
+
truncation=True,
|
88 |
+
padding=True)
|
89 |
+
|
90 |
+
# Get predictions
|
91 |
+
model.eval()
|
92 |
+
icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
|
93 |
+
|
94 |
+
# Get probabilities
|
95 |
+
icd_probs = F.softmax(icd_logits, dim=1)
|
96 |
+
cpt_probs = F.softmax(cpt_logits, dim=1)
|
97 |
+
|
98 |
+
# Get top 3 predictions
|
99 |
+
top_icd = torch.topk(icd_probs, k=3)
|
100 |
+
top_cpt = torch.topk(cpt_probs, k=3)
|
101 |
+
|
102 |
+
# Format results
|
103 |
+
result = "Recommended ICD-10 Codes:\n"
|
104 |
+
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
|
105 |
+
result += f"{i+1}. {ICD_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
|
106 |
+
|
107 |
+
result += "\nRecommended CPT Codes:\n"
|
108 |
+
for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
|
109 |
+
result += f"{i+1}. {CPT_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
|
110 |
+
|
111 |
+
return result
|
112 |
+
|
113 |
+
# Load models globally
|
114 |
+
tokenizer, model = load_models()
|
115 |
+
|
116 |
+
# Create Gradio interface
|
117 |
+
iface = gr.Interface(
|
118 |
+
fn=predict_codes,
|
119 |
+
inputs=gr.Textbox(
|
120 |
+
lines=5,
|
121 |
+
placeholder="Enter medical summary here...",
|
122 |
+
label="Medical Summary"
|
123 |
+
),
|
124 |
+
outputs=gr.Textbox(
|
125 |
+
label="Predicted Codes",
|
126 |
+
lines=8
|
127 |
+
),
|
128 |
+
title="AutoRCM - Medical Code Predictor",
|
129 |
+
description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
|
130 |
+
examples=[
|
131 |
+
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."],
|
132 |
+
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."],
|
133 |
+
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."]
|
134 |
+
]
|
135 |
+
)
|
136 |
+
|
137 |
+
# Launch the interface
|
138 |
+
iface.launch(share=True)
|