Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import torch.nn.functional as F | |
import os | |
# Load ICD and CPT codes from files | |
def load_codes_from_files(directory_path, code_type): | |
codes = {} | |
if os.path.exists(directory_path): | |
for file_name in os.listdir(directory_path): | |
if file_name.endswith(".txt"): | |
file_path = os.path.join(directory_path, file_name) | |
with open(file_path, "r", encoding="utf-8") as file: | |
for line in file: | |
parts = line.strip().split(maxsplit=1) | |
if len(parts) == 2: | |
code = parts[0].strip() | |
description = parts[1].strip() | |
codes[code] = description | |
else: | |
print(f"Directory {directory_path} does not exist!") | |
return codes | |
# Load ICD and CPT codes | |
ICD_CODES = load_codes_from_files("./codes/icd_txt_files/", "ICD") | |
CPT_CODES = load_codes_from_files("./codes/cpt_txt_files/", "CPT") | |
# Check if codes were loaded | |
if not ICD_CODES or not CPT_CODES: | |
raise ValueError("No ICD or CPT codes were loaded. Please check your files and directory structure.") | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
model = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(ICD_CODES)) | |
# Prediction function | |
def predict_codes(text): | |
if not text.strip(): | |
return "Please enter a medical summary." | |
# Tokenize input | |
inputs = tokenizer( | |
text, | |
return_tensors="pt", | |
max_length=512, | |
truncation=True, | |
padding=True | |
) | |
# Get predictions | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Get probabilities | |
probs = F.softmax(logits, dim=1) | |
# Get top 3 predictions for ICD and CPT | |
top_k = min(3, len(ICD_CODES)) | |
top_icd = torch.topk(probs, k=top_k) | |
# Format results | |
result = "Recommended ICD-10 Codes:\n" | |
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])): | |
code = list(ICD_CODES.keys())[idx.item()] | |
description = ICD_CODES[code] | |
result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n" | |
result += "\nRecommended CPT Codes:\n" | |
for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])): | |
code = list(CPT_CODES.keys())[idx.item()] | |
description = CPT_CODES[code] | |
result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n" | |
return result | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_codes, | |
inputs=gr.Textbox( | |
lines=5, | |
placeholder="Enter medical summary here...", | |
label="Medical Summary" | |
), | |
outputs=gr.Textbox( | |
label="Predicted Codes", | |
lines=10 | |
), | |
title="AutoRCM - Medical Code Predictor", | |
description="Enter a medical summary to get recommended ICD-10 and CPT codes.", | |
examples=[ | |
["Patient presents with blood pressure 150/90. Complains of occasional headaches. History of hypertension."], | |
["Patient has elevated blood sugar levels. A1C is 7.8. History of type 2 diabetes."], | |
["Patient complains of chronic lower back pain, worse with movement. No radiation to legs."] | |
] | |
) | |
# Launch the interface | |
iface.launch(share=True) |