You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

NDIJayant/gemma-4-E4B-comorbities

A Gemma 4 E4B-it model fine-tuned to extract patient comorbidities from unstructured clinical notes and return a strict JSON schema. Produces structured (comorbidity_name, status, comorbidity_date, evidence, reason) entries with an issues audit array.

Intended use

  • Input: an unstructured clinical note (ROS, PMH, Assessment, Imaging, Problem List, narrative). Up to ~15k tokens.
  • Output: JSON with eligibility status, a list of comorbidities (with onset/diagnosis date normalized to ISO), and an issues audit field.

Not a medical device. For research and internal tooling only — every extraction must be reviewed by a qualified clinician before any clinical decision.

Training

Two-stage supervised fine-tuning, full parameter (no LoRA).

Stage 1 — base SFT:

  • 115,459 train / 12,829 eval records
  • Labels: Gemini 2.5 Flash generations using a detailed clinical-extraction system prompt
  • 8 × RTX PRO 6000 Blackwell (95 GB), bf16, DeepSpeed ZeRO-2
  • 1 epoch, LR 5e-6 cosine w/ 3% warmup, effective batch size 128
  • Custom chunked lm_head + cross-entropy to avoid [B, S, 262K] logit OOM
  • Output: checkpoint-810

Stage 2 — teacher-distillation alignment on refined data:

  • Selected the top 40K records by output length from the stage-1 train set
  • Refined each (clinical note, Gemini draft) pair through Gemma 4 31B-it served locally via vLLM, using the same clinical system prompt — producing additions, removals, and date/status corrections
  • 33,315 refined records kept (JSON-valid, parseable)
  • 1 additional epoch from the stage-1 checkpoint, same optimizer config
  • Output: checkpoint-260 (this release)

Fully converged: eval loss was identical at step 130 vs step 260 (0.1323 vs 0.1324). No additional epochs would help on this data.

Evaluation

GPT-5.2 was used as an independent single-model judge (no head-to-head comparison — each model scored on its own merits against the clinical note, not against a Gemini-generated ground-truth that would bias the results in Gemini's favor). 50 samples drawn from the eval split with fixed seed.

Dimension gemma-enrich (this, ckpt-260) gemma-base (stage 1, ckpt-810) gemini-2.5-flash gemma-4-E4B-it (zero-shot)
completeness 8.29 7.79 7.82 6.90
accuracy 7.68 7.63 7.40 6.68
no_hallucination 9.10 9.26 8.92 8.11
json_format 9.18 9.16 5.79 8.92
overall 7.95 7.63 7.11 6.84

Reading the numbers:

  • Two-stage SFT lifted overall 6.84 → 7.63 → 7.95 (+1.11 from base Gemma, +0.32 from the stage-2 refinement alignment pass).
  • Beats the teacher (Gemini 2.5 Flash) by +0.84 overall as judged by an independent LLM — the student now outperforms the model that labelled its training data.
  • completeness climbs by +1.39 from zero-shot to the final model, the single biggest axis of improvement (and the one the stage-2 refinement targeted).
  • Gemini's weakness is json_format (5.79) — it frequently emits markdown-fenced output. Both Gemma checkpoints stay above 9.1.

Usage

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tok = AutoTokenizer.from_pretrained("NDIJayant/gemma-4-E4B-comorbities")
model = AutoModelForCausalLM.from_pretrained(
    "NDIJayant/gemma-4-E4B-comorbities", torch_dtype=torch.bfloat16, device_map="auto"
)

SYSTEM = "SYSTEM_PROMPT = """Extract comorbidities from the clinical note. Return only valid JSON, no markdown.

{
  "eligibility_status": true,
  "reason": {
    "comorbidities": [
      {
        "comorbidity_name": "",
        "status": "Pre-existing | Active | Historical | Unknown",
        "comorbidity_date": "",
        "evidence": "",
        "reason": ""
      }
    ]
  },
  "issues": []
}"""
"
note   = "..."

prompt = f"<bos><|turn>system\n{SYSTEM}<turn|>\n<|turn>user\n{note}<turn|>\n<|turn>model\n"
ids = tok(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
out = model.generate(**ids, max_new_tokens=2048, do_sample=False)
print(tok.decode(out[0][ids["input_ids"].shape[1]:], skip_special_tokens=True))

Limitations

  • English-language clinical notes only; not validated on other languages.
  • Performance degrades on notes > ~15k tokens (training cap).
  • Dates are best when explicitly stated in the note; for narrative-only conditions the model may leave comorbidity_date empty.
  • Still inherits teacher biases from Gemini 2.5 Flash and Gemma 4 31B (judgment calls around status classification, RA/seronegative handling).
  • No PHI / safety guarantees — output is not validated against identifiers.

Citation

Please cite the base model: google/gemma-4-E4B-it.

Downloads last month
-
Safetensors
Model size
9B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for NDIJayant/gemma-4-E4B-comorbities

Finetuned
(222)
this model