toroe commited on
Commit
c661e63
·
verified ·
1 Parent(s): a421812

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +323 -0
README.md ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: other
5
+ pipeline_tag: text-generation
6
+ library_name: transformers
7
+ tags:
8
+ - clinical-nlp
9
+ - medical-coding
10
+ - icd10
11
+ - icd-10-cm
12
+ - reasoning
13
+ - reinforcement-learning
14
+ - grpo
15
+ - healthcare
16
+ base_model:
17
+ - Qwen/Qwen2.5-7B-Instruct
18
+ ---
19
+
20
+ # DeepICD-R1-7B
21
+
22
+ ## Model Summary
23
+
24
+ **DeepICD-R1-7B** is a clinical reasoning language model for **ICD-10-CM diagnosis outcome prediction from admission notes**.
25
+ It is derived from **Qwen2.5-7B-Instruct** and trained using the **DeepICD-R1 framework**, which combines structured reasoning traces with reinforcement learning and hierarchical reward signals.
26
+
27
+ The model is designed to predict a **single ICD-10-CM diagnosis code** from clinical text while producing an interpretable reasoning trace explaining the decision.
28
+
29
+ The training methodology follows the approach described in the paper:
30
+
31
+ **DeepICD-R1: Medical Reasoning through Hierarchical Rewards and Unsupervised Distillation**
32
+
33
+ This work frames clinical diagnosis prediction as a **reasoning task optimized through reinforcement learning**.
34
+
35
+ ---
36
+
37
+ # Model Details
38
+
39
+ - **Model name:** DeepICD-R1-7B
40
+ - **Organization:** DATEXIS
41
+ - **Base model:** Qwen2.5-7B-Instruct
42
+ - **Parameters:** ~7B
43
+ - **Task:** Single ICD-10-CM diagnosis prediction from admission notes
44
+ - **Training paradigm:** Supervised reasoning + reinforcement learning
45
+ - **Framework:** VERL RL trainer
46
+ - **Domain:** Clinical NLP / healthcare reasoning
47
+
48
+ The Qwen2.5-7B-Instruct architecture is a **7-billion-parameter instruction-tuned language model designed for instruction following and long-form generation tasks**. :contentReference[oaicite:1]{index=1}
49
+
50
+ ---
51
+
52
+ # Intended Use
53
+
54
+ This model is intended for **research purposes**, including:
55
+
56
+ - clinical reasoning research
57
+ - ICD-10-CM coding prediction
58
+ - reinforcement learning for language models
59
+ - reasoning trace generation
60
+ - structured prediction from clinical text
61
+
62
+ ### Out-of-Scope Use
63
+
64
+ This model **must not be used for**:
65
+
66
+ - medical diagnosis
67
+ - clinical decision support
68
+ - patient triage
69
+ - automated medical coding without expert supervision
70
+ - billing or compliance workflows
71
+
72
+ ---
73
+
74
+ # Training Methodology
75
+
76
+ The **DeepICD-R1 framework** treats diagnosis prediction as a reasoning problem.
77
+
78
+ Training combines:
79
+
80
+ ### 1. Supervised reasoning traces
81
+ A dataset of reasoning chains explaining diagnosis predictions.
82
+
83
+ ### 2. Reinforcement learning optimization
84
+
85
+ Training uses **Group Relative Policy Optimization (GRPO)** to improve reasoning and prediction accuracy.
86
+
87
+ ### 3. Hierarchical reward signals
88
+
89
+ Rewards are aligned with the hierarchical structure of ICD codes.
90
+
91
+ The reward function combines:
92
+
93
+ - **format reward** — correct reasoning + diagnosis structure
94
+ - **outcome reward** — correct diagnosis prediction
95
+ - **hierarchical reward** — partial credit for correct ICD prefixes
96
+
97
+ This design encourages models to produce both **accurate diagnoses and structured reasoning**.
98
+
99
+ ---
100
+
101
+ # Training Data
102
+
103
+ The training task uses **clinical admission notes paired with ICD-10-CM diagnosis codes**, derived from de-identified electronic health record datasets such as **MIMIC-IV**.
104
+
105
+ Task formulation:
106
+
107
+ **Input**
108
+
109
+ Clinical admission note describing patient presentation.
110
+
111
+ **Output**
112
+
113
+ Structured reasoning trace and predicted ICD-10-CM code.
114
+
115
+ ---
116
+
117
+ # Output Format
118
+
119
+ The model is trained to produce structured outputs separating reasoning from the final diagnosis.
120
+
121
+ ### Example
122
+
123
+ ```text
124
+ <think>
125
+ The patient presents with ...
126
+ Symptoms and clinical history suggest ...
127
+ ...
128
+ </think>
129
+
130
+ <diagnosis>
131
+ M5116
132
+ </diagnosis>
133
+ ```
134
+ ## Training Configuration
135
+
136
+ The model was trained using the **VERL reinforcement learning trainer** with **Group Relative Policy Optimization (GRPO)**, following the DeepICD-R1 training framework.
137
+
138
+ ### Core Training Parameters
139
+
140
+ | Parameter | Value |
141
+ |-----------|------|
142
+ | Algorithm | GRPO |
143
+ | Training framework | VERL (`verl.trainer.main_ppo`) |
144
+ | Base model | Qwen2.5-7B-Instruct |
145
+ | Training batch size | 64 |
146
+ | PPO mini batch size | 64 |
147
+ | PPO micro batch size per GPU | 16 |
148
+ | Learning rate | 1e-6 |
149
+ | LR warmup steps | 80 |
150
+ | Total epochs | 1 |
151
+ | Max prompt length | 2048 tokens |
152
+ | Max response length | 1024 tokens |
153
+
154
+ ### Rollout / Generation Settings
155
+
156
+ | Parameter | Value |
157
+ |-----------|------|
158
+ | Rollout engine | vLLM |
159
+ | Samples per prompt (`n`) | 8 |
160
+ | Temperature | 0.9 |
161
+ | Top-k | disabled |
162
+ | dtype | bfloat16 |
163
+ | Tensor parallel size | 1 |
164
+ | GPU memory utilization | 0.4 |
165
+
166
+ ### Optimization Details
167
+
168
+ | Parameter | Value |
169
+ |-----------|------|
170
+ | Entropy coefficient | 0.001 |
171
+ | KL controller coefficient | 0.001 |
172
+ | KL loss | disabled |
173
+ | Gradient checkpointing | enabled |
174
+ | Torch compile | enabled |
175
+ | FSDP param offload | disabled |
176
+ | FSDP optimizer offload | disabled |
177
+
178
+ ### Hardware
179
+
180
+ | Component | Value |
181
+ |-----------|------|
182
+ | GPUs | 4 |
183
+ | Nodes | 1 |
184
+ | Precision | bfloat16 |
185
+
186
+ ### Reward Function
187
+
188
+ Training uses a **custom batched reward function** combining several reward signals:
189
+
190
+ - **Outcome reward** — correct ICD-10 prediction
191
+ - **Format reward** — correct `<think>` and `<diagnosis>` structure
192
+ - **Hierarchical reward** — partial credit for ICD prefix matches
193
+ - **Reasoning reward** — encourages meaningful reasoning traces
194
+ - **LLM-based reward** — optional external judge scoring
195
+
196
+ These rewards align the model toward producing **both accurate diagnoses and structured reasoning traces**.
197
+
198
+ The reasoning trace provides transparency into how the diagnosis was derived from the clinical note.
199
+
200
+ ---
201
+
202
+ ## Evaluation
203
+
204
+ Evaluation follows the methodology described in the **DeepICD-R1 paper**.
205
+
206
+ Performance is measured using **macro-averaged F1 scores** at multiple levels of the ICD hierarchy.
207
+
208
+ | Level | Description |
209
+ |------|-------------|
210
+ | Chapter | Broad ICD category |
211
+ | Category | First three digits |
212
+ | Full code | Complete ICD-10 code |
213
+
214
+ Hierarchical evaluation allows partial credit when the model predicts the correct high-level diagnostic category even if the full code is incorrect.
215
+
216
+ ---
217
+
218
+ ## Limitations
219
+
220
+ Models following the **DeepICD-R1 framework** share several limitations.
221
+
222
+ ### Dataset limitations
223
+
224
+ - Training data consists primarily of **English clinical notes**
225
+ - Distribution reflects **hospital-specific patient populations**
226
+ - ICD labels are **highly imbalanced**, affecting rare diagnoses
227
+
228
+ ### Model limitations
229
+
230
+ - Reasoning traces may appear convincing while being incorrect
231
+ - Predictions may fail for rare or long-tail diagnoses
232
+ - Models may demonstrate **premature diagnostic closure**
233
+ - Reinforcement learning rewards are only proxies for expert feedback
234
+
235
+ ---
236
+
237
+ ## Ethical Considerations
238
+
239
+ This model is trained on **de-identified clinical data** and intended strictly for research.
240
+
241
+ ### Potential risks
242
+
243
+ - propagation of dataset biases
244
+ - overconfidence in generated reasoning
245
+ - misuse in clinical decision making
246
+
247
+ ### Appropriate safeguards
248
+
249
+ - expert oversight
250
+ - dataset bias evaluation
251
+ - fairness audits
252
+ - controlled deployment environments
253
+
254
+ ---
255
+
256
+ ## Hardware and Training Setup
257
+
258
+ Typical training configuration for models in this family includes:
259
+
260
+ - **GPUs:** multi-GPU training (4–8 GPUs)
261
+ - **Precision:** bfloat16
262
+ - **Rollout engine:** vLLM
263
+ - **Training framework:** VERL PPO / GRPO trainer
264
+ - **Sampling:** multiple rollouts per prompt
265
+
266
+ ---
267
+
268
+ ## Usage
269
+
270
+ ### Transformers Example
271
+
272
+ ```python
273
+ from transformers import AutoTokenizer, AutoModelForCausalLM
274
+
275
+ model_id = "DATEXIS/DeepICD-R1-7B"
276
+
277
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
278
+ model = AutoModelForCausalLM.from_pretrained(
279
+ model_id,
280
+ device_map="auto",
281
+ torch_dtype="auto"
282
+ )
283
+
284
+ prompt = """
285
+ You are a clinical reasoning model.
286
+
287
+ Given the following admission note,
288
+ produce reasoning in <think> tags
289
+ and a final ICD-10 diagnosis in <diagnosis> tags.
290
+
291
+ [ADMISSION NOTE]
292
+ """
293
+
294
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
295
+
296
+ outputs = model.generate(
297
+ **inputs,
298
+ max_new_tokens=512
299
+ )
300
+
301
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
302
+ ```
303
+ ## Recommended Inference Practices
304
+
305
+ - Use prompts consistent with the training format.
306
+ - Validate predicted ICD-10 codes against official code formats.
307
+ - Always review predictions with medical experts.
308
+ - Avoid exposing reasoning traces in safety-critical settings without verification.
309
+
310
+ ---
311
+
312
+ ## Citation
313
+
314
+ If you use this model, please cite:
315
+
316
+ ```bibtex
317
+ @inproceedings{roehr2026deepicdr1,
318
+ title={DeepICD-R1: Medical Reasoning through Hierarchical Rewards and Unsupervised Distillation},
319
+ author={R{\"o}hr, Tom and Steffek, Thomas and Teucher, Roman and Bressem, Keno and others},
320
+ booktitle={Proceedings of LREC-COLING},
321
+ year={2026}
322
+ }
323
+