|
|
|
|
|
|
|
import gradio as gr |
|
from transformers import AutoProcessor, Llama4ForConditionalGeneration |
|
import datasets |
|
import torch |
|
import json |
|
import os |
|
import pdfplumber |
|
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
|
from accelerate import Accelerator |
|
import huggingface_hub |
|
import re |
|
import nltk |
|
from nltk.tokenize import sent_tokenize |
|
|
|
try: |
|
nltk.data.find('tokenizers/punkt') |
|
except LookupError: |
|
nltk.download('punkt') |
|
|
|
|
|
from document_analyzer import HealthcareFraudAnalyzer |
|
|
|
|
|
print("Environment variables:", dict(os.environ)) |
|
|
|
|
|
LLama = os.getenv("LLama") |
|
if not LLama: |
|
raise ValueError("LLama token not found. Set it in Hugging Face Space secrets as 'LLama'.") |
|
|
|
|
|
print(f"Retrieved LLama token: {LLama[:5]}...") |
|
|
|
|
|
huggingface_hub.login(token=LLama) |
|
|
|
|
|
MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct" |
|
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
|
|
|
model = Llama4ForConditionalGeneration.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
quantization_config={"load_in_8bit": True}, |
|
attn_implementation="flex_attention" |
|
) |
|
|
|
|
|
model = prepare_model_for_kbit_training(model) |
|
peft_config = LoraConfig( |
|
r=16, |
|
lora_alpha=32, |
|
lora_dropout=0.05, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"] |
|
) |
|
model = get_peft_model(model, peft_config) |
|
model.print_trainable_parameters() |
|
|
|
|
|
def extract_training_pairs_from_text(text): |
|
pairs = [] |
|
patterns = [ |
|
|
|
( |
|
r"(?i).*?\b(haloperidol|lorazepam|ativan)\b.*?\b(daily|routine|regular)\b.*?", |
|
"Patient receives {} on a {} basis. Is this appropriate medication management?", |
|
"This may indicate inappropriate medication management. Regular use of psychotropic medications without documented need assessment, behavior monitoring, and attempted dose reductions may violate care standards." |
|
), |
|
|
|
( |
|
r"(?i).*?\b(missing|omitted|absent|lacking)\b.*?\b(documentation|records|logs|notes)\b.*?", |
|
"Facility has {} {} for patient care. Is this a documentation concern?", |
|
"Yes, incomplete documentation is a significant red flag. Missing records may indicate attempts to conceal care issues or fraudulent billing for services not provided." |
|
), |
|
|
|
( |
|
r"(?i).*?\b(restrict|limit|prevent|block)\b.*?\b(visits|visitation|access|family)\b.*?", |
|
"Facility {} family {} without documented medical necessity. Is this suspicious?", |
|
"Yes, unjustified visitation restrictions may indicate attempts to conceal care issues and prevent family oversight. This can constitute fraud when facilities bill for care while violating resident rights." |
|
), |
|
|
|
( |
|
r"(?i).*?\b(hospice|terminal|end.of.life)\b.*?\b(not|without|lacking)\b.*?\b(evidence|decline|documentation)\b.*?", |
|
"Patient placed on {} care {} supporting {}. Is this fraudulent?", |
|
"Yes, hospice enrollment without documented terminal decline may indicate Medicare fraud. Hospice certification requires genuine clinical determination of terminal status with prognosis of six months or less." |
|
), |
|
|
|
( |
|
r"(?i).*?\b(different|contradicts|conflicts|inconsistent)\b.*?\b(records|documentation|testimony|statements)\b.*?", |
|
"Records show {} {} about patient condition. Is this fraudulent documentation?", |
|
"Yes, contradictory documentation is a strong indicator of fraudulent record-keeping designed to misrepresent care quality or patient condition, particularly when official records differ from internal communications." |
|
) |
|
] |
|
|
|
for pattern, input_template, output_text in patterns: |
|
matches = re.finditer(pattern, text) |
|
for match in matches: |
|
groups = match.groups() |
|
if len(groups) >= 2: |
|
input_text = input_template.format(*groups) |
|
pairs.append({ |
|
"input": input_text, |
|
"output": output_text |
|
}) |
|
|
|
if not pairs: |
|
if any(x in text.lower() for x in ["medication", "prescribed", "administered"]): |
|
pairs.append({ |
|
"input": "Medication records show inconsistencies in administration times. Is this concerning?", |
|
"output": "Yes, inconsistent medication administration timing may indicate fraudulent documentation or medication mismanagement that could harm patients." |
|
}) |
|
if any(x in text.lower() for x in ["visit", "family", "spouse"]): |
|
pairs.append({ |
|
"input": "Staff documents family visits inconsistently. Is this suspicious?", |
|
"output": "Yes, selective documentation of family visits indicates fraudulent record-keeping designed to create a false narrative about family involvement and patient responses." |
|
}) |
|
if any(x in text.lower() for x in ["hospice", "terminal", "prognosis"]): |
|
pairs.append({ |
|
"input": "Patient remained on hospice for extended period without documented decline. Is this Medicare fraud?", |
|
"output": "Yes, maintaining hospice services without documented decline suggests fraudulent hospice certification to obtain Medicare benefits inappropriately." |
|
}) |
|
|
|
return pairs |
|
|
|
|
|
def train_ui(files): |
|
try: |
|
raw_text = "" |
|
dataset = None |
|
for file in files: |
|
if file.name.endswith(".pdf"): |
|
with pdfplumber.open(file.name) as pdf: |
|
for page in pdf.pages: |
|
raw_text += page.extract_text() or "" |
|
elif file.name.endswith(".json"): |
|
with open(file.name, "r", encoding="utf-8") as f: |
|
raw_data = json.load(f) |
|
training_data = raw_data.get("training_pairs", raw_data) |
|
with open("temp_fraud_data.json", "w", encoding="utf-8") as f: |
|
json.dump({"training_pairs": training_data}, f) |
|
dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json") |
|
|
|
if not raw_text and not dataset: |
|
return "Error: No valid PDF or JSON data found." |
|
|
|
if raw_text: |
|
training_data = extract_training_pairs_from_text(raw_text) |
|
with open("temp_fraud_data.json", "w") as f: |
|
json.dump({"training_pairs": training_data}, f) |
|
dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json") |
|
|
|
def tokenize_data(example): |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [{"type": "text", "text": example['input']}] |
|
}, |
|
{ |
|
"role": "assistant", |
|
"content": [{"type": "text", "text": example['output']}] |
|
} |
|
] |
|
formatted_text = processor.apply_chat_template(messages, add_generation_prompt=False) |
|
inputs = processor(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt") |
|
inputs["labels"] = inputs["input_ids"].clone() |
|
return {k: v.squeeze(0) for k, v in inputs.items()} |
|
|
|
tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names) |
|
|
|
training_args = TrainingArguments( |
|
output_dir="./fine_tuned_llama4_healthcare", |
|
per_device_train_batch_size=2, |
|
gradient_accumulation_steps=8, |
|
eval_strategy="no", |
|
save_strategy="epoch", |
|
save_total_limit=2, |
|
num_train_epochs=5, |
|
learning_rate=2e-5, |
|
weight_decay=0.01, |
|
logging_dir="./logs", |
|
logging_steps=10, |
|
bf16=True, |
|
gradient_checkpointing=True, |
|
optim="adamw_torch", |
|
warmup_steps=100, |
|
) |
|
|
|
def custom_data_collator(features): |
|
return { |
|
"input_ids": torch.stack([f["input_ids"] for f in features]), |
|
"attention_mask": torch.stack([f["attention_mask"] for f in features]), |
|
"labels": torch.stack([f["labels"] for f in features]), |
|
} |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_dataset, |
|
data_collator=custom_data_collator, |
|
) |
|
|
|
trainer.train() |
|
model.save_pretrained("./fine_tuned_llama4_healthcare") |
|
processor.save_pretrained("./fine_tuned_llama4_healthcare") |
|
return f"Training completed with {len(tokenized_dataset)} examples! Model saved to ./fine_tuned_llama4_healthcare" |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}. Please check file format, dependencies, or the LLama token." |
|
|
|
|
|
def analyze_document_ui(files): |
|
try: |
|
if not files: |
|
return "Error: No file uploaded. Please upload a PDF to analyze." |
|
|
|
file = files[0] |
|
if not file.name.endswith(".pdf"): |
|
return "Error: Please upload a PDF file for analysis." |
|
|
|
raw_text = "" |
|
with pdfplumber.open(file.name) as pdf: |
|
for page in pdf.pages: |
|
raw_text += page.extract_text() or "" |
|
|
|
if not raw_text: |
|
return "Error: Could not extract text from the PDF. The file may be corrupt or contain only images." |
|
|
|
analyzer = HealthcareFraudAnalyzer(model, processor) |
|
results = analyzer.analyze_document(raw_text) |
|
return results["summary"] |
|
|
|
except Exception as e: |
|
return f"Error during document analysis: {str(e)}" |
|
|
|
|
|
with gr.Blocks(title="Healthcare Fraud Detection Suite") as demo: |
|
gr.Markdown("# Healthcare Fraud Detection Suite") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Fine-Tune Model"): |
|
gr.Markdown("## Train Llama 4 for Healthcare Fraud Detection") |
|
gr.Markdown("Upload PDFs (e.g., care logs, medication records) or a JSON file with training pairs.") |
|
train_file_input = gr.File(label="Upload Files (PDF/JSON)", file_count="multiple") |
|
train_button = gr.Button("Start Fine-Tuning") |
|
train_output = gr.Textbox(label="Training Status", lines=5) |
|
train_button.click(fn=train_ui, inputs=train_file_input, outputs=train_output) |
|
|
|
with gr.TabItem("Analyze Document"): |
|
gr.Markdown("## Analyze Document for Healthcare Fraud Indicators") |
|
gr.Markdown("Upload a PDF document to analyze for potential fraud, neglect, or abuse indicators.") |
|
analyze_file_input = gr.File(label="Upload PDF Document") |
|
analyze_button = gr.Button("Analyze Document") |
|
analyze_output = gr.Markdown(label="Analysis Results") |
|
analyze_button.click(fn=analyze_document_ui, inputs=analyze_file_input, outputs=analyze_output) |
|
|
|
gr.Markdown(""" |
|
### About This Tool |
|
This tool uses Llama 4 Maverick to identify patterns of potential fraud, neglect, and abuse in healthcare documentation. |
|
The fine-tuning tab allows model customization with your examples or automatic extraction from documents. |
|
The analysis tab scans documents for suspicious patterns, generating detailed reports. |
|
**Note:** All analysis is performed locally - no data is shared externally. |
|
""") |
|
|
|
|
|
demo.launch() |