MedGemma 1.5 - Multiple Myeloma Progression Tracking (Module 4)

πŸ“ˆ Model Overview

This repository contains a PEFT/LoRA adapter fine-tuned on MedGemma 1.5 4B-IT. It is specifically designed to analyze raw clinical text, extract longitudinal M-Spike metrics (Serum Protein Electrophoresis), and assess if the data indicates rapid disease progression in Multiple Myeloma patients.

This adapter was developed as part of a broader agentic AI application catering to Multiple Myeloma patients. It acts as Module 4, operating alongside upstream risk assessment and vision modules to feed structured progression data into a RAG-enabled clinical dashboard.

πŸ”— Associated Code Repository

The complete source code for data preparation, training and validating this adapter, as well as the full Agentic-AI pipeline, can be found on GitHub: here

Base Model Dependency

This is an adapter model. It requires the base weights from Google's MedGemma 1.5 4B-IT.

⚠️ License and Terms of Use

  • LoRA Adapter Weights: The adapter weights and associated code in this repository are open-sourced under the Apache 2.0 license.
  • Base Model: To use this adapter, you must agree to the Google Health AI Developer Foundations Terms of Use to access the underlying MedGemma 1.5 weights.
  • Clinical Disclaimer: This model is for educational and research purposes only. It is not a medical device, is not intended for clinical use, and should not be used to diagnose, treat, or offer medical advice for any disease or condition.

πŸ’» How to Use

Because this model uses the MedGemma vision-language architecture, it is strictly recommended to load the model in 4-bit NF4 quantization and utilize a dummy image tensor to stabilize the cross-attention vision layers during text-only generation.

from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
from peft import PeftModel
from PIL import Image
import torch

# 1. Load Base Model in 4-bit NF4
model_id = "google/medgemma-1.5-4b-it"
processor = AutoProcessor.from_pretrained(model_id)

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

base_model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=quant_config
)

# 2. Load this LoRA Adapter
model = PeftModel.from_pretrained(base_model, "shrish/medgemma-1.5-mm-progression-module4")

# 3. Format Prompt
patient_id = "" # fill in the patient id 
Timeline = "" # have the subsequent lab test metrics like: - Day -17: Platelets: 329.0 x10^9 cells/L, Hemoglobin: 6.7 mmol/L, Creatinine: 68.07 umol/L, M Protein: 2.98 g/dL, Calcium: 2.45 mmol/L
- Day 78: Calcium: 2.25 mmol/L, M Protein: 1.97 g/dL, Hemoglobin: 7.01 mmol/L, Platelets: 286.0 x10^9 cells/L, Creatinine: 53.92 umol/L, etc
prompt = f"Review the following longitudinal biomarker history for patient {patient_id}. Predict the disease trajectory: Is this patient showing biochemical progression toward Active Myeloma?
Timeline:{Timeline}"
messages = [{"role": "user", "content": prompt}]
formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

inputs = processor(
    text=formatted_prompt, 
    return_tensors="pt", 
    padding=True
).to(model.device)
inputs.pop("token_type_ids", None)

# 4. Generate
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=300, do_sample=False)

input_length = inputs["input_ids"].shape[1]
generated_tokens = outputs[0, input_length:]
pred_text = processor.decode(generated_tokens, skip_special_tokens=True).strip()

print("Model Prediction:\n", pred_text)
Downloads last month
1
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for shrishSVaidya/medgemma-1.5-mm-progression-module4

Adapter
(49)
this model