Add files using upload-large-folder tool
Browse files- BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/checkpoint-250/README.md +202 -0
- BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/checkpoint-250/adapter_config.json +39 -0
- BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/checkpoint-250/additional_config.json +1 -0
- BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/checkpoint-250/args.json +364 -0
- BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/val_dataset.jsonl +57 -0
- BIO/sft/qwen-production-08022302/v0-20250802-230250/images/train_train_steps_per_second.png +0 -0
- BioReason/.gitignore +180 -0
- BioReason/LICENSE +201 -0
- BioReason/README.md +148 -0
- BioReason/bioreason/__init__.py +0 -0
- BioReason/bioreason/dataset/__init__.py +11 -0
- BioReason/bioreason/dataset/kegg.py +382 -0
- BioReason/bioreason/dataset/utils.py +59 -0
- BioReason/bioreason/dataset/variant_effect.py +98 -0
- BioReason/bioreason/dna_modules/__init__.py +4 -0
- BioReason/bioreason/dna_modules/dna_module.py +49 -0
- BioReason/bioreason/dna_modules/esm_protein_module.py +649 -0
- BioReason/bioreason/dna_modules/nucleotide_module.py +263 -0
- BioReason/bioreason/dna_modules/protein_module.py +200 -0
- BioReason/bioreason/models/__init__.py +9 -0
- BioReason/bioreason/models/dl/__init__.py +1 -0
- BioReason/bioreason/models/dl/chat_template_dl.py +1 -0
- BioReason/bioreason/models/dl/configuration_dl.py +232 -0
- BioReason/bioreason/models/dl/processing_dl.py +275 -0
- BioReason/bioreason/models/dna_llm.py +306 -0
- BioReason/bioreason/models/dna_only.py +203 -0
- BioReason/bioreason/models/esm_tokenizer.py +330 -0
- BioReason/bioreason/models/evo2_tokenizer.py +219 -0
- BioReason/bioreason/models/pl/chat_template_pl.py +1 -0
- BioReason/bioreason/models/pl/configuration_pl.py +234 -0
- BioReason/bioreason/models/pl/processing_pl.py +276 -0
- BioReason/bioreason/models/protein_llm.py +375 -0
- BioReason/bioreason/models/protein_utils.py +233 -0
- BioReason/bioreason/trainer/__init__.py +7 -0
- BioReason/bioreason/trainer/demo_grpo.py +811 -0
- BioReason/bioreason/trainer/grpo_config.py +365 -0
- BioReason/bioreason/trainer/grpo_trainer.py +905 -0
- BioReason/bioreason/utils/__init__.py +0 -0
- BioReason/bioreason/utils/dna_utils.py +12 -0
- BioReason/data/BioReasoning_DataCuration_KEGG.ipynb +0 -0
- BioReason/data/Clinvar_Coding.ipynb +2481 -0
- BioReason/pyproject.toml +57 -0
- BioReason/reason.py +636 -0
- BioReason/reason_protein.py +696 -0
- BioReason/requirements.txt +13 -0
- BioReason/sh_reason.sh +57 -0
- BioReason/sh_train_dna_only.sh +138 -0
- BioReason/sh_train_dna_qwen.sh +191 -0
- BioReason/train_dna_only.py +502 -0
- BioReason/train_dna_qwen.py +1064 -0
BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/checkpoint-250/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: /oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.15.2
|
BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/checkpoint-250/adapter_config.json
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 32,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.05,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": [],
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"r": 8,
|
| 24 |
+
"rank_pattern": {},
|
| 25 |
+
"revision": null,
|
| 26 |
+
"target_modules": [
|
| 27 |
+
"up_proj",
|
| 28 |
+
"o_proj",
|
| 29 |
+
"gate_proj",
|
| 30 |
+
"down_proj",
|
| 31 |
+
"v_proj",
|
| 32 |
+
"k_proj",
|
| 33 |
+
"q_proj"
|
| 34 |
+
],
|
| 35 |
+
"task_type": "CAUSAL_LM",
|
| 36 |
+
"trainable_token_indices": null,
|
| 37 |
+
"use_dora": false,
|
| 38 |
+
"use_rslora": false
|
| 39 |
+
}
|
BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/checkpoint-250/additional_config.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"lora_dtype": null, "lorap_lr_ratio": null, "lorap_emb_lr": 1e-06}
|
BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/checkpoint-250/args.json
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model": "/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged",
|
| 3 |
+
"model_type": "qwen2_5",
|
| 4 |
+
"model_revision": null,
|
| 5 |
+
"task_type": "causal_lm",
|
| 6 |
+
"torch_dtype": "bfloat16",
|
| 7 |
+
"attn_impl": null,
|
| 8 |
+
"num_labels": null,
|
| 9 |
+
"problem_type": null,
|
| 10 |
+
"rope_scaling": null,
|
| 11 |
+
"device_map": null,
|
| 12 |
+
"max_memory": {},
|
| 13 |
+
"local_repo_path": null,
|
| 14 |
+
"template": "qwen2_5",
|
| 15 |
+
"system": null,
|
| 16 |
+
"max_length": 8192,
|
| 17 |
+
"truncation_strategy": "delete",
|
| 18 |
+
"max_pixels": null,
|
| 19 |
+
"tools_prompt": "react_en",
|
| 20 |
+
"norm_bbox": null,
|
| 21 |
+
"response_prefix": null,
|
| 22 |
+
"padding_side": "right",
|
| 23 |
+
"loss_scale": "default",
|
| 24 |
+
"sequence_parallel_size": 1,
|
| 25 |
+
"use_chat_template": true,
|
| 26 |
+
"template_backend": "swift",
|
| 27 |
+
"dataset": [
|
| 28 |
+
"/oss/wangyujia/ProtT3/ProtT3/data/sft/dataset/metal_ion_binding/train.jsonl"
|
| 29 |
+
],
|
| 30 |
+
"val_dataset": [],
|
| 31 |
+
"split_dataset_ratio": 0.01,
|
| 32 |
+
"data_seed": 42,
|
| 33 |
+
"dataset_num_proc": 128,
|
| 34 |
+
"dataset_shuffle": true,
|
| 35 |
+
"val_dataset_shuffle": false,
|
| 36 |
+
"streaming": false,
|
| 37 |
+
"interleave_prob": null,
|
| 38 |
+
"stopping_strategy": "first_exhausted",
|
| 39 |
+
"shuffle_buffer_size": 1000,
|
| 40 |
+
"enable_cache": false,
|
| 41 |
+
"download_mode": "reuse_dataset_if_exists",
|
| 42 |
+
"columns": {},
|
| 43 |
+
"strict": false,
|
| 44 |
+
"remove_unused_columns": true,
|
| 45 |
+
"model_name": [
|
| 46 |
+
"qwen_bio_sft_deeplocbinary-08022035"
|
| 47 |
+
],
|
| 48 |
+
"model_author": [
|
| 49 |
+
"swift"
|
| 50 |
+
],
|
| 51 |
+
"custom_dataset_info": [],
|
| 52 |
+
"quant_method": null,
|
| 53 |
+
"quant_bits": null,
|
| 54 |
+
"hqq_axis": null,
|
| 55 |
+
"bnb_4bit_compute_dtype": "bfloat16",
|
| 56 |
+
"bnb_4bit_quant_type": "nf4",
|
| 57 |
+
"bnb_4bit_use_double_quant": true,
|
| 58 |
+
"bnb_4bit_quant_storage": null,
|
| 59 |
+
"max_new_tokens": 64,
|
| 60 |
+
"temperature": 0.0,
|
| 61 |
+
"top_k": null,
|
| 62 |
+
"top_p": null,
|
| 63 |
+
"repetition_penalty": null,
|
| 64 |
+
"num_beams": 1,
|
| 65 |
+
"stream": false,
|
| 66 |
+
"stop_words": [],
|
| 67 |
+
"logprobs": false,
|
| 68 |
+
"top_logprobs": null,
|
| 69 |
+
"ckpt_dir": "/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged",
|
| 70 |
+
"load_dataset_config": null,
|
| 71 |
+
"lora_modules": [],
|
| 72 |
+
"tuner_backend": "peft",
|
| 73 |
+
"train_type": "lora",
|
| 74 |
+
"adapters": [],
|
| 75 |
+
"external_plugins": [],
|
| 76 |
+
"seed": 42,
|
| 77 |
+
"model_kwargs": {},
|
| 78 |
+
"load_args": false,
|
| 79 |
+
"load_data_args": false,
|
| 80 |
+
"use_hf": false,
|
| 81 |
+
"hub_token": null,
|
| 82 |
+
"custom_register_path": [],
|
| 83 |
+
"ignore_args_error": false,
|
| 84 |
+
"use_swift_lora": false,
|
| 85 |
+
"output_dir": "/nas/shared/kilab/wangyujia/BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822",
|
| 86 |
+
"overwrite_output_dir": false,
|
| 87 |
+
"do_train": false,
|
| 88 |
+
"do_eval": false,
|
| 89 |
+
"do_predict": false,
|
| 90 |
+
"eval_strategy": "steps",
|
| 91 |
+
"prediction_loss_only": false,
|
| 92 |
+
"per_device_train_batch_size": 2,
|
| 93 |
+
"per_device_eval_batch_size": 2,
|
| 94 |
+
"per_gpu_train_batch_size": null,
|
| 95 |
+
"per_gpu_eval_batch_size": null,
|
| 96 |
+
"gradient_accumulation_steps": 4,
|
| 97 |
+
"eval_accumulation_steps": null,
|
| 98 |
+
"eval_delay": 0,
|
| 99 |
+
"torch_empty_cache_steps": null,
|
| 100 |
+
"learning_rate": 1e-05,
|
| 101 |
+
"weight_decay": 0.1,
|
| 102 |
+
"adam_beta1": 0.9,
|
| 103 |
+
"adam_beta2": 0.95,
|
| 104 |
+
"adam_epsilon": 1e-08,
|
| 105 |
+
"max_grad_norm": 1.0,
|
| 106 |
+
"num_train_epochs": 3.0,
|
| 107 |
+
"max_steps": -1,
|
| 108 |
+
"lr_scheduler_type": "cosine",
|
| 109 |
+
"lr_scheduler_kwargs": null,
|
| 110 |
+
"warmup_ratio": 0.05,
|
| 111 |
+
"warmup_steps": 0,
|
| 112 |
+
"log_level": "passive",
|
| 113 |
+
"log_level_replica": "warning",
|
| 114 |
+
"log_on_each_node": true,
|
| 115 |
+
"logging_dir": "/nas/shared/kilab/wangyujia/BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/runs",
|
| 116 |
+
"logging_strategy": "steps",
|
| 117 |
+
"logging_first_step": true,
|
| 118 |
+
"logging_steps": 1,
|
| 119 |
+
"logging_nan_inf_filter": true,
|
| 120 |
+
"save_strategy": "steps",
|
| 121 |
+
"save_steps": 5.0,
|
| 122 |
+
"save_total_limit": 5,
|
| 123 |
+
"save_safetensors": true,
|
| 124 |
+
"save_on_each_node": false,
|
| 125 |
+
"save_only_model": true,
|
| 126 |
+
"restore_callback_states_from_checkpoint": false,
|
| 127 |
+
"no_cuda": false,
|
| 128 |
+
"use_cpu": false,
|
| 129 |
+
"use_mps_device": false,
|
| 130 |
+
"jit_mode_eval": false,
|
| 131 |
+
"use_ipex": false,
|
| 132 |
+
"bf16": true,
|
| 133 |
+
"fp16": false,
|
| 134 |
+
"fp16_opt_level": "O1",
|
| 135 |
+
"half_precision_backend": "auto",
|
| 136 |
+
"bf16_full_eval": false,
|
| 137 |
+
"fp16_full_eval": false,
|
| 138 |
+
"tf32": null,
|
| 139 |
+
"local_rank": 0,
|
| 140 |
+
"ddp_backend": null,
|
| 141 |
+
"tpu_num_cores": null,
|
| 142 |
+
"tpu_metrics_debug": false,
|
| 143 |
+
"debug": null,
|
| 144 |
+
"dataloader_drop_last": false,
|
| 145 |
+
"eval_steps": 5.0,
|
| 146 |
+
"dataloader_num_workers": 1,
|
| 147 |
+
"dataloader_prefetch_factor": null,
|
| 148 |
+
"past_index": -1,
|
| 149 |
+
"run_name": "construct",
|
| 150 |
+
"disable_tqdm": null,
|
| 151 |
+
"label_names": null,
|
| 152 |
+
"load_best_model_at_end": false,
|
| 153 |
+
"metric_for_best_model": "loss",
|
| 154 |
+
"greater_is_better": false,
|
| 155 |
+
"ignore_data_skip": false,
|
| 156 |
+
"fsdp": "",
|
| 157 |
+
"fsdp_min_num_params": 0,
|
| 158 |
+
"fsdp_config": null,
|
| 159 |
+
"tp_size": 0,
|
| 160 |
+
"fsdp_transformer_layer_cls_to_wrap": null,
|
| 161 |
+
"accelerator_config": {
|
| 162 |
+
"dispatch_batches": false
|
| 163 |
+
},
|
| 164 |
+
"deepspeed": {
|
| 165 |
+
"fp16": {
|
| 166 |
+
"enabled": "auto",
|
| 167 |
+
"loss_scale": 0,
|
| 168 |
+
"loss_scale_window": 1000,
|
| 169 |
+
"initial_scale_power": 16,
|
| 170 |
+
"hysteresis": 2,
|
| 171 |
+
"min_loss_scale": 1
|
| 172 |
+
},
|
| 173 |
+
"bf16": {
|
| 174 |
+
"enabled": "auto"
|
| 175 |
+
},
|
| 176 |
+
"zero_optimization": {
|
| 177 |
+
"stage": 3,
|
| 178 |
+
"offload_optimizer": {
|
| 179 |
+
"device": "none",
|
| 180 |
+
"pin_memory": true
|
| 181 |
+
},
|
| 182 |
+
"offload_param": {
|
| 183 |
+
"device": "none",
|
| 184 |
+
"pin_memory": true
|
| 185 |
+
},
|
| 186 |
+
"overlap_comm": false,
|
| 187 |
+
"contiguous_gradients": true,
|
| 188 |
+
"sub_group_size": 1000000000.0,
|
| 189 |
+
"reduce_bucket_size": "auto",
|
| 190 |
+
"zero_quantized_weights": false,
|
| 191 |
+
"zero_quantized_gradients": false,
|
| 192 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 193 |
+
"stage3_param_persistence_threshold": "auto",
|
| 194 |
+
"stage3_max_live_parameters": 1000000000.0,
|
| 195 |
+
"stage3_max_reuse_distance": 1000000000.0,
|
| 196 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 197 |
+
},
|
| 198 |
+
"gradient_accumulation_steps": "auto",
|
| 199 |
+
"gradient_clipping": "auto",
|
| 200 |
+
"steps_per_print": 2000,
|
| 201 |
+
"train_batch_size": "auto",
|
| 202 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 203 |
+
"wall_clock_breakdown": false
|
| 204 |
+
},
|
| 205 |
+
"label_smoothing_factor": 0.0,
|
| 206 |
+
"optim": "adamw_torch",
|
| 207 |
+
"optim_args": null,
|
| 208 |
+
"adafactor": false,
|
| 209 |
+
"group_by_length": false,
|
| 210 |
+
"length_column_name": "length",
|
| 211 |
+
"report_to": [
|
| 212 |
+
"tensorboard"
|
| 213 |
+
],
|
| 214 |
+
"ddp_find_unused_parameters": null,
|
| 215 |
+
"ddp_bucket_cap_mb": null,
|
| 216 |
+
"ddp_broadcast_buffers": null,
|
| 217 |
+
"dataloader_pin_memory": true,
|
| 218 |
+
"dataloader_persistent_workers": false,
|
| 219 |
+
"skip_memory_metrics": true,
|
| 220 |
+
"use_legacy_prediction_loop": false,
|
| 221 |
+
"push_to_hub": false,
|
| 222 |
+
"resume_from_checkpoint": null,
|
| 223 |
+
"hub_model_id": null,
|
| 224 |
+
"hub_strategy": "every_save",
|
| 225 |
+
"hub_private_repo": null,
|
| 226 |
+
"hub_always_push": false,
|
| 227 |
+
"gradient_checkpointing": true,
|
| 228 |
+
"gradient_checkpointing_kwargs": null,
|
| 229 |
+
"include_inputs_for_metrics": false,
|
| 230 |
+
"include_for_metrics": [],
|
| 231 |
+
"eval_do_concat_batches": true,
|
| 232 |
+
"fp16_backend": "auto",
|
| 233 |
+
"push_to_hub_model_id": null,
|
| 234 |
+
"push_to_hub_organization": null,
|
| 235 |
+
"push_to_hub_token": null,
|
| 236 |
+
"mp_parameters": "",
|
| 237 |
+
"auto_find_batch_size": false,
|
| 238 |
+
"full_determinism": false,
|
| 239 |
+
"torchdynamo": null,
|
| 240 |
+
"ray_scope": "last",
|
| 241 |
+
"ddp_timeout": 1800,
|
| 242 |
+
"torch_compile": false,
|
| 243 |
+
"torch_compile_backend": null,
|
| 244 |
+
"torch_compile_mode": null,
|
| 245 |
+
"include_tokens_per_second": false,
|
| 246 |
+
"include_num_input_tokens_seen": false,
|
| 247 |
+
"neftune_noise_alpha": null,
|
| 248 |
+
"optim_target_modules": null,
|
| 249 |
+
"batch_eval_metrics": false,
|
| 250 |
+
"eval_on_start": false,
|
| 251 |
+
"use_liger_kernel": false,
|
| 252 |
+
"eval_use_gather_object": false,
|
| 253 |
+
"average_tokens_across_devices": false,
|
| 254 |
+
"sortish_sampler": false,
|
| 255 |
+
"predict_with_generate": false,
|
| 256 |
+
"generation_max_length": null,
|
| 257 |
+
"generation_num_beams": null,
|
| 258 |
+
"generation_config": null,
|
| 259 |
+
"check_model": true,
|
| 260 |
+
"acc_strategy": "token",
|
| 261 |
+
"train_dataloader_shuffle": true,
|
| 262 |
+
"metric_warmup_step": 0,
|
| 263 |
+
"fsdp_num": 1,
|
| 264 |
+
"acc_steps": 1,
|
| 265 |
+
"eval_use_evalscope": false,
|
| 266 |
+
"eval_datasets": [],
|
| 267 |
+
"eval_limit": null,
|
| 268 |
+
"eval_datasets_args": null,
|
| 269 |
+
"eval_generation_config": null,
|
| 270 |
+
"freeze_parameters": [],
|
| 271 |
+
"freeze_parameters_ratio": 0.0,
|
| 272 |
+
"trainable_parameters": [],
|
| 273 |
+
"freeze_llm": false,
|
| 274 |
+
"freeze_vit": true,
|
| 275 |
+
"freeze_aligner": true,
|
| 276 |
+
"target_modules": [
|
| 277 |
+
"all-linear"
|
| 278 |
+
],
|
| 279 |
+
"target_regex": null,
|
| 280 |
+
"modules_to_save": [],
|
| 281 |
+
"lora_rank": 8,
|
| 282 |
+
"lora_alpha": 32,
|
| 283 |
+
"lora_dropout": 0.05,
|
| 284 |
+
"lora_bias": "none",
|
| 285 |
+
"lora_dtype": null,
|
| 286 |
+
"lorap_lr_ratio": null,
|
| 287 |
+
"use_rslora": false,
|
| 288 |
+
"use_dora": false,
|
| 289 |
+
"lora_ga_batch_size": 2,
|
| 290 |
+
"lora_ga_iters": 2,
|
| 291 |
+
"lora_ga_max_length": 1024,
|
| 292 |
+
"lora_ga_direction": "ArB2r",
|
| 293 |
+
"lora_ga_scale": "stable",
|
| 294 |
+
"lora_ga_stable_gamma": 16,
|
| 295 |
+
"init_weights": true,
|
| 296 |
+
"fourier_n_frequency": 2000,
|
| 297 |
+
"fourier_scaling": 300.0,
|
| 298 |
+
"boft_block_size": 4,
|
| 299 |
+
"boft_block_num": 0,
|
| 300 |
+
"boft_n_butterfly_factor": 1,
|
| 301 |
+
"boft_dropout": 0.0,
|
| 302 |
+
"vera_rank": 256,
|
| 303 |
+
"vera_projection_prng_key": 0,
|
| 304 |
+
"vera_dropout": 0.0,
|
| 305 |
+
"vera_d_initial": 0.1,
|
| 306 |
+
"adapter_act": "gelu",
|
| 307 |
+
"adapter_length": 128,
|
| 308 |
+
"use_galore": false,
|
| 309 |
+
"galore_target_modules": null,
|
| 310 |
+
"galore_rank": 128,
|
| 311 |
+
"galore_update_proj_gap": 50,
|
| 312 |
+
"galore_scale": 1.0,
|
| 313 |
+
"galore_proj_type": "std",
|
| 314 |
+
"galore_optim_per_parameter": false,
|
| 315 |
+
"galore_with_embedding": false,
|
| 316 |
+
"galore_quantization": false,
|
| 317 |
+
"galore_proj_quant": false,
|
| 318 |
+
"galore_proj_bits": 4,
|
| 319 |
+
"galore_proj_group_size": 256,
|
| 320 |
+
"galore_cos_threshold": 0.4,
|
| 321 |
+
"galore_gamma_proj": 2,
|
| 322 |
+
"galore_queue_size": 5,
|
| 323 |
+
"adalora_target_r": 8,
|
| 324 |
+
"adalora_init_r": 12,
|
| 325 |
+
"adalora_tinit": 0,
|
| 326 |
+
"adalora_tfinal": 0,
|
| 327 |
+
"adalora_deltaT": 1,
|
| 328 |
+
"adalora_beta1": 0.85,
|
| 329 |
+
"adalora_beta2": 0.85,
|
| 330 |
+
"adalora_orth_reg_weight": 0.5,
|
| 331 |
+
"llamapro_num_new_blocks": 4,
|
| 332 |
+
"llamapro_num_groups": null,
|
| 333 |
+
"lisa_activated_layers": 0,
|
| 334 |
+
"lisa_step_interval": 20,
|
| 335 |
+
"reft_layer_key": null,
|
| 336 |
+
"reft_layers": null,
|
| 337 |
+
"reft_rank": 4,
|
| 338 |
+
"reft_intervention_type": "LoreftIntervention",
|
| 339 |
+
"reft_args": null,
|
| 340 |
+
"swanlab_token": null,
|
| 341 |
+
"swanlab_project": null,
|
| 342 |
+
"swanlab_workspace": null,
|
| 343 |
+
"swanlab_exp_name": null,
|
| 344 |
+
"swanlab_mode": "cloud",
|
| 345 |
+
"add_version": true,
|
| 346 |
+
"resume_only_model": false,
|
| 347 |
+
"create_checkpoint_symlink": false,
|
| 348 |
+
"packing": false,
|
| 349 |
+
"lazy_tokenize": false,
|
| 350 |
+
"loss_type": null,
|
| 351 |
+
"optimizer": null,
|
| 352 |
+
"metric": null,
|
| 353 |
+
"zero_hpz_partition_size": null,
|
| 354 |
+
"rank": 0,
|
| 355 |
+
"global_world_size": 8,
|
| 356 |
+
"local_world_size": 8,
|
| 357 |
+
"model_suffix": "checkpoint-50-merged",
|
| 358 |
+
"model_info": "ModelInfo(model_type='qwen2_5', model_dir='/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged', torch_dtype=torch.bfloat16, max_model_len=32768, quant_method=None, quant_bits=None, rope_scaling=None, config=None, task_type='causal_lm', num_labels=None)",
|
| 359 |
+
"model_meta": "ModelMeta(model_type='qwen2_5', model_groups=[ModelGroup(models=[Model(ms_model_id='Qwen/Qwen2.5-0.5B-Instruct', hf_model_id='Qwen/Qwen2.5-0.5B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-1.5B-Instruct', hf_model_id='Qwen/Qwen2.5-1.5B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-3B-Instruct', hf_model_id='Qwen/Qwen2.5-3B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-7B-Instruct', hf_model_id='Qwen/Qwen2.5-7B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-14B-Instruct', hf_model_id='Qwen/Qwen2.5-14B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-32B-Instruct', hf_model_id='Qwen/Qwen2.5-32B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-72B-Instruct', hf_model_id='Qwen/Qwen2.5-72B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-0.5B', hf_model_id='Qwen/Qwen2.5-0.5B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-1.5B', hf_model_id='Qwen/Qwen2.5-1.5B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-3B', hf_model_id='Qwen/Qwen2.5-3B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-7B', hf_model_id='Qwen/Qwen2.5-7B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-14B', hf_model_id='Qwen/Qwen2.5-14B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-32B', hf_model_id='Qwen/Qwen2.5-32B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-72B', hf_model_id='Qwen/Qwen2.5-72B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-0.5B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-0.5B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-1.5B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-1.5B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-3B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-3B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-7B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-7B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-14B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-14B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-32B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-32B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-72B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-72B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None)], ignore_patterns=None, requires=None, tags=[]), ModelGroup(models=[Model(ms_model_id='Qwen/Qwen2.5-Coder-0.5B-Instruct', hf_model_id='Qwen/Qwen2.5-Coder-0.5B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-1.5B-Instruct', hf_model_id='Qwen/Qwen2.5-Coder-1.5B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-3B-Instruct', hf_model_id='Qwen/Qwen2.5-Coder-3B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-7B-Instruct', hf_model_id='Qwen/Qwen2.5-Coder-7B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-14B-Instruct', hf_model_id='Qwen/Qwen2.5-Coder-14B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-32B-Instruct', hf_model_id='Qwen/Qwen2.5-Coder-32B-Instruct', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-0.5B', hf_model_id='Qwen/Qwen2.5-Coder-0.5B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-1.5B', hf_model_id='Qwen/Qwen2.5-Coder-1.5B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-3B', hf_model_id='Qwen/Qwen2.5-Coder-3B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-7B', hf_model_id='Qwen/Qwen2.5-Coder-7B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-14B', hf_model_id='Qwen/Qwen2.5-Coder-14B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-32B', hf_model_id='Qwen/Qwen2.5-Coder-32B', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-Coder-0.5B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-Coder-1.5B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-3B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-Coder-3B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-7B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-Coder-7B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-14B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-Coder-14B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-32B-Instruct-AWQ', hf_model_id='Qwen/Qwen2.5-Coder-32B-Instruct-AWQ', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-Coder-0.5B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-Coder-3B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-Coder-14B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4', hf_model_id='Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int4', model_path=None, ms_revision=None, hf_revision=None), Model(ms_model_id='Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8', hf_model_id='Qwen/Qwen2.5-Coder-32B-Instruct-GPTQ-Int8', model_path=None, ms_revision=None, hf_revision=None)], ignore_patterns=None, requires=None, tags=['coding'])], template='qwen2_5', get_function=<function get_model_tokenizer_with_flash_attn at 0x7fbf7497de10>, model_arch='llama', architectures=['Qwen2ForCausalLM'], additional_saved_files=[], torch_dtype=None, is_multimodal=False, is_reward=False, task_type=None, ignore_patterns=[], requires=['transformers>=4.37'], tags=[])",
|
| 360 |
+
"model_dir": "/oss/wangyujia/BIO/construction_finetuning/alpaca/v1-20250609-141541/checkpoint-50-merged",
|
| 361 |
+
"hub": "<class 'swift.hub.hub.MSHub'>",
|
| 362 |
+
"evaluation_strategy": "steps",
|
| 363 |
+
"training_args": "Seq2SeqTrainingArguments(output_dir='/nas/shared/kilab/wangyujia/BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822', overwrite_output_dir=False, do_train=False, do_eval=True, do_predict=False, eval_strategy=<IntervalStrategy.STEPS: 'steps'>, prediction_loss_only=False, per_device_train_batch_size=2, per_device_eval_batch_size=2, per_gpu_train_batch_size=None, per_gpu_eval_batch_size=None, gradient_accumulation_steps=4, eval_accumulation_steps=None, eval_delay=0, torch_empty_cache_steps=None, learning_rate=1e-05, weight_decay=0.1, adam_beta1=0.9, adam_beta2=0.95, adam_epsilon=1e-08, max_grad_norm=1.0, num_train_epochs=3.0, max_steps=-1, lr_scheduler_type=<SchedulerType.COSINE: 'cosine'>, lr_scheduler_kwargs=None, warmup_ratio=0.05, warmup_steps=0, log_level='passive', log_level_replica='warning', log_on_each_node=True, logging_dir='/nas/shared/kilab/wangyujia/BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/runs', logging_strategy=<IntervalStrategy.STEPS: 'steps'>, logging_first_step=True, logging_steps=1, logging_nan_inf_filter=True, save_strategy=<SaveStrategy.STEPS: 'steps'>, save_steps=5, save_total_limit=5, save_safetensors=True, save_on_each_node=False, save_only_model=True, restore_callback_states_from_checkpoint=False, no_cuda=False, use_cpu=False, use_mps_device=False, seed=42, data_seed=42, jit_mode_eval=False, use_ipex=False, bf16=True, fp16=False, fp16_opt_level='O1', half_precision_backend='auto', bf16_full_eval=False, fp16_full_eval=False, tf32=None, local_rank=0, ddp_backend=None, tpu_num_cores=None, tpu_metrics_debug=False, debug=[], dataloader_drop_last=False, eval_steps=5, dataloader_num_workers=1, dataloader_prefetch_factor=10, past_index=-1, run_name='/nas/shared/kilab/wangyujia/BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822', disable_tqdm=False, remove_unused_columns=False, label_names=None, load_best_model_at_end=False, metric_for_best_model='loss', greater_is_better=False, ignore_data_skip=False, fsdp=[], fsdp_min_num_params=0, fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}, tp_size=0, fsdp_transformer_layer_cls_to_wrap=None, accelerator_config=AcceleratorConfig(split_batches=False, dispatch_batches=False, even_batches=True, use_seedable_sampler=True, non_blocking=False, gradient_accumulation_kwargs=None, use_configured_state=False), deepspeed={'fp16': {'enabled': 'auto', 'loss_scale': 0, 'loss_scale_window': 1000, 'initial_scale_power': 16, 'hysteresis': 2, 'min_loss_scale': 1}, 'bf16': {'enabled': 'auto'}, 'zero_optimization': {'stage': 3, 'offload_optimizer': {'device': 'none', 'pin_memory': True}, 'offload_param': {'device': 'none', 'pin_memory': True}, 'overlap_comm': False, 'contiguous_gradients': True, 'sub_group_size': 1000000000.0, 'reduce_bucket_size': 'auto', 'zero_quantized_weights': False, 'zero_quantized_gradients': False, 'stage3_prefetch_bucket_size': 'auto', 'stage3_param_persistence_threshold': 'auto', 'stage3_max_live_parameters': 1000000000.0, 'stage3_max_reuse_distance': 1000000000.0, 'stage3_gather_16bit_weights_on_model_save': True}, 'gradient_accumulation_steps': 'auto', 'gradient_clipping': 'auto', 'steps_per_print': 2000, 'train_batch_size': 'auto', 'train_micro_batch_size_per_gpu': 'auto', 'wall_clock_breakdown': False}, label_smoothing_factor=0.0, optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>, optim_args=None, adafactor=False, group_by_length=False, length_column_name='length', report_to=['tensorboard'], ddp_find_unused_parameters=None, ddp_bucket_cap_mb=None, ddp_broadcast_buffers=None, dataloader_pin_memory=True, dataloader_persistent_workers=False, skip_memory_metrics=True, use_legacy_prediction_loop=False, push_to_hub=False, resume_from_checkpoint=None, hub_model_id=None, hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>, hub_token=None, hub_private_repo=None, hub_always_push=False, gradient_checkpointing=True, gradient_checkpointing_kwargs=None, include_inputs_for_metrics=False, include_for_metrics=[], eval_do_concat_batches=True, fp16_backend='auto', push_to_hub_model_id=None, push_to_hub_organization=None, push_to_hub_token=None, mp_parameters='', auto_find_batch_size=False, full_determinism=False, torchdynamo=None, ray_scope='last', ddp_timeout=1800, torch_compile=False, torch_compile_backend=None, torch_compile_mode=None, include_tokens_per_second=None, include_num_input_tokens_seen=None, neftune_noise_alpha=None, optim_target_modules=None, batch_eval_metrics=False, eval_on_start=False, use_liger_kernel=False, eval_use_gather_object=False, average_tokens_across_devices=None, sortish_sampler=False, predict_with_generate=False, generation_max_length=None, generation_num_beams=None, generation_config=None, check_model=True, acc_strategy='token', train_dataloader_shuffle=True, metric_warmup_step=0, fsdp_num=1, acc_steps=1, eval_use_evalscope=False, eval_datasets=[], eval_limit=None, eval_datasets_args=None, eval_generation_config=None, train_type='lora', optimizer=None, local_repo_path=None, galore_config=None)"
|
| 364 |
+
}
|
BIO/sft/qwen-metal_ion_binding-08022141/v0-20250802-215822/val_dataset.jsonl
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: TVPRFGVQTDQEEQLAKELEDTNKWGLDVFKVAELSGNRPLTAIIFSIFQERDLLKTFQIPADTLATYLLMLEGHYHANVAYHNSLHAADVAQSTHVLLATPALEAVFTDLEILAALFASAIHDVDHPGVSNQFLINTNSELALMYNDASVLENHHLAVGFKLLQAENCDIFQNLSAKQRLSLRRMVIDMVLATDMSKHMNLLADLKTMVETKKVTSLGVLLLDNYSDRIQVLQNLVHCADLSNPTKPLPLYRQWTDRIMAEFFQQGDRERESGLDISPMCDKHTASVEKSQVGFIDYIAHPLWETWADLVHPDAQDLLDTLEDNREWYQSKIPRSPSDLTNPERDGPDRFQFELTLE\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 2 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MLTAEDKKLIQQAWEKAASHQEEFGAEALTRMFTTYPQTKTYFPHFDLSPGSDQVRGHGKKVLGALGNAVKNVDNLSQAMAELSNLHAYNLRVDPVNFKLLSQCIQVVLAVHMGKDYTPEVHAAFDKFLSAVSAVLAEKYR\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 3 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: PQVTLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKMIGGIGGFIKVRQYDQILIEICGHKAIGTVLVGPTPVNIIGRDLLTQIGMTLNFGGSSGPQVTLWQRPLVTIKIGGQLKEALLDTGADDTVLEEMSLPGRWKPKMIGGIGGFIKVRQYDQILIEICGHKAIGTVLVGPTPVNIIGRDLLTQIGMTLNF\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 4 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: ADKELKFLVVDDESTMRRIVRNLLKELGFNNVEEAEDGVDALNKLQAGGYGFVISDWMMPNMDGLELLKTIRADGAMSALPVLMVTALAKKENIIAAAQAGASGYVVKPFTAATLEEKLNKIFEKLGM\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 5 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MRTRSTISTPNGITWYYEQEGTGPDIVLVPDGLGECQMFDSSVSQIAAQGFRVTTFDMPGMSRSAKAPAETYTEVTAQKLASYVISILDALDIKHATVWGCSSGASTVVALLLGYPDRIRNAMCHELPTKLLDHLSNTAVLEDEEISNILANVMLNDVSGGSEAWQALGVEVHARLHKNYPVWARGYPRTIPPSAPVQDVEALRGKPLDWTVGAATPTESFFDNIVTATKAGVNIGLLPGMHFPYVSHPDVFAKYVVETTQKHLWNSSSVDKLAAALEHHHHHH\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 6 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MGSSHHHHHHSSGLVPRGSHMTEIATTSGARSVGLLSVGAYRPERVVTNDEICQHIDSSDEWIYTRTGIKTRRFAADDESAASMATEACRRALSNAGLSAADIDGVIVTTNTHFLQTPPAAPMVAASLGAKGILGFDLSAGCAGFGYALGAAADMIRGGGAATMLVVGTEKLSPTIDMYDRGNCFIFADGAAAVVVGETPFQGIGPTVAGSDGEQADAIRQDIDWITFAQNPSGPRPFVRLEGPAVFRWAAFKMGDVGRRAMDAAGVRPDQIDVFVPHQANSRINELLVKNLQLRPDAVVANDIEHTGNTSAASIPLAMAELLTTGAAKPGDLALLIGYGAGLSYAAQVVRMPKG\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 7 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: GLKLDLTWFDKSTEDFKGEEYSKDFGDDGSVMESLGVPFKDNVNNGCFDVIAEWVPLLQPYFNHQIDISDNEYFVSFDYRDGDW\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 8 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MGSSHHHHHHSSGLVPRGSHMPAKKPYNKIVSHLLVAEPEKIYAMPDPTVPDSDIKALTTLCDLADRELVVIIGWAKHIPGFSTLSLADQMSLLQSAWMEILILGVVYRSLSFEDELVYADDYIMDEDQSKLAGLLDLNNAILQLVKKYKSMKLEKEEFVTLKAIALANSDSMHIEDVEAVQKLQDVLHEALQDYEAGQHMEDPRRAGKMLMTLPLLRQTSTKAVQHFYNIKLEGKVPMHKLFLEMLEAKV\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 9 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: VLSGTDKTNVKSIFSKIGGQADDYGAEALERMFVTYPQTKTYFPHFDVSPGSAQVKAHGKKVAGGLSEAANHIDDIATSLSKLSDLHAQKLRVDPVNFKLLGQCFLVVVAIHNPSALTPEAHASLDKFLCAVGLVLTAKYR\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 10 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: TTPLVHVASVEKGRSYEDFQKVYNAIALKLREDDEYDNYIGYGPVLVRLAWHTSGTWDKHDNTGGSYGGTYRFKKEFNDPSNAGLQNGFKFLEPIHKEFPWISSGDLFSLGGVTAVQEMQGPKIPWRCGRVDTPEDTTPDNGRLPDADKDADYVRTFFQRLNMNDREVVALMGAHTLGKTHLKNSGYEGPWTANNNVFDNSFYLNLLNEDWKLEKNDANNEQWDSKSGYLQLPTDYSLIQDPKYLSIVKEYANDQDKFFKDFSKAFEKLLENGITFPKDAPSPFIFKTLEEQGL\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 11 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: DFDCIPGWSAYDRYCYQAFSKPKNWEDAESFCEEGVKTSHLVSIESSGEGDFVAQLVAEKIKTSFQYVWIGLRIQNKEQQCRSEWSDASSVNYENLVKQFSKKCYALKKGTELRTWFNVYCGTENPEVCKYTPEC\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 12 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: QNNVPNTFTDPDSGITFNTWGLDEDSPQTQGGFTFGVALPSDALTTDASEFIGYLKCARNDESGWCGISLGGPMTNSLLITAWPHEDTVYTSLRFATGYAMPDVYEGDAEITQVSSSVNSTHFSLIFRCKNCLQWSHGGSSGGASTSGGVLVLGWVQAFDDPGNPTCPEQITLQQHDNGMGIWGAQLNTDAASPSYTDWAAQATKTVT\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 13 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MFTGSIVAIVTPMDEKGNVCRASLKKLIDYHVASGTSAIVSVGTTGESATLNHDEHADVVMMTLDLADGRIPVIAGTGANATAEAISLTQRFNDSGIVGCLTVTPYWNRPSQEGLYQHFKAIAEHTDLPQILYNVPSRTGCDLLPETVGRLAKVKNIIGIKEATGNLTRVNQIKELVSDDFVLLSGDDASALDFMQLGGHGVISVTANVAARDMAQMCKLAAEGHFAEARVINQRLMPLHNKLFVEPNPIPVKWACKELGLVATDTLRLPMTPITDSGRETVRAALKHAGLL\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 14 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MQNDAGEFVDLYVPRKCSASNRIIGAKDHASIQMNVAEVDKVTGRFNGQFKTYAICGAIRRMGESDDSILRLAKADGIVSKNF\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 15 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: GMRIPSAIQLHKASKTLTLRYGEDSYDLPAEFLRVHSPSAEVQGHGNPVLQYGKLNVGLVGVEPAGQYALKLSFDDGHDSGLFTWDYLYELATRKDQLWADYLAELASAGKSRDPDESVVKLML\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 16 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: ATSTKKLHKEPATLIKAIDGDTVKLMYKGQPMTFRLLLVDTPEFNEKYGPEASAFTKKMEENAKKIEVEFDKGQRTDKYGRGLAYIYADGKMVNEALVRQGLAKVAYVYKGNNTHEQLLRKAEAQAKKEKLNIWSEDNADSGQ\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 17 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: LTESSTSKFVKINEKGFSDFNIHYNEAGNGETVIMLHGGGPGAGGWSNYYRNVGPFVDAGYRVILKDSPGFNKSDAVVMDEQRGLVNARAVKGLMDALDIDRAHLVGNAMGGATALNFALEYPDRIGKLILMGPGGLGPSMFAPMPMEGIKLLFKLYAEPSYETLKQMLQVFLYDQSLITEELLQGRWEAIQRQPEHLKNFLISAQKAPLSTWDVTARLGEIKAKTFITWGRDDRFVPLDHGLKLLWNIDDARLHVFSKCGHWAQWEHADEFNRLVIDFLRHA\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 18 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: CGTSNGGQNTSPSTSSSSAKGEASALPKGQTITVWSWQTGPELQDVKQIAAQWAKAHGDKVIVVDQSSNPKGFQFYATAARTGKGPDVVFGMPHDNNGVFAEEGLMAPVPSGVLNTGLYAPNTIDAIKVNGTMYSVPVSVQVAAIYYNKKLVPQPPQTWAEFVKDANAHGFMYDQANLYFDYAIIGGYGGYVFKDNNGTLDPNNIGLDTPGAVQAYTLMRDMVSKYHWMTPSTNGSIAKAEFLAGKIGMYVSGPWDTADIEKAKIDFGVTPWPTLPNGKHATPFLGVITAFVNKESKTQAADWSLVQALTSAQAQQMYFRDSQQIPALLSVQRSSAVQSSPTFKAFVEQLRYAVPMPNIPQMQAVWQAMSILQNIIAGKVSPEQGAKDFVQNIQKGIMAQGS\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 19 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MAVKVEYDLKRLRNIGIAAHIDAGKTTTTERILYYTGRIHKIGEVHEGAATMDFMEQERERGITITAAVTTCFWKDHRINIIDTPGHVDFTIEVERSMRVLDGAIVVFDSSQGVEPQSETVWRQAEKYKVPRIAFANKMDKTGADLWLVIRTMQERLGARPVVMQLPIGREDTFSGIIDVLRMKAYTYGNDLGTDIREIPIPEEYLDNAREYHEKLVEVAADFDENIMLKYLEGEEPTEEELVAAIRKGTIDLKITPVFLGSALKNKGVQLLLDAVVDYLPSPLDIPPIKGTTPEGEVVEIHPDPNGPLAALAFKIMADPYVGRLTFIRVYSGTLTSGSYVYNTTKGRKERVARLLRMHANHREEVEELKAGDLGAVVGLKETITGDTLVGEDAPRVILESIEVPEPVIDVAIEPKTKADQEKLSQALARLAEEDPTFRVSTHPETGQTIISGMGELHLEIIVDRLKREFKVDANVGKPQVAYRETITKPVDVEGKFIRQTGGRGQYGHVKIKVEPLPRGSGFEFVNAIVGGVIPKEYIPAVQKGIEEAMQSGPLIGFPVVDIKVTLYDGSYHEVDSSEMAFKIAGSMAIKEAVQKGDPVILEPIMRVEVTTPEEYMGDVIGDLNARRGQILGMEPRGNAQVIRAFVPLAEMFGYATDLRSKTQGRGSFVMFFDHYQEVPKQVQEKLIKGQ\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 20 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MSLSIPTIKQVRAFVLRGGGADYHDQGDGHWIDDHISTPMGKYPEYRQSRRSFGINVLGTLVVEIEASDGNVGFAVTTGGEPAAYIVEKHLARFLEGARVTDIERIWDQMYNSTLYYGRKGLVINTISGVDLALWDLLGKVRREPVHQLLGGAVRDELQFYATGARPDLAQKMGFIGGKMPLHHGPSEGEEGLKKNLEELATMRERVGPDFWLMFDCWMSLDLNYATRLARGAREYGLKWIEEALPPDDYWGYAELRRNAPTGMMVTTGEHEATRWGFRMLLEMGCCDIIQPDVGWCGGVTELLKISALADAHNALVVPHGSSVYSYHFVATRQNSPFAEFLMMAPKADQVVPMFHPQLLGEPVPENGRMRLSRLDQPGFGVTLNPECQLHRPYTHEGHHHHHH\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 21 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: GSAKGNFCPLCDKCYDDDDYESKMMQCGKCDRWVHSKCENLSDEMYEILSNLPESVAYTCVNCTERH\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 22 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: NEAISVEKSIQEQKLNGYGVGSLIKFPVSSTAPTLDAKSFYKYFQLRDTLDDRLTAVTATEVSLEGTTLDPTDYKVDTKGQTVTVTFTAEGLKRIKAAPGKKVSAVFQGKVTEARNGAITNRAQVISDTVYAEQPPTPEEPPANPENPPTSNEVTSRWGDLLIKKVDNHQQGQDKAGLQGAQFQLYKAKNAYAGTCTKDKEGDPIAINGETTLTTDAQGAINVKGLFISDSIDGANRDNQKDATARCYVLVETKAPAGYVLPAGDGAVTPVKIEVGAVTTDNVTIENTKQ\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 23 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MVTSPPTSSPSQRSYSPQDWLRGYQSQPQEWDYWVEDVEGSIPPDLQGTLYRNGPGLLEIGDRPLKHPFDGDGMVTAFKFPGDGRVHFQSKFVRTQGYVEEQKAGKMIYRGVFGSQPAGGWLKTIFDLRLKNIANTNITYWGDRLLALWQGGQPHRLEPSNLATIGLDDLGGILAEGQPLSAHPRIDPASTFDGGQPCYVTFSIKSSLSSTLTLLELDPQGKLLRQKTETFPGFAFIHDFAITPHYAIFLQNNVTLNGLPYLFGLRGAGECVQFHPDKPAQIILVPRDGGEIKRIPVQAGFVFHHANAFEENGKIILDSICYNSLPQVDTDGDFRSTNFDNLDPGQLWRFTIDPAAATVEKQLMVSRCCEFPVVHPQQVGRPYRYVYMGAAHHSTGNAPLQAILKVDLESGTETLRSFAPHGFAGEPIFVPRPGGVAEDDGWLLCLIYKADLHRSELVILDAQDITAPAIATLKLKHHIPYPLHGSWAQT\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 24 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MAAGSITTLPALPEDGGSGAFPPGHFKDPKRLYCKNGGFFLRIHPDGRVDGVREKSDPHIKLQLQAEERGVVSIKGVSANRYLAMKEDGRLLASKSVTDECFFFERLESNNYNTYRSRKYTSWYVALKRTGQYKLGSKTGPGQKAILFLPMSAKS\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 25 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: GPMGGFQRGKYGTMAEGRSEDNLSATPPALRIILVGKTGCGKSATGNSILGQPVFESKLRAQSVTRTCQVKTGTWNGRKVLVVDTPSIFESQADTQELYKNIGDCYLLSAPGPHVLLLVIQLGRFTAQDTVAIRKVKEVFGTGAMRHVVILFTHKEDLGGQALDDYVANTDNCSLKDLVRECERRYCAFNNWGSVEEQRQQQAELLAVIERLGREREGSFHSNDLFLDAQLLQRTGAGACQEDYRQYQAKVEWQVEKHKQELRENESNWAYKALLRVK\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 26 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: METVEMVAIATVAGLLSLATVTGNILLMLSIKVNRQLQTVNNYFAFSLACADLIIGAFSMNLYTVYIIMGHWALGALACDLALALDYVASNAAVMNLLLISFDRYFSVTRPLSYRAKRTPRRALLMIGLAWLVSFVLWAPAILFWQYLVGERTVLAGQCYIQFLSQPIITFGTAMATFYLPVTVMCTLYWRIYRETENRANIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSELDKAIGRNTNGVITKDEAEKLFNQDVDAAVRGILRNAKLKPVYDSLDAVRRAALINMVFQMGETGVAGFTNSLRMLQQKRWDEAAVNLAKSRWYNQTPNRAKRVITTFRTGTWDAYTFSLVKEKAALRTLSAILLAFILTWTPYNIMVLVSTFCKDCVPETLWELGYWLCYVNATINPMCYALCNKAFRDTFRLLLLARWDHHHHHHHHHH\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 27 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: GPGSEFMADDDVLFEDVYELCEVIGKGPFSVVRRCINRETGQQFAVKIVDVAKFTSSPGLSTEDLKREASICHMLKHPHIVELLETYSSDGMLYMVFEFMDGADLCFEIVKRADAGFVYSEAVASHYMRQILEALRYCHDNNIIHRDVKPHCVLLASKENSAPVKLGGFGVAIQLGESGLVAGGRVGTPHFMAPEVVKREPYGKPVDVWGCGVILFILLSGCLPFYGTKERLFEGIIKGKYKMNPRQWSHISESAKDLVRRMLMLDPAERITVYEALNHPWLKERDRYAYKIHLPETVEQLRKFNARRKLKGAVLAAVSSHKFNS\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 28 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: TPEAPYASLTEIEHLVQSVSKSYRETCQLRLEDLLRQRSNIFSREEVTGYQRKSMWEMWERCAHHLTEAIQYVVEFAKRLSGFMELSQNDQIVLLKAGAMEVVLVRMCRAYNADNRTVFFEGKYGGMELFRALGCSELISSIFDFSHSLSALHFSEDEIALYTALVLINAHRPGLQEKRKVEQLQYNLELAFHHHLCKTHRQSILAKLPPKGKLRSLCSQHVERLQIFQHLHPIVVQAAFPPLYKELFSTETESPVGLS\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 29 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: QVQLVQSGAEVKKPGASVKVSCKASGYTFTGYYMHWVRQAPGQGLEWMGWINPNSGGTNYAQKFQGRVTMTRDTSISTAYMELSRLRSDDTAVYYCARDGSGDDTSWHLHPWGQGTLVTVSSASTKGPSVFPLAPSSKSTSGGTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKRVEPKSCDKTHHHHHH\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 30 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: QDGEALFKSKPCAACHSIDAKMVGPALKEVAAKYAGQEGAADLLAGHIKNGTQGNWGPIPMPPNPVTEEEAKTLAEWVLSLK\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 31 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: LTNSSLMPTLNPMIQQLALAIAASWQSLPLKPYQLPEDLGYVEGRLEGEKLVIENRCYQTPQFRKMQLELAKVGKGLDILHCVMFPEPLYGLPLFGCDIVAGPGGVSAAIADLSPTQSDRQLPAAYQKSLAELGQPEFEQQRELPPWGEIFSEYCLFIRPSNVTEEERFVQRVVDFLQIHCHQSIVAEPLSEAQTLEHRQGQIHYCQQQQKNDKTRRVLEKAFGEAWAERYMSQVLFDV\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 32 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: GIQGLAKLIADVAPSAIRENDIKSYFGRKVAIDASMSIYQFLIAVRQGGDVLQNEEGETTSHLMGMFYRTIRMMENGIKPVYVFDGKPPQLKSGELAKRSERRAEAEKQLQQAQAAGAEQEVEKFTKRLVKVTKQHNDECKHLLSLMGIPYLDAPSEAEASCAALVKAGKVYAAATEDMDCLTFGSPVLMRHLTASEAKKLPIQEFHLSRILQELGLNQEQFVDLCILLGSNYCESIRGIGPKRAVDLIQKHKSIEEIVRRLDPNKYPVPENWLHKEAHQLFLEPEVLDPESVELKWSEPNEEELIKFMCGEKQFSEERIRSGVKRLSKSRQGSTLEVLFQ\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 33 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MHHHHHHSSGVDLGTENLYFQSMVDATRVPMDERFRTLKKKLEEGMVFTEYEQIPKKKANGIFSTAALPENAERSRIREVVPYEENRVELIPTKENNTGYINASHIKVVVGGAEWHYIATQGPLPHTCHDFWQMVWEQGVNVIAMVTAEEEGGRTKSHRYWPKLGSKHSSATYGKFKVTTKFRTDSVCYATTGLKVKHLLSGQERTVWHLQYTDWPDHGCPEDVQGFLSYLEEIQSVRRHTNSMLEGTKNRHPPIVVHCSAGVGRTGVLILSELMIYCLEHNEKVEVPMMLRLLREQRMFMIQTIAQYKFVYQVLIQFLQNSRLI\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 34 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: QAQITGRPEWIWLALGTALMGLGTLYFLVKGMGVSDPDAKKFYAITTLVAAIAFTMYLSMLLGYGLTMVPFGGEQNPIYWARYADWLFTTPLLLLDLALLVDADQGTILALVGADGIMIGTGLVGALTKVYSYRFVWWAISTAAMLYILYVLFFGFTSKAESMRPEVASTFKVLRNVTVVLWSAYPVVWLIGSEGAGIVPLNIETLLFMVLDVSAKVGFGLILLRSRAIFGEAEAPEPSAGDGAAATSD\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 35 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MEKDGLCRADQQYECVAEIGEGAYGKVFKARDLKNGGRFVALKRVRVQTGEEGMPLSTIREVAVLRHLETFEHPNVVRLFDVCTVSRTDRETKLTLVFEHVDQDLTTYLDKVPEPGVPTETIKDMMFQLLRGLDFLHSHRVVHRDLKPQNILVTSSGQIKLADFGLARIYSFQMALTSVVVTLWYRAPEVLLQSSYATPVDLWSVGCIFAEMFRRKPLFRGSSDVDQLGKILDVIGLPGEEDWPRDVALPRQAFHSKSAQPIEKFVTDIDELGKDLLLKCLTFNPAKRISAYSALSHPYFQDLERCKENLDSHLPPSQNTSELNTA\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 36 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: AFELPPLPYAHDALQPHISKETLEFHHDKHHNTYVVNLNNLVPGTEFEGKTLEEIVKTSSGGIFNNAAQVWNHTFYWNCLSPNAGGQPTGALADAINAAFGSFDKFKEEFTKTSVGTFGSGWGWLVKKADGSLALASTIGAGCPLTIGDTPLLTCDVWEHAYYIDYRNLRPKYVEAFWNLVNWAFVAEQFEGKTYKV\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 37 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: AVGIGAVFLGFLGAAGSTMGAASMTLTVQARNLLSGIVQQQSNLLRAIEAQQHLLKLTVWGIKQLQARVLAVERYLRDQQLLGIWGCSGKLICCTNVPWNSSWSNRNLSEIWDNMTWLQWDKEISNYTQIIYGLLEESQNQQEKNEQDLLALD\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 38 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: SNADLLVPKLTASVTDGAVGVTVDAPVSVTAADGVLAAVTMVNDNGRPVAGRLSPDGLRWSTTEQLGYNRRYTLNATALGLGGAATRQLTFQTSSPAHLTMPYVMPGDGEVVGVGEPVAIRFDENIADRGAAEKAIKITTNPPVEGAFYWLNNREVRWRPEHFWKPGTAVDVAVNTYGVDLGEGMFGEDNVQTHFTIGDEVIATADDNTKILTVRVNGEVVKSMPTSMGKDSTPTANGIYIVGSRYKHIIMDSSTYGVPVNSPNGYRTDVDWATQISYSGVFVHSAPWSVGAQGHTNTSHGCLNVSPSNAQWFYDHVKRGDIVEVVNTVGGTLPGIDGLGDWNIPWDQWRAGNAKA\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 39 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: QAPTAVLNGNEVISGVLEGKVDTFKGIPFADPPLNDLRFKHPQPFTGSYQGLKANDFSPACMQLDPGNSLTLLDKALGLAKVIPEEFRGPLYDMAKGTVSMNEDCLYLNVFRPAGTKPDAKLPVMVWIYGGAFVYGSSAAYPGNSYVKESINMGQPVVFVSINYRTGPFGFLGGDAITAEGNTNAGLHDQRKGLEWVSDNIANFGGDPDKVMIFGESAGAMSVAHQLIAYGGDNTYNGKKLFHSAILQSGGPLPYHDSSSVGPDISYNRFAQYAGCDTSASANDTLECLRSKSSSVLHDAQNSYDLKDLFGLLPQFLGFGPRPDGNIIPDAAYELFRSGRYAKVPYISGNQEDEGTAFAPVALNATTTPHVKKWLQYIFYDASEASIDRVLSLYPQTLSVGSPFRTGILNALTPQFKRVAAILSDMLFQSPRRVMLSATKDVNRWTYLSTHLHNLVPFLGTFHGNELIFQFNVNIGPANSYLRYFISFANHHDPNVGTNLLQWDQYTDEGKEMLEIHMTDNVMRTDDYRIEGISNFETDVNLYG\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 40 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MTGSLNRHSLLNGVKKMRIILCDTNEVVTNLWQESIPHAYIQNDKYLCIHHGHLQSLMDSMRKGDAIHHGHSYAIVSPGNSYGYLGGGFDKALYNYFGGKPFETWFRNQLGGRYHTVGSATVVDLQRCLEEKTIECRDGIRYIIHVPTVVAPSAPIFNPQNPLKTGFEPVFNAMWNALMHSPKDIDGLIIPGLCTGYAGVPPIISCKSMAFALRLYMAGDHISKELKNVLIMYYLQYPFEPFFPESCKIECQKLGIDIEMLKSFNVEKDAIELLIPRRILTLDL\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 41 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: GSSGSSGQPPLKNLLSLLKAYYALNAQPSAEELSKIADSVNLPLDVVKKWFEKMQAGQISVQSS\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 42 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: YKLICYYTSWSQYREGDGSCFPDAIDPFLCTHVIYSFANISNNEIDTWEWNDVTLYDTLNTLKNRNPNLKTLLSVGGWNFGSERFSKIASKTQSRRTFIKSVPPFLRTHGFDGLDLAWLYPGWRDKRHLTTLVKEMKAEFVREAQAGTEQLLLSAAVPAGKIAIDRGYDIAQISRHLDFISLLTYDFHGGWRGTVGHHSPLFRGNSDGSSRFSNADYAVSYMLRLGAPANKLVMGIPTFGRSYTLASSKTDVGAPISGPGIPGQFTKEKGTLAYYEICDFLHGATTHRFRDQQVPYATKGNQWVAYDDQESVKNKARYLKNRQLAGAMVWALDLDDFRGTFCGQNLTFPLTSAIKDVLARV\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 43 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MSIEPVFILVRPQMGENIGAAARAMLNFGLGRLRIVDPRDGWPNPKAVAMASGAGRLLDHAGLFPTVAEAIRDCDYVFATTARGRELTKPVMTPERAMAHGRALTGEGRRVGILFGPERTGLENEDVALANAIVTVPVNPEFFSLNLAQCVLLLAYEWRRQHDETPPEVIDMARVDFASGLEVEKLGDHFEEKLEAAGFFFPPEKAPGMKLNLRNMWARLPLTRADVQTLHGMLRQIAWKLKQENLYFQ\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 44 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: GIKQYSQEELKEMALVEIAHELFEEHKKPVPFQELLNEIASLLGVKKEELGDRIAQFYTDLNIDGRFLALSDQTWGLRSWYPYDQLDEETQPTVKAKKKKAKKAVEEDLDLDEFEEIDEDDLDLDEVEEELDLEADDFDEEDLDEDDDDLEIEEDIIDEDDEDYDDEEEEIK\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 45 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MGSSHHHHHHSSGLVPRGSHMENLYFQSSPNYDKWEMERTDITMKHKLGGGQYGEVYEGVWKKYSLTVAVKTLKEDTMEVEEFLKEAAVMKEIKHPNLVQLLGVCTREPPFYIITEFMTYGNLLDYLRECNRQEVNAVVLLYMATQISSAMEYLEKKNFIHRDLAARNCLVGENHLVKVADFGLSRLMTGDTYTAHAGAKFPIKWTAPESLAYNKFSIKSDVWAFGVLLWEIATYGMSPYPGIDLSQVYELLEKDYRMERPEGCPEKVYELMRACWQWNPSDRPSFAEIHQAFETMFQESSISDEVEKEL\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 46 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MKIEEGKLVIWINGDKGYNGLAEVGKKFEKDTGIKVTVEHPDKLEEKFPQVAATGDGPDIIFWAHDRFGGYAQSGLLAEITPDKAFQDKLYPFTWDAVRYNGKLIAYPIAVEALSLIYNKDLLPNPPKTWEEIPALDKELKAKGKSALMFNLQEPYFTWPLIAADGGYAFKYENGKYDIKDVGVDNAGAKAGLTFLVDLIKNKHMNADTDYSIAEAAFNKGETAMTINGPWAWSNIDTSKVNYGVTVLPTFKGQPSKPFVGVLSAGINAASPNKELAKEFLENYLLTDEGLEAVNKDKPLGAVALKSYEEELAKDPRIAATMENAQKGEIMPNIPQMSAFWYAVRTAVINAASGRQTVDEALKDAQTNAAAAAGCKAA\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 47 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MRVIFSEDHKLRNAKTELYGGELVPPFEAPFRAEWILAAVKEAGFDDVVAPARHGLETVLKVHDAGYLNFLETAWDRWKAAGYKGEAIATSFPVRRTSPRIPTDIEGQIGYYCNAAETAISPGTWEAALSSMASAIDGADLIAAGHKAAFSLCRPPGHHAGIDMFGGYCFINNAAVAAQRLLDKGAKKIAILDVDFHHGNGTQDIFYERGDVFFASLHGDPAEAFPHFLGYAEETGKGAGAGTTANYPMGRGTPYSVWGEALTDSLKRIAAFGAEAIVVSLGVDTFEQDPISFFKLTSPDYITMGRTIAASGVPLLVVMEGGYGVPEIGLNVANVLKGVAG\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 48 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MAAECATQRAPGSVVELLGKSYPQDDHSNLTRKVLTRVGRNLHNQQHHPLWLIKERVKEHFYKQYVGRFGTPLFSVYDNLSPVVTTWQNFDSLLIPADHPSRKKGDNYYLNRTHMLRAHTSAHQWDLLHAGLDAFLVVGDVYRRDQIDSQHYPIFHQLEAVRLFSKHELFAGIKDGESLQLFEQSSRSAHKQETHTMEAVKLVEFDLKQTLTRLMAHLFGDELEIRWVDCYFPFTHPSFEMEINFHGEWLEVLGCGVMEQQLVNSAGAQDRIGWAFGLGLERLAMILYDIPDIRLFWCEDERFLKQFCVSNINQKVKFQPLSKYPAVINDISFWLPSENYAENDFYDLVRTIGGDLVEKVDLIDKFVHPKTHKTSHCYRITYRHMERTLSQREVRHIHQALQEAAVQLLGVEGRF\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 49 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: ATINNAGFESGFSNWNETDPAAISSDAYSGSKSLKIQGSPARVYQVVDIQPNTEYTLSAYVLGKGQIGVNDLNGLFKNQTFNVSSWTKVTKTFTSANTNSLQVFAKHYNNTSDVRFDNFSLIEGSGSNDGGSDGGSDNSNGSTIPSSITSGSIFDLEGDNPNPLVDDSTLVFVPLEAQHITPNGNGWRHEYKVKESLRVAMTQTYEVFEATVKVEMSDGGKTIISQHHASDTGTISKVYVSDTDESGFNDSVANNGIFDVYVRLRNTSGNEEKFALGTMTSGETFNLRVVNNYGDVEVTAFGNSFGIPVEDDSQSYFKFGNYLQSQDPYTLDKCGEAGNSNSFKNCFEDLGITESKVTMTNVSYTRETNHHHHHH\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 50 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: KIQVKCSAPNSVTITNASGGLYLVEYPEGYVAYSKATEVTGKLVHANFGTKKDFEDLDYAVNGSIVIVRAGKITIAEKVANAQSFNAIGVLIYKDRTKYPISRADEPLQGHSGLPSIPVQTISREAAEKLFQNMERDCPRSWNTDSSCKLELLQNRNVKLTVNNCLKEGTSGHHHHHH\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 51 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MDIKINDITLGNNSPFVLFGGINVLESLDSTLQTCAHYVEVTRKLGIPYIFKASFDKANRSSIHSYRGVGLEEGLKIFEKVKAEFGIPVITDVHEPHQCQPVAEVCDVIQLPAFLARQTDLVVAMAKTGNVVNIKKPQFLSPSQMKNIVEKFHEAGNGKLILCERGSSFGYDNLVVDMLGFGVMKQTCGNLPVIFDVTHSLATRDAGSAASGGRRAQALDLALAGMATRLAGLFLESHPDPKLAKCDGPSALPLHLLEDFLIRIKALDDLIKSQPILTIE\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 52 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MPPGKCLFSGVFCNMAEKQTAKRNRREEILQSLALMLESSDGSQRITTAKLAASVGVSEAALYRHFPSKTRMFDSLIEFIEDSLITRINLILKDEKDTTARLRLIVLLLLGFGERNPGLTRILTGHALMFEQDRLQGRINQLFERIEAQLRQVLREKRMREGEGYATDETLLASQILAFCEGMLSRFVRSEFKYRPTDDFDARWPLIAAQLQ\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 53 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: GQRQLLLASENPQQFMDYFSEEFRNDFLELLRRRFGTKRVHNNIVYNEYISHREHIHMNATQWETLTDFTKWLGREGLCKVDETPKGWYIQYIDRDPETIRRQLELEKKKK\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 54 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MSYYHHHHHHLESTSLYKKAGFMSLQRIVRVSLEHPTSAVCVAGVETLVDIYGSVPEGTEMFEVYGTPGVDIYISPNMERGRERADTRRWRFDATLEIIVVMNSPSNDLNDSHVQISYHSSHEPLPLAYAVLYLTCVDISLDCDLNCEGRQDRNFVDKRQWVWGPSGYGGILLVNCDRDDPSCDVQDNCDQHVHCLQDLEDMSVMVLRTQGPAALFDDHKLVLHTSSYDAKRAQVFHICGPEDVCEAYRHVLGQDKVSYEVPRLHGDEERFFVEGLSFPDAGFTGLISFHVTLLDDSNEDFSASPIFTDTVVFRVAPWIMTPSTLPPLEVYVCRVRNNTCFVDAVAELARKAGCKLTICPQAENRNDRWIQDEMELGYVQAPHKTLPVVFDSPRNGELQDFPYKRILGPDFGYVTREPRDRSVSGLDSFGNLEVSPPVVANGKEYPLGRILIGGNLPGSSGRRVTQVVRDFLHAQKVQPPVELFVDWLAVGHVDEFLSFVPAPDGKGFRMLLASPGACFKLFQEKQKCGHGRALLFQGVVDDEQVKTISINQVLSNKDLINYNKFVQSCIDWNREVLKRELGLAECDIIDIPQLFKTERKKATAFFPDLVNMLVLGKHLGIPKPFGPIINGCCCLEEKVRSLLEPLGLHCTFIDDFTPYHMLHGEVHCGTNVCRKPFSFKWWNMVP\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 55 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: MANGANRVDLDGKPIQPLTICMIGAGGFIGSHLCEKLLTETPHKVLALDVYNDKIKHLLEPDTVEWSGRIQFHRINIKHDSRLEGLVKMADLIINLAAICTPADYNTRPLDTIYSNFIDALPVVKYCSENNKRLIHFSTCEVYGKTIGSFLPKDHPLRDDPAFYVLKEDISPCIFGSIEKQRWSYACAKQLIERLVYAEGAENGLEFTIVRPFNWIGPRMDFIPGIDGPSEGVPRVLACFSNNLLRREPLKLVDGGESQRTFVYINDAIEAVLLMIENPERANGHIFNVGNPNNEVTVRQLAEMMTEVYAKVSGEGAIESPTVDVSSKEFYGEGYDDSDKRIPDMTIINRQLGWNPKTSLWDLLESTLTYQHRTYAEAVKKATSKPVAS\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
| 56 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: PISPIETVPVKLKPGMDGPKVKQWPLTEEKIKALVEICTELEKEGKISKIGPENPYNTPVFAIKKKDSTKWRKLVDFRELNKRTQDFWEVQLGIPHPAGLKKKKSVTVLDVGDAYFSVPLDEDFRKYTAFTIPSINNETPGIRYQYNVLPQGWKGSPAIFQSSMTKILEPFRKQNPDIVIYQYMDDLYVGSDLEIGQHRTKIEELRQHLLRWGLYTPDKKHQKEPPFLWMGYELHPDKWTVQPIVLPEKDSWTVNDIQKLVGKLNWASQIYPGIKVRQLCKLLRGTKALTEVIPLTEEAELELAENREILKEPVHGVYYDPSKDLIAEIQKQGQGQWTYQIYQEPFKNLKTGKYARMRGAHTNDVKQLTEAVQKITTESIVIWGKTPKFKLPIQKETWETWWTEYWQATWIPEWEFVNTPPLVKLWYQLEKEPIVGAETFYVDGAANRETKLGKAGYVTNRGRQKVVTLTDTTNQKTELQAIYLALQDSGLEVNIVTDSQYALGIIQAQPDQSESELVNQIIEQLIKKEKVYLAWVPAHKGIGGNEQVDKLVSAGIRKVL\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "1"}]}
|
| 57 |
+
{"messages": [{"role": "user", "content": "Based on the following protein's amino acid sequence, is the protein located on the membrane?.\nProtein amino acid sequence: HHHHHHFNLPPGNYKKGGTVDGTRDRSDTHIQFQISPEGNGEVLLKSTETGQYLRINPDGTVDGTRDRSDTHIQFQISPEGNGEVLLKSTETGQYLRINPDGTVDGTRDRSDTHIQFQISPEGNGEVLLKSTETGQYLRINP\nOptions:\n0.\"Yes\"\n1.\"No\"\n"}, {"role": "assistant", "content": "0"}]}
|
BIO/sft/qwen-production-08022302/v0-20250802-230250/images/train_train_steps_per_second.png
ADDED
|
BioReason/.gitignore
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
.idea/
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.py[cod]
|
| 5 |
+
*$py.class
|
| 6 |
+
wandb/
|
| 7 |
+
.DS_Store
|
| 8 |
+
.vscode/
|
| 9 |
+
.venv/
|
| 10 |
+
.env
|
| 11 |
+
.pytest_cache/
|
| 12 |
+
|
| 13 |
+
# C extensions
|
| 14 |
+
*.so
|
| 15 |
+
|
| 16 |
+
outputs/
|
| 17 |
+
|
| 18 |
+
# Distribution / packaging
|
| 19 |
+
.Python
|
| 20 |
+
build/
|
| 21 |
+
develop-eggs/
|
| 22 |
+
dist/
|
| 23 |
+
downloads/
|
| 24 |
+
eggs/
|
| 25 |
+
.eggs/
|
| 26 |
+
lib/
|
| 27 |
+
lib64/
|
| 28 |
+
parts/
|
| 29 |
+
sdist/
|
| 30 |
+
var/
|
| 31 |
+
wheels/
|
| 32 |
+
share/python-wheels/
|
| 33 |
+
*.egg-info/
|
| 34 |
+
.installed.cfg
|
| 35 |
+
*.egg
|
| 36 |
+
MANIFEST
|
| 37 |
+
|
| 38 |
+
# PyInstaller
|
| 39 |
+
# Usually these files are written by a python script from a template
|
| 40 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 41 |
+
*.manifest
|
| 42 |
+
*.spec
|
| 43 |
+
|
| 44 |
+
# Installer logs
|
| 45 |
+
pip-log.txt
|
| 46 |
+
pip-delete-this-directory.txt
|
| 47 |
+
|
| 48 |
+
# Unit test / coverage reports
|
| 49 |
+
htmlcov/
|
| 50 |
+
.tox/
|
| 51 |
+
.nox/
|
| 52 |
+
.coverage
|
| 53 |
+
.coverage.*
|
| 54 |
+
.cache
|
| 55 |
+
nosetests.xml
|
| 56 |
+
coverage.xml
|
| 57 |
+
*.cover
|
| 58 |
+
*.py,cover
|
| 59 |
+
.hypothesis/
|
| 60 |
+
.pytest_cache/
|
| 61 |
+
cover/
|
| 62 |
+
|
| 63 |
+
# Translations
|
| 64 |
+
*.mo
|
| 65 |
+
*.pot
|
| 66 |
+
|
| 67 |
+
# Django stuff:
|
| 68 |
+
*.log
|
| 69 |
+
local_settings.py
|
| 70 |
+
db.sqlite3
|
| 71 |
+
db.sqlite3-journal
|
| 72 |
+
|
| 73 |
+
# Flask stuff:
|
| 74 |
+
instance/
|
| 75 |
+
.webassets-cache
|
| 76 |
+
|
| 77 |
+
# Scrapy stuff:
|
| 78 |
+
.scrapy
|
| 79 |
+
|
| 80 |
+
# Sphinx documentation
|
| 81 |
+
docs/_build/
|
| 82 |
+
|
| 83 |
+
# PyBuilder
|
| 84 |
+
.pybuilder/
|
| 85 |
+
target/
|
| 86 |
+
|
| 87 |
+
# Jupyter Notebook
|
| 88 |
+
.ipynb_checkpoints
|
| 89 |
+
|
| 90 |
+
# IPython
|
| 91 |
+
profile_default/
|
| 92 |
+
ipython_config.py
|
| 93 |
+
|
| 94 |
+
# pyenv
|
| 95 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 96 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 97 |
+
# .python-version
|
| 98 |
+
|
| 99 |
+
# pipenv
|
| 100 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 101 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 102 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 103 |
+
# install all needed dependencies.
|
| 104 |
+
#Pipfile.lock
|
| 105 |
+
|
| 106 |
+
# UV
|
| 107 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 108 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 109 |
+
# commonly ignored for libraries.
|
| 110 |
+
#uv.lock
|
| 111 |
+
|
| 112 |
+
# poetry
|
| 113 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 114 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 115 |
+
# commonly ignored for libraries.
|
| 116 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 117 |
+
#poetry.lock
|
| 118 |
+
|
| 119 |
+
# pdm
|
| 120 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 121 |
+
#pdm.lock
|
| 122 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 123 |
+
# in version control.
|
| 124 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 125 |
+
.pdm.toml
|
| 126 |
+
.pdm-python
|
| 127 |
+
.pdm-build/
|
| 128 |
+
|
| 129 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 130 |
+
__pypackages__/
|
| 131 |
+
|
| 132 |
+
# Celery stuff
|
| 133 |
+
celerybeat-schedule
|
| 134 |
+
celerybeat.pid
|
| 135 |
+
|
| 136 |
+
# SageMath parsed files
|
| 137 |
+
*.sage.py
|
| 138 |
+
|
| 139 |
+
# Environments
|
| 140 |
+
.env
|
| 141 |
+
.venv
|
| 142 |
+
env/
|
| 143 |
+
venv/
|
| 144 |
+
ENV/
|
| 145 |
+
env.bak/
|
| 146 |
+
venv.bak/
|
| 147 |
+
|
| 148 |
+
# Spyder project settings
|
| 149 |
+
.spyderproject
|
| 150 |
+
.spyproject
|
| 151 |
+
|
| 152 |
+
# Rope project settings
|
| 153 |
+
.ropeproject
|
| 154 |
+
|
| 155 |
+
# mkdocs documentation
|
| 156 |
+
/site
|
| 157 |
+
|
| 158 |
+
# mypy
|
| 159 |
+
.mypy_cache/
|
| 160 |
+
.dmypy.json
|
| 161 |
+
dmypy.json
|
| 162 |
+
|
| 163 |
+
# Pyre type checker
|
| 164 |
+
.pyre/
|
| 165 |
+
|
| 166 |
+
# pytype static type analyzer
|
| 167 |
+
.pytype/
|
| 168 |
+
|
| 169 |
+
# Cython debug symbols
|
| 170 |
+
cython_debug/
|
| 171 |
+
|
| 172 |
+
# PyCharm
|
| 173 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 174 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 175 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 176 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 177 |
+
#.idea/
|
| 178 |
+
|
| 179 |
+
# PyPI configuration file
|
| 180 |
+
.pypirc
|
BioReason/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
BioReason/README.md
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center">
|
| 2 |
+
🧬 BioReason<br>Incentivizing Multimodal Biological Reasoning<br>within a DNA-LLM Model
|
| 3 |
+
</h1>
|
| 4 |
+
|
| 5 |
+
<p align="center">
|
| 6 |
+
<a href="https://www.arxiv.org/abs/2505.23579" target="_blank"><img src="https://img.shields.io/badge/arXiv-2505.23579-FF6B6B?style=for-the-badge&logo=arxiv&logoColor=white" alt="arXiv"></a>
|
| 7 |
+
<a href="https://github.com/bowang-lab/BioReason"><img src="https://img.shields.io/badge/GitHub-Code-4A90E2?style=for-the-badge&logo=github&logoColor=white" alt="GitHub"></a>
|
| 8 |
+
<a href="https://bowang-lab.github.io/BioReason/"><img src="https://img.shields.io/badge/Website-Online-00B89E?style=for-the-badge&logo=internet-explorer&logoColor=white" alt="Website"></a>
|
| 9 |
+
<a href="https://huggingface.co/collections/wanglab/bioreason-683cd17172a037a31d208f70"><img src="https://img.shields.io/badge/HuggingFace-Dataset-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Dataset"></a>
|
| 10 |
+
</p>
|
| 11 |
+
|
| 12 |
+
<br>
|
| 13 |
+
|
| 14 |
+
## Updates [Jun 10, 2025]
|
| 15 |
+
- We are integrating vLLM to improve the speed and efficiency of the GRPO pipeline. We expect this to be pushed by end of week.
|
| 16 |
+
- Checkpoints along with the custom DNA-LLM model class will be released on HuggingFace by end of week.
|
| 17 |
+
- More training results with GRPO will be shared soon.
|
| 18 |
+
|
| 19 |
+
<br>
|
| 20 |
+
|
| 21 |
+
## Abstract
|
| 22 |
+
|
| 23 |
+
Unlocking deep, interpretable biological reasoning from complex genomic data is a major AI challenge hindering scientific discovery. Current DNA foundation models, despite strong sequence representation, struggle with multi-step reasoning and lack inherent transparent, biologically intuitive explanations. We introduce BioReason, a pioneering architecture that, for the first time, deeply integrates a DNA foundation model with a large language model (LLM). This novel connection enables the LLM to directly process and reason with genomic information as a fundamental input, fostering a new form of multimodal biological understanding. BioReason's sophisticated multi-step reasoning is developed through supervised fine-tuning and targeted reinforcement learning, guiding the system to generate logical, biologically coherent deductions. On biological reasoning benchmarks including KEGG-based disease pathway prediction—where accuracy improves from 88% to 97%—and variant effect prediction, BioReason demonstrates an average 15% performance gain over strong single-modality baselines.
|
| 24 |
+
|
| 25 |
+
<br>
|
| 26 |
+
|
| 27 |
+
## Key Contributions
|
| 28 |
+
|
| 29 |
+
• **Novel multimodal architecture**: The first successful integration of a DNA foundation model with an LLM, establishing a new methodology for AI-driven biological studies.
|
| 30 |
+
|
| 31 |
+
• **Advanced reasoning methodology**: A systematic training approach combining supervised fine-tuning and reinforcement learning that incentivizes multi-step biological reasoning.
|
| 32 |
+
|
| 33 |
+
• **New biological reasoning benchmarks**: Development and curation of novel benchmarks for evaluating biological reasoning capabilities, including an annotated reasoning dataset for gene pathway and disease prediction from KEGG.
|
| 34 |
+
|
| 35 |
+
• **Empirical performance improvements**: Demonstration that BioReason outperforms both DNA foundation models and LLMs used independently or in simple combination, with average performance gains of 15%+ over baseline.
|
| 36 |
+
|
| 37 |
+
• **Interpretable reasoning traces**: A mechanism for generating step-by-step biological reasoning traces that provide interpretable predictions, enhancing scientific insight and hypothesis generation.
|
| 38 |
+
|
| 39 |
+
<br>
|
| 40 |
+
|
| 41 |
+
## Datasets
|
| 42 |
+
|
| 43 |
+
The datasets used to train and evaluate BioReason can be found on our [HuggingFace collection](https://huggingface.co/collections/wanglab/bioreason-683cd17172a037a31d208f70) with detailed download and usage instructions.
|
| 44 |
+
|
| 45 |
+
<br>
|
| 46 |
+
|
| 47 |
+
## Checkpoints
|
| 48 |
+
|
| 49 |
+
We will release the checkpoints soon!
|
| 50 |
+
|
| 51 |
+
<br>
|
| 52 |
+
|
| 53 |
+
## Installation
|
| 54 |
+
|
| 55 |
+
### Prerequisites
|
| 56 |
+
- Python 3.11+
|
| 57 |
+
- CUDA/GPU for best performance
|
| 58 |
+
|
| 59 |
+
### Installation Steps
|
| 60 |
+
```bash
|
| 61 |
+
# Clone the repository
|
| 62 |
+
git clone https://github.com/bowang-lab/BioReason.git
|
| 63 |
+
cd BioReason
|
| 64 |
+
|
| 65 |
+
# Install package
|
| 66 |
+
pip install -e .
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
<br>
|
| 70 |
+
|
| 71 |
+
## Results
|
| 72 |
+
|
| 73 |
+
### KEGG-Derived Biological Reasoning Task
|
| 74 |
+
Performance comparison on 290 test datapoints for multi-step mechanistic reasoning:
|
| 75 |
+
|
| 76 |
+
| Model | Accuracy | F1-Score | Precision | Recall |
|
| 77 |
+
|-------|----------|----------|-----------|---------|
|
| 78 |
+
| [DNA] NT - 500M | 86.55 | 69.76 | 73.23 | 66.61 |
|
| 79 |
+
| [DNA] Evo2 - 1B | 88.28 | 72.43 | 75.23 | 69.83 |
|
| 80 |
+
| [LLM] Qwen3 - 1B | 85.17 | 65.71 | 71.39 | 64.19 |
|
| 81 |
+
| [LLM] Qwen3 - 4B | 93.48 | 85.44 | 88.31 | 86.72 |
|
| 82 |
+
| [DNA-LLM] NT + Qwen3 - 1B | 88.42 | 72.13 | 75.42 | 71.91 |
|
| 83 |
+
| [DNA-LLM] NT + Qwen3 - 1B (+RL) | 89.66 | 74.11 | 78.82 | 72.96 |
|
| 84 |
+
| [DNA-LLM] NT + Qwen3 - 4B | 96.90 | **89.03** | **90.99** | **89.38** |
|
| 85 |
+
| [DNA-LLM] Evo2 + Qwen3 - 1B | 90.42 | 75.62 | 77.42 | 73.91 |
|
| 86 |
+
| [DNA-LLM] Evo2 + Qwen3 - 4B | **97.24** | 86.30 | 86.75 | 87.25 |
|
| 87 |
+
|
| 88 |
+
### Variant Effect Prediction Benchmarks
|
| 89 |
+
Performance on pathogenic/benign classification:
|
| 90 |
+
|
| 91 |
+
| Model | Variant Effect - Coding | | Variant Effect - Non-SNV | |
|
| 92 |
+
|-------|------------|----------|------------|----------|
|
| 93 |
+
| | Accuracy | F1-Score | Accuracy | F1-Score |
|
| 94 |
+
| [DNA] NT - 500M | 60.91 | 45.20 | 67.93 | 65.97 |
|
| 95 |
+
| [DNA] Evo2 - 1B | 70.07 | 49.19 | 76.17 | 66.51 |
|
| 96 |
+
| [LLM] Qwen3 - 1B | 46.55 | 34.82 | 70.67 | 76.21 |
|
| 97 |
+
| [LLM] Qwen3 - 4B | 48.99 | 39.58 | 61.86 | 67.60 |
|
| 98 |
+
| [DNA-LLM] NT + Qwen3 - 1B | 55.58 | 54.50 | 72.82 | 76.93 |
|
| 99 |
+
| [DNA-LLM] NT + Qwen3 - 4B | 60.94 | 55.66 | 65.59 | 73.00 |
|
| 100 |
+
| [DNA-LLM] Evo2 + Qwen3 - 1B | 72.83 | 68.90 | **88.20** | **89.91** |
|
| 101 |
+
| [DNA-LLM] Evo2 + Qwen3 - 4B | **80.21** | **80.00** | 83.85 | 85.02 |
|
| 102 |
+
|
| 103 |
+
<br>
|
| 104 |
+
|
| 105 |
+
## Citation
|
| 106 |
+
|
| 107 |
+
If you find this work useful, please cite our paper:
|
| 108 |
+
|
| 109 |
+
```bibtex
|
| 110 |
+
@misc{fallahpour2025bioreasonincentivizingmultimodalbiological,
|
| 111 |
+
title={BioReason: Incentivizing Multimodal Biological Reasoning within a DNA-LLM Model},
|
| 112 |
+
author={Adibvafa Fallahpour and Andrew Magnuson and Purav Gupta and Shihao Ma and Jack Naimer and Arnav Shah and Haonan Duan and Omar Ibrahim and Hani Goodarzi and Chris J. Maddison and Bo Wang},
|
| 113 |
+
year={2025},
|
| 114 |
+
eprint={2505.23579},
|
| 115 |
+
archivePrefix={arXiv},
|
| 116 |
+
primaryClass={cs.LG},
|
| 117 |
+
url={https://arxiv.org/abs/2505.23579},
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
<br>
|
| 122 |
+
|
| 123 |
+
## Authors
|
| 124 |
+
|
| 125 |
+
- **Adibvafa Fallahpour**¹²³⁵ * (adibvafa.fallahpour@mail.utoronto.ca)
|
| 126 |
+
- **Andrew Magnuson**¹² *
|
| 127 |
+
- **Purav Gupta**¹² *
|
| 128 |
+
- **Shihao Ma**¹²³
|
| 129 |
+
- **Jack Naimer**¹²³
|
| 130 |
+
- **Arnav Shah**¹²³
|
| 131 |
+
- **Haonan Duan**¹²
|
| 132 |
+
- **Omar Ibrahim**³
|
| 133 |
+
- **Hani Goodarzi**†⁴⁶
|
| 134 |
+
- **Chris J. Maddison**†¹²⁷
|
| 135 |
+
- **Bo Wang**†¹²³
|
| 136 |
+
|
| 137 |
+
¹ University of Toronto ² Vector Institute ³ University Health Network (UHN) <br>
|
| 138 |
+
⁴ Arc Institute ⁵ Cohere ⁶ University of California, San Francisco ⁷ Google DeepMind
|
| 139 |
+
|
| 140 |
+
<br>
|
| 141 |
+
* Equal contribution <br>
|
| 142 |
+
† Equal advising
|
| 143 |
+
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
<p align="center">
|
| 147 |
+
Made with ❤️ at University of Toronto, Vector Institute, and University Health Network
|
| 148 |
+
</p>
|
BioReason/bioreason/__init__.py
ADDED
|
File without changes
|
BioReason/bioreason/dataset/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .kegg import KEGGDataset, split_kegg_dataset
|
| 2 |
+
from .utils import torch_to_hf_dataset, truncate_dna
|
| 3 |
+
from .variant_effect import get_format_variant_effect_function
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"KEGGDataset",
|
| 7 |
+
"split_kegg_dataset",
|
| 8 |
+
"torch_to_hf_dataset",
|
| 9 |
+
"truncate_dna",
|
| 10 |
+
"get_format_variant_effect_function",
|
| 11 |
+
]
|
BioReason/bioreason/dataset/kegg.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import sys
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
from typing import Any, Dict, List, Tuple
|
| 8 |
+
|
| 9 |
+
from bioreason.dataset.utils import torch_to_hf_dataset
|
| 10 |
+
from bioreason.models.dl.processing_dl import DLProcessor
|
| 11 |
+
from trl.data_utils import maybe_apply_chat_template
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class KEGGDataset(Dataset):
|
| 15 |
+
"""Dataset for KEGG data."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, data_dir: str):
|
| 18 |
+
"""
|
| 19 |
+
Initialize the dataset by loading all JSON files from the given directory.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
data_dir: Path to the directory containing JSON files
|
| 23 |
+
"""
|
| 24 |
+
self.data_dir = data_dir
|
| 25 |
+
self.data = []
|
| 26 |
+
|
| 27 |
+
# Load all JSON files
|
| 28 |
+
json_files = sorted([f for f in os.listdir(data_dir) if f.endswith(".json")])
|
| 29 |
+
|
| 30 |
+
# Process each file
|
| 31 |
+
for filename in json_files:
|
| 32 |
+
file_path = os.path.join(data_dir, filename)
|
| 33 |
+
kegg_id = filename.split("_")[1]
|
| 34 |
+
|
| 35 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 36 |
+
item = json.load(f)
|
| 37 |
+
item["kegg_id"] = kegg_id
|
| 38 |
+
processed_item = self._process_item(item)
|
| 39 |
+
self.data.append(processed_item)
|
| 40 |
+
|
| 41 |
+
def _process_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
| 42 |
+
"""
|
| 43 |
+
Process a single data item to format fields as required.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
item: Original data item from JSON
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
Processed data item
|
| 50 |
+
"""
|
| 51 |
+
# Extract question as is
|
| 52 |
+
question = item.get("question", "")
|
| 53 |
+
|
| 54 |
+
# Convert answer to lowercase and strip whitespace
|
| 55 |
+
answer = item.get("answer", "").lower().strip()
|
| 56 |
+
|
| 57 |
+
# Combine reasoning steps into a single paragraph with newlines
|
| 58 |
+
reasoning_steps = item.get("reasoning", {}).get("reasoning_steps", [])
|
| 59 |
+
reasoning = "\n".join(reasoning_steps)
|
| 60 |
+
|
| 61 |
+
# Convert sequences to uppercase and strip whitespace
|
| 62 |
+
reference_sequence = item.get("reference_sequence", "").upper().strip()
|
| 63 |
+
variant_sequence = item.get("variant_sequence", "").upper().strip()
|
| 64 |
+
|
| 65 |
+
return {
|
| 66 |
+
"question": question,
|
| 67 |
+
"answer": answer,
|
| 68 |
+
"reasoning": reasoning,
|
| 69 |
+
"reference_sequence": reference_sequence,
|
| 70 |
+
"variant_sequence": variant_sequence,
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
def __len__(self) -> int:
|
| 74 |
+
"""Return the number of items in the dataset."""
|
| 75 |
+
return len(self.data)
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
| 78 |
+
"""Return a specific item from the dataset."""
|
| 79 |
+
return self.data[idx]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def split_kegg_dataset(
|
| 83 |
+
dataset: KEGGDataset,
|
| 84 |
+
train_ratio: float = 0.8,
|
| 85 |
+
val_ratio: float = 0.1,
|
| 86 |
+
test_ratio: float = 0.1,
|
| 87 |
+
seed: int = 42,
|
| 88 |
+
) -> Tuple[KEGGDataset, KEGGDataset, KEGGDataset]:
|
| 89 |
+
"""
|
| 90 |
+
Split a KEGG dataset into train, validation, and test sets.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
dataset: The dataset to split
|
| 94 |
+
train_ratio: Proportion of data for training
|
| 95 |
+
val_ratio: Proportion of data for validation
|
| 96 |
+
test_ratio: Proportion of data for testing
|
| 97 |
+
batch_size: Batch size for the dataloaders
|
| 98 |
+
seed: Random seed for reproducibility
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Tuple of (train_dataset, val_dataset, test_dataset)
|
| 102 |
+
"""
|
| 103 |
+
# Calculate the size of each split
|
| 104 |
+
dataset_size = len(dataset)
|
| 105 |
+
train_size = int(train_ratio * dataset_size)
|
| 106 |
+
val_size = int(val_ratio * dataset_size)
|
| 107 |
+
test_size = dataset_size - train_size - val_size
|
| 108 |
+
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must sum to 1"
|
| 109 |
+
|
| 110 |
+
# Set the random seed
|
| 111 |
+
torch.manual_seed(seed)
|
| 112 |
+
random.seed(seed)
|
| 113 |
+
|
| 114 |
+
# Split the dataset
|
| 115 |
+
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
|
| 116 |
+
dataset, [train_size, val_size, test_size]
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return train_dataset, val_dataset, test_dataset
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def create_kegg_dataloader(
|
| 123 |
+
data_dir: str,
|
| 124 |
+
batch_size: int = 2,
|
| 125 |
+
shuffle: bool = True,
|
| 126 |
+
num_workers: int = 2,
|
| 127 |
+
pin_memory: bool = True,
|
| 128 |
+
) -> DataLoader:
|
| 129 |
+
"""
|
| 130 |
+
Create a DataLoader for the KEGG dataset.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
data_dir: Path to the directory containing JSON files
|
| 134 |
+
batch_size: Batch size for the dataloader
|
| 135 |
+
shuffle: Whether to shuffle the data
|
| 136 |
+
num_workers: Number of worker processes for loading data
|
| 137 |
+
pin_memory: Whether to pin memory for faster data transfer
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
DataLoader for the KEGG dataset
|
| 141 |
+
"""
|
| 142 |
+
dataset = KEGGDataset(data_dir)
|
| 143 |
+
return DataLoader(
|
| 144 |
+
dataset,
|
| 145 |
+
batch_size=batch_size,
|
| 146 |
+
shuffle=shuffle,
|
| 147 |
+
num_workers=num_workers,
|
| 148 |
+
pin_memory=pin_memory,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def get_format_kegg_function(model_name: str) -> Any:
|
| 153 |
+
"""
|
| 154 |
+
Get the appropriate format function for a given model name.
|
| 155 |
+
"""
|
| 156 |
+
if model_name.lower() == "llm":
|
| 157 |
+
return format_kegg_for_llm
|
| 158 |
+
elif model_name.lower() == "dna-llm":
|
| 159 |
+
return format_kegg_for_dna_llm
|
| 160 |
+
else:
|
| 161 |
+
raise ValueError(f"Unsupported model name: {model_name}")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def format_kegg_for_dna_llm(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 165 |
+
"""
|
| 166 |
+
Format a KEGG example into the required chat format for DNA-LLM.
|
| 167 |
+
"""
|
| 168 |
+
return {
|
| 169 |
+
"prompt": [
|
| 170 |
+
{
|
| 171 |
+
"role": "user",
|
| 172 |
+
"content": [
|
| 173 |
+
*({"type": "dna", "text": None} for _ in range(2)),
|
| 174 |
+
{"type": "text", "text": example["question"].strip()},
|
| 175 |
+
],
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"role": "assistant",
|
| 179 |
+
"reasoning_content": example["reasoning"].strip(),
|
| 180 |
+
"content": [
|
| 181 |
+
{"type": "text", "text": f"Answer: {example['answer'].strip()}"},
|
| 182 |
+
],
|
| 183 |
+
},
|
| 184 |
+
],
|
| 185 |
+
"dna_sequences": [
|
| 186 |
+
example["reference_sequence"],
|
| 187 |
+
example["variant_sequence"],
|
| 188 |
+
],
|
| 189 |
+
"answer": example["answer"],
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def format_kegg_for_llm(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 194 |
+
"""
|
| 195 |
+
Format a KEGG example into the required chat format for LLM.
|
| 196 |
+
"""
|
| 197 |
+
question = f"Reference sequence: {example['reference_sequence']}\nVariant sequence: {example['variant_sequence']}\nQuestion: {example['question']}"
|
| 198 |
+
return {
|
| 199 |
+
"prompt": [
|
| 200 |
+
{
|
| 201 |
+
"role": "user",
|
| 202 |
+
"content": [
|
| 203 |
+
*({"type": "dna", "text": None} for _ in range(2)),
|
| 204 |
+
{"type": "text", "text": question.strip()},
|
| 205 |
+
],
|
| 206 |
+
},
|
| 207 |
+
{
|
| 208 |
+
"role": "assistant",
|
| 209 |
+
"reasoning_content": example["reasoning"].strip(),
|
| 210 |
+
"content": [
|
| 211 |
+
{"type": "text", "text": f"Answer: {example['answer'].strip()}"},
|
| 212 |
+
],
|
| 213 |
+
},
|
| 214 |
+
],
|
| 215 |
+
"dna_sequences": [
|
| 216 |
+
"",
|
| 217 |
+
"",
|
| 218 |
+
],
|
| 219 |
+
"answer": example["answer"],
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def qwen_dna_collate_fn(
|
| 224 |
+
examples: List[Dict],
|
| 225 |
+
processor: DLProcessor,
|
| 226 |
+
max_length_text: int,
|
| 227 |
+
max_length_dna: int,
|
| 228 |
+
return_answer_in_batch: bool = False,
|
| 229 |
+
) -> Dict:
|
| 230 |
+
"""
|
| 231 |
+
Custom collate function for Qwen DNA models.
|
| 232 |
+
|
| 233 |
+
Creates a batch with proper labels for supervised fine-tuning where only
|
| 234 |
+
the assistant responses contribute to the loss calculation.
|
| 235 |
+
"""
|
| 236 |
+
prompts_text = [
|
| 237 |
+
maybe_apply_chat_template(example, processor)["prompt"] for example in examples
|
| 238 |
+
]
|
| 239 |
+
batch_dna_sequences = [example["dna_sequences"] for example in examples]
|
| 240 |
+
|
| 241 |
+
batch = processor(
|
| 242 |
+
text=prompts_text,
|
| 243 |
+
batch_dna_sequences=batch_dna_sequences,
|
| 244 |
+
return_tensors="pt",
|
| 245 |
+
padding=True,
|
| 246 |
+
padding_side="left",
|
| 247 |
+
add_special_tokens=False,
|
| 248 |
+
max_length_text=max_length_text,
|
| 249 |
+
max_length_dna=max_length_dna,
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Create labels tensor filled with -100 (ignored in loss calculation)
|
| 253 |
+
labels = torch.full_like(batch["input_ids"], -100)
|
| 254 |
+
|
| 255 |
+
# Get token IDs for special markers
|
| 256 |
+
assistant_start_marker = "<|im_start|>assistant\n"
|
| 257 |
+
im_end_marker = "<|im_end|>"
|
| 258 |
+
|
| 259 |
+
assistant_start_token_ids = processor.tokenizer.encode(
|
| 260 |
+
assistant_start_marker, add_special_tokens=False
|
| 261 |
+
)
|
| 262 |
+
im_end_token_ids = processor.tokenizer.encode(
|
| 263 |
+
im_end_marker, add_special_tokens=False
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Convert token arrays to tensors for faster comparison
|
| 267 |
+
assistant_marker_tensor = torch.tensor(
|
| 268 |
+
assistant_start_token_ids, device=batch["input_ids"].device
|
| 269 |
+
)
|
| 270 |
+
im_end_marker_tensor = torch.tensor(
|
| 271 |
+
im_end_token_ids, device=batch["input_ids"].device
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Get dimensions for easier reference
|
| 275 |
+
assistant_marker_len = len(assistant_start_token_ids)
|
| 276 |
+
im_end_marker_len = len(im_end_token_ids)
|
| 277 |
+
|
| 278 |
+
# For each sequence in the batch
|
| 279 |
+
for i in range(batch["input_ids"].shape[0]):
|
| 280 |
+
input_ids = batch["input_ids"][i]
|
| 281 |
+
seq_len = input_ids.size(0)
|
| 282 |
+
|
| 283 |
+
# Track assistant sections
|
| 284 |
+
assistant_sections = []
|
| 285 |
+
|
| 286 |
+
# Find all assistant start markers
|
| 287 |
+
start_positions = []
|
| 288 |
+
for pos in range(seq_len - assistant_marker_len + 1):
|
| 289 |
+
if torch.all(
|
| 290 |
+
input_ids[pos : pos + assistant_marker_len] == assistant_marker_tensor
|
| 291 |
+
):
|
| 292 |
+
start_positions.append(
|
| 293 |
+
pos + assistant_marker_len
|
| 294 |
+
) # Store position after marker
|
| 295 |
+
|
| 296 |
+
# Find all end markers
|
| 297 |
+
end_positions = []
|
| 298 |
+
for pos in range(seq_len - im_end_marker_len + 1):
|
| 299 |
+
if torch.all(
|
| 300 |
+
input_ids[pos : pos + im_end_marker_len] == im_end_marker_tensor
|
| 301 |
+
):
|
| 302 |
+
end_positions.append(pos) # Store position at start of end marker
|
| 303 |
+
|
| 304 |
+
# Match start and end markers to create sections
|
| 305 |
+
for start_pos in start_positions:
|
| 306 |
+
# Find the next end marker after this start position
|
| 307 |
+
valid_ends = [pos for pos in end_positions if pos > start_pos]
|
| 308 |
+
if valid_ends:
|
| 309 |
+
end_pos = min(valid_ends) # Take the first end marker after start
|
| 310 |
+
# Only include content between markers (not the markers themselves)
|
| 311 |
+
if start_pos < end_pos:
|
| 312 |
+
assistant_sections.append((start_pos, end_pos))
|
| 313 |
+
else:
|
| 314 |
+
# If no end marker, assume the section runs to the end of the sequence
|
| 315 |
+
assistant_sections.append((start_pos, seq_len))
|
| 316 |
+
|
| 317 |
+
# Set labels for all identified assistant sections
|
| 318 |
+
for start_pos, end_pos in assistant_sections:
|
| 319 |
+
if start_pos < end_pos and start_pos < seq_len:
|
| 320 |
+
end_pos = min(end_pos, seq_len) # Safety check
|
| 321 |
+
labels[i, start_pos:end_pos] = input_ids[start_pos:end_pos]
|
| 322 |
+
|
| 323 |
+
# Also mask padding tokens
|
| 324 |
+
labels[batch["input_ids"] == processor.tokenizer.pad_token_id] = -100
|
| 325 |
+
|
| 326 |
+
# Add labels to batch
|
| 327 |
+
batch["labels"] = labels
|
| 328 |
+
|
| 329 |
+
# Add answer to batch
|
| 330 |
+
if return_answer_in_batch:
|
| 331 |
+
batch["answer"] = [example["answer"].strip() for example in examples]
|
| 332 |
+
|
| 333 |
+
return batch
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def dna_collate_fn(
|
| 337 |
+
batch: List[Dict[str, Any]],
|
| 338 |
+
dna_tokenizer: Any,
|
| 339 |
+
label2id: Dict[str, int],
|
| 340 |
+
max_length: int = 2048,
|
| 341 |
+
) -> Dict[str, Any]:
|
| 342 |
+
"""
|
| 343 |
+
Custom collate function for DNA models.
|
| 344 |
+
"""
|
| 345 |
+
ref_sequences = [item["reference_sequence"] for item in batch]
|
| 346 |
+
alt_sequences = [item["variant_sequence"] for item in batch]
|
| 347 |
+
|
| 348 |
+
# Tokenize DNA sequences separately
|
| 349 |
+
tokenized_ref = dna_tokenizer(
|
| 350 |
+
ref_sequences,
|
| 351 |
+
padding=True,
|
| 352 |
+
truncation=True,
|
| 353 |
+
max_length=max_length,
|
| 354 |
+
return_tensors="pt",
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
tokenized_alt = dna_tokenizer(
|
| 358 |
+
alt_sequences,
|
| 359 |
+
padding=True,
|
| 360 |
+
truncation=True,
|
| 361 |
+
max_length=max_length,
|
| 362 |
+
return_tensors="pt",
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Get labels
|
| 366 |
+
labels = []
|
| 367 |
+
for item in batch:
|
| 368 |
+
label = label2id[item["answer"]]
|
| 369 |
+
labels.append(label)
|
| 370 |
+
|
| 371 |
+
# Create labels tensor
|
| 372 |
+
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
| 373 |
+
|
| 374 |
+
tokenized_batch = {
|
| 375 |
+
"ref_ids": tokenized_ref.input_ids,
|
| 376 |
+
"ref_attention_mask": tokenized_ref.attention_mask,
|
| 377 |
+
"alt_ids": tokenized_alt.input_ids,
|
| 378 |
+
"alt_attention_mask": tokenized_alt.attention_mask,
|
| 379 |
+
"labels": labels_tensor,
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
return tokenized_batch
|
BioReason/bioreason/dataset/utils.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import Dataset as HFDataset
|
| 2 |
+
from torch.utils.data import Dataset as TorchDataset
|
| 3 |
+
from typing import Dict, Any, Union, List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def truncate_dna(
|
| 7 |
+
example: Dict[str, Any], truncate_dna_per_side: int = 1024
|
| 8 |
+
) -> Dict[str, Any]:
|
| 9 |
+
"""
|
| 10 |
+
Truncate DNA sequences by removing a specified number of base pairs from both ends.
|
| 11 |
+
If the sequence is too short, it will return the middle portion.
|
| 12 |
+
"""
|
| 13 |
+
for key in ["reference_sequence", "variant_sequence"]:
|
| 14 |
+
sequence = example[key]
|
| 15 |
+
seq_len = len(sequence)
|
| 16 |
+
|
| 17 |
+
if seq_len > 2 * truncate_dna_per_side + 8:
|
| 18 |
+
example[key] = sequence[truncate_dna_per_side:-truncate_dna_per_side]
|
| 19 |
+
|
| 20 |
+
return example
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def torch_to_hf_dataset(torch_dataset: TorchDataset) -> HFDataset:
|
| 24 |
+
"""
|
| 25 |
+
Convert a PyTorch Dataset to a Hugging Face Dataset.
|
| 26 |
+
|
| 27 |
+
This function takes a PyTorch Dataset and converts it to a Hugging Face Dataset
|
| 28 |
+
by extracting all items and organizing them into a dictionary structure that
|
| 29 |
+
can be used to create a Hugging Face Dataset.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
torch_dataset: A PyTorch Dataset object to be converted
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
A Hugging Face Dataset containing the same data as the input PyTorch Dataset
|
| 36 |
+
"""
|
| 37 |
+
# Get first item to determine structure
|
| 38 |
+
if len(torch_dataset) == 0:
|
| 39 |
+
return HFDataset.from_dict({})
|
| 40 |
+
|
| 41 |
+
first_item = torch_dataset[0]
|
| 42 |
+
|
| 43 |
+
# Initialize dictionary based on first item's keys
|
| 44 |
+
data_dict = (
|
| 45 |
+
{k: [] for k in first_item.keys()}
|
| 46 |
+
if isinstance(first_item, dict)
|
| 47 |
+
else {"data": []}
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Populate dictionary
|
| 51 |
+
for i in range(len(torch_dataset)):
|
| 52 |
+
item = torch_dataset[i]
|
| 53 |
+
if isinstance(item, dict):
|
| 54 |
+
for k in data_dict:
|
| 55 |
+
data_dict[k].append(item[k])
|
| 56 |
+
else:
|
| 57 |
+
data_dict["data"].append(item)
|
| 58 |
+
|
| 59 |
+
return HFDataset.from_dict(data_dict)
|
BioReason/bioreason/dataset/variant_effect.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import sys
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset, DataLoader
|
| 7 |
+
from typing import Any, Dict, List, Tuple
|
| 8 |
+
|
| 9 |
+
from bioreason.dataset.utils import torch_to_hf_dataset
|
| 10 |
+
from bioreason.models.dl.processing_dl import DLProcessor
|
| 11 |
+
from trl.data_utils import maybe_apply_chat_template
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_format_variant_effect_function(model_name: str) -> Any:
|
| 15 |
+
"""
|
| 16 |
+
Get the appropriate format function for a given model name.
|
| 17 |
+
"""
|
| 18 |
+
if model_name.lower() == "llm":
|
| 19 |
+
return format_variant_effect_for_llm
|
| 20 |
+
elif model_name.lower() == "dna-llm":
|
| 21 |
+
return format_variant_effect_for_dna_llm
|
| 22 |
+
else:
|
| 23 |
+
raise ValueError(f"Unsupported model name: {model_name}")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def clean_variant_effect_example(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 27 |
+
"""
|
| 28 |
+
Clean a variant effect example.
|
| 29 |
+
"""
|
| 30 |
+
example['answer'] = example['answer'].split(";")[0].strip().lower()
|
| 31 |
+
return example
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def clean_variant_effect_non_snv_example(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 35 |
+
"""
|
| 36 |
+
Clean a variant effect non-SNV example.
|
| 37 |
+
"""
|
| 38 |
+
example['answer'] = example['answer'].replace("[", "").replace("]", "").replace("'", "").replace("_", " ").strip()
|
| 39 |
+
return example
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def format_variant_effect_for_dna_llm(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 43 |
+
"""
|
| 44 |
+
Format a VEP example into the required chat format for DNA-LLM.
|
| 45 |
+
"""
|
| 46 |
+
return {
|
| 47 |
+
"prompt": [
|
| 48 |
+
{
|
| 49 |
+
"role": "user",
|
| 50 |
+
"content": [
|
| 51 |
+
*({"type": "dna", "text": None} for _ in range(2)),
|
| 52 |
+
{"type": "text", "text": example["question"].strip()},
|
| 53 |
+
],
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"role": "assistant",
|
| 57 |
+
"reasoning_content": f"Answer: {example['answer'].strip()}",
|
| 58 |
+
"content": [
|
| 59 |
+
{"type": "text", "text": f"Answer: {example['answer'].strip()}"},
|
| 60 |
+
],
|
| 61 |
+
},
|
| 62 |
+
],
|
| 63 |
+
"dna_sequences": [
|
| 64 |
+
example["reference_sequence"],
|
| 65 |
+
example["variant_sequence"],
|
| 66 |
+
],
|
| 67 |
+
"answer": example["answer"].strip(),
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def format_variant_effect_for_llm(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 72 |
+
"""
|
| 73 |
+
Format a VEP example into the required chat format for LLM.
|
| 74 |
+
"""
|
| 75 |
+
question = f"Reference sequence: {example['reference_sequence']}\nVariant sequence: {example['variant_sequence']}\nQuestion: {example['question']}"
|
| 76 |
+
return {
|
| 77 |
+
"prompt": [
|
| 78 |
+
{
|
| 79 |
+
"role": "user",
|
| 80 |
+
"content": [
|
| 81 |
+
*({"type": "dna", "text": None} for _ in range(2)),
|
| 82 |
+
{"type": "text", "text": question.strip()},
|
| 83 |
+
],
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"role": "assistant",
|
| 87 |
+
"reasoning_content": f"Answer: {example['answer'].strip()}",
|
| 88 |
+
"content": [
|
| 89 |
+
{"type": "text", "text": f"Answer: {example['answer'].strip()}"},
|
| 90 |
+
],
|
| 91 |
+
},
|
| 92 |
+
],
|
| 93 |
+
"dna_sequences": [
|
| 94 |
+
"",
|
| 95 |
+
"",
|
| 96 |
+
],
|
| 97 |
+
"answer": example["answer"].strip(),
|
| 98 |
+
}
|
BioReason/bioreason/dna_modules/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dna_module import DNABaseModule
|
| 2 |
+
from .nucleotide_module import NucleotideDNAModule
|
| 3 |
+
|
| 4 |
+
__all__ = ["DNABaseModule", "NucleotideDNAModule"]
|
BioReason/bioreason/dna_modules/dna_module.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Dict, Any, Union
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class DNABaseModule(ABC):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def get_dnallm_key(self):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def get_model_class(self, model_id: str, model_init_kwargs: dict):
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def post_model_init(self, model, processing_class):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
def is_embeds_input(self):
|
| 21 |
+
return False
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def get_processing_class(self):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
@abstractmethod
|
| 28 |
+
def get_dnallm_modules_keywords(self):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def get_custom_multimodal_keywords(self):
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
@abstractmethod
|
| 36 |
+
def get_non_generate_params(self):
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def get_custom_processing_keywords(self):
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def prepare_prompt(self, processing_class, inputs: dict[str, Union[torch.Tensor, Any]]):
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
@abstractmethod
|
| 48 |
+
def prepare_model_inputs(self, processing_class, prompts_text, images, return_tensors, padding, padding_side, add_special_tokens):
|
| 49 |
+
pass
|
BioReason/bioreason/dna_modules/esm_protein_module.py
ADDED
|
@@ -0,0 +1,649 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, Union, List, Optional, Callable, Type
|
| 2 |
+
from trl.data_utils import maybe_apply_chat_template
|
| 3 |
+
import torch
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
from bioreason.protein_modules.protein_base_module import ProteinBaseModule
|
| 9 |
+
from bioreason.models.protein_llm.protein_llm import ProteinLLMModel
|
| 10 |
+
from bioreason.models.protein_llm.processing_protein import ProteinLLMProcessor
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ESMQFormerModule(ProteinBaseModule):
|
| 14 |
+
"""
|
| 15 |
+
Protein module implementation for ESM2 + Q-Former based models.
|
| 16 |
+
|
| 17 |
+
This module provides the interface between Protein-LLM models and the training
|
| 18 |
+
infrastructure, handling model loading, processing setup, and reward functions
|
| 19 |
+
specifically for the ESM2 + Q-Former + LLM architecture.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
"""Initialize the ESMQFormerModule."""
|
| 24 |
+
super().__init__()
|
| 25 |
+
|
| 26 |
+
def get_protein_llm_key(self) -> str:
|
| 27 |
+
"""
|
| 28 |
+
Get the key identifier for this protein-LLM implementation.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
String identifier for this module type
|
| 32 |
+
"""
|
| 33 |
+
return "esm_qformer"
|
| 34 |
+
|
| 35 |
+
def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type:
|
| 36 |
+
"""
|
| 37 |
+
Return the appropriate model class based on model ID.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model_id: Identifier for the model
|
| 41 |
+
model_init_kwargs: Initialization arguments for the model
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
The model class to instantiate
|
| 45 |
+
|
| 46 |
+
Raises:
|
| 47 |
+
ValueError: If the model is not supported
|
| 48 |
+
"""
|
| 49 |
+
if "ProteinLLM" in model_id or "ESM" in model_id:
|
| 50 |
+
return ProteinLLMModel
|
| 51 |
+
else:
|
| 52 |
+
raise ValueError(f"Unsupported model for ESM-QFormer module: {model_id}")
|
| 53 |
+
|
| 54 |
+
def post_model_init(self, model: Any, processing_class: Any) -> None:
|
| 55 |
+
"""
|
| 56 |
+
Perform any post-initialization setup on the model.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model: The initialized model
|
| 60 |
+
processing_class: The processor for the model
|
| 61 |
+
"""
|
| 62 |
+
# Ensure special tokens are properly set
|
| 63 |
+
if hasattr(model, 'text_tokenizer') and hasattr(model.text_tokenizer, 'pad_token_id'):
|
| 64 |
+
if model.text_tokenizer.pad_token_id is None:
|
| 65 |
+
model.text_tokenizer.pad_token_id = model.text_tokenizer.eos_token_id
|
| 66 |
+
|
| 67 |
+
# Set up Q-Former query token count for batch processing
|
| 68 |
+
if hasattr(processing_class, 'set_query_tokens'):
|
| 69 |
+
processing_class.set_query_tokens(model.qformer_num_query_tokens)
|
| 70 |
+
|
| 71 |
+
def get_processing_class(self) -> Type:
|
| 72 |
+
"""
|
| 73 |
+
Get the processing class to use with this protein-LLM model.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
The processing class
|
| 77 |
+
"""
|
| 78 |
+
return ProteinLLMProcessor
|
| 79 |
+
|
| 80 |
+
def get_protein_llm_modules_keywords(self) -> List[str]:
|
| 81 |
+
"""
|
| 82 |
+
Get keywords to identify protein-specific modules in the model.
|
| 83 |
+
|
| 84 |
+
Used to exclude protein modules from LoRA adaptation during training.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
List of keywords that identify protein modules
|
| 88 |
+
"""
|
| 89 |
+
return [
|
| 90 |
+
"protein",
|
| 91 |
+
"esm",
|
| 92 |
+
"qformer",
|
| 93 |
+
"protein_projection",
|
| 94 |
+
"query_tokens",
|
| 95 |
+
"protein_model",
|
| 96 |
+
]
|
| 97 |
+
|
| 98 |
+
def get_custom_multimodal_keywords(self) -> List[str]:
|
| 99 |
+
"""
|
| 100 |
+
Get keywords for multimodal inputs that should be passed to the model.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
List of input keywords for multimodal processing
|
| 104 |
+
"""
|
| 105 |
+
return ["protein_tokenized", "batch_idx_map"]
|
| 106 |
+
|
| 107 |
+
def get_non_generate_params(self) -> List[str]:
|
| 108 |
+
"""
|
| 109 |
+
Get parameter names that should be excluded from generation.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
List of parameter names to exclude from generation calls
|
| 113 |
+
"""
|
| 114 |
+
return ["labels"]
|
| 115 |
+
|
| 116 |
+
def get_custom_processing_keywords(self) -> List[tuple]:
|
| 117 |
+
"""
|
| 118 |
+
Get custom processing keywords for the processor.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
List of (component, parameter) tuples for custom processing
|
| 122 |
+
"""
|
| 123 |
+
return [
|
| 124 |
+
("protein_tokenizer", "max_length_protein"),
|
| 125 |
+
("text_tokenizer", "max_length_text"),
|
| 126 |
+
("qformer", "num_query_tokens"),
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
def prepare_prompt(
|
| 130 |
+
self,
|
| 131 |
+
processing_class: Any,
|
| 132 |
+
inputs: List[Dict[str, Union[torch.Tensor, Any]]]
|
| 133 |
+
) -> List[str]:
|
| 134 |
+
"""
|
| 135 |
+
Prepare prompts from input examples.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
processing_class: The processor to use
|
| 139 |
+
inputs: List of input examples
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
List of prepared prompts
|
| 143 |
+
"""
|
| 144 |
+
prompts_text = []
|
| 145 |
+
for example in inputs:
|
| 146 |
+
# Apply chat template if available
|
| 147 |
+
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'chat_template'):
|
| 148 |
+
formatted_prompt = maybe_apply_chat_template(example, processing_class)["prompt"]
|
| 149 |
+
else:
|
| 150 |
+
# Fallback to simple formatting
|
| 151 |
+
formatted_prompt = example.get("prompt", example.get("text", str(example)))
|
| 152 |
+
prompts_text.append(formatted_prompt)
|
| 153 |
+
|
| 154 |
+
return prompts_text
|
| 155 |
+
|
| 156 |
+
def prepare_model_inputs(
|
| 157 |
+
self,
|
| 158 |
+
processing_class: Any,
|
| 159 |
+
model: Any,
|
| 160 |
+
prompts_text: List[str],
|
| 161 |
+
batch_protein_sequences: List[List[str]],
|
| 162 |
+
return_tensors: str = "pt",
|
| 163 |
+
padding: bool = True,
|
| 164 |
+
padding_side: str = "left",
|
| 165 |
+
add_special_tokens: bool = False,
|
| 166 |
+
**kwargs
|
| 167 |
+
) -> Dict[str, Any]:
|
| 168 |
+
"""
|
| 169 |
+
Prepare inputs for the model.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
processing_class: The processor to use
|
| 173 |
+
model: The model to prepare inputs for
|
| 174 |
+
prompts_text: List of text prompts
|
| 175 |
+
batch_protein_sequences: List of lists of protein sequences
|
| 176 |
+
return_tensors: Return format for tensors
|
| 177 |
+
padding: Whether to pad inputs
|
| 178 |
+
padding_side: Side to pad on
|
| 179 |
+
add_special_tokens: Whether to add special tokens
|
| 180 |
+
**kwargs: Additional arguments
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Processed inputs for the model
|
| 184 |
+
"""
|
| 185 |
+
# Handle DataParallel wrapped models
|
| 186 |
+
actual_model = model.module if hasattr(model, 'module') else model
|
| 187 |
+
|
| 188 |
+
# Get model parameters
|
| 189 |
+
max_length_text = getattr(actual_model, 'max_length_text', 512)
|
| 190 |
+
max_length_protein = getattr(actual_model, 'max_length_protein', 1024)
|
| 191 |
+
|
| 192 |
+
# Process inputs using the processor
|
| 193 |
+
prompt_inputs = processing_class(
|
| 194 |
+
text=prompts_text,
|
| 195 |
+
proteins=batch_protein_sequences,
|
| 196 |
+
return_tensors=return_tensors,
|
| 197 |
+
padding=padding,
|
| 198 |
+
padding_side=padding_side,
|
| 199 |
+
add_special_tokens=add_special_tokens,
|
| 200 |
+
max_length_text=max_length_text,
|
| 201 |
+
max_length_protein=max_length_protein,
|
| 202 |
+
**kwargs
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
return prompt_inputs
|
| 206 |
+
|
| 207 |
+
@staticmethod
|
| 208 |
+
def get_question_template() -> str:
|
| 209 |
+
"""
|
| 210 |
+
Get the template for formatting questions.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
String template for questions
|
| 214 |
+
"""
|
| 215 |
+
return "Question: {Question}\nAnswer:"
|
| 216 |
+
|
| 217 |
+
@staticmethod
|
| 218 |
+
def accuracy_reward(
|
| 219 |
+
completions: List[Dict[str, Any]],
|
| 220 |
+
ground_truth: List[str],
|
| 221 |
+
**kwargs
|
| 222 |
+
) -> List[float]:
|
| 223 |
+
"""
|
| 224 |
+
Accuracy reward function for protein tasks.
|
| 225 |
+
|
| 226 |
+
Compares model outputs with ground truth using exact match and semantic similarity.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
completions: List of model completions
|
| 230 |
+
ground_truth: List of ground truth answers
|
| 231 |
+
**kwargs: Additional arguments including:
|
| 232 |
+
- match_type: 'exact', 'partial', 'semantic' (default: 'partial')
|
| 233 |
+
- case_sensitive: bool (default: False)
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
List of reward scores (0.0 to 1.0)
|
| 237 |
+
"""
|
| 238 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 239 |
+
match_type = kwargs.get('match_type', 'partial')
|
| 240 |
+
case_sensitive = kwargs.get('case_sensitive', False)
|
| 241 |
+
rewards = []
|
| 242 |
+
|
| 243 |
+
for content, truth in zip(completion_contents, ground_truth):
|
| 244 |
+
if not case_sensitive:
|
| 245 |
+
content = content.lower().strip()
|
| 246 |
+
truth = truth.lower().strip()
|
| 247 |
+
else:
|
| 248 |
+
content = content.strip()
|
| 249 |
+
truth = truth.strip()
|
| 250 |
+
|
| 251 |
+
if match_type == 'exact':
|
| 252 |
+
# Exact string match
|
| 253 |
+
reward = 1.0 if content == truth else 0.0
|
| 254 |
+
elif match_type == 'partial':
|
| 255 |
+
# Partial matching using token overlap
|
| 256 |
+
content_tokens = set(content.split())
|
| 257 |
+
truth_tokens = set(truth.split())
|
| 258 |
+
if len(truth_tokens) == 0:
|
| 259 |
+
reward = 0.0
|
| 260 |
+
else:
|
| 261 |
+
overlap = len(content_tokens & truth_tokens)
|
| 262 |
+
reward = overlap / len(truth_tokens)
|
| 263 |
+
elif match_type == 'semantic':
|
| 264 |
+
# Simple semantic matching using key biological terms
|
| 265 |
+
bio_terms_content = set(re.findall(r'\b(?:protein|enzyme|binding|domain|function|structure|fold|helix|sheet|active|site|residue|amino|acid)\b', content.lower()))
|
| 266 |
+
bio_terms_truth = set(re.findall(r'\b(?:protein|enzyme|binding|domain|function|structure|fold|helix|sheet|active|site|residue|amino|acid)\b', truth.lower()))
|
| 267 |
+
|
| 268 |
+
if len(bio_terms_truth) == 0:
|
| 269 |
+
# Fallback to partial matching
|
| 270 |
+
content_tokens = set(content.split())
|
| 271 |
+
truth_tokens = set(truth.split())
|
| 272 |
+
overlap = len(content_tokens & truth_tokens)
|
| 273 |
+
reward = overlap / len(truth_tokens) if len(truth_tokens) > 0 else 0.0
|
| 274 |
+
else:
|
| 275 |
+
overlap = len(bio_terms_content & bio_terms_truth)
|
| 276 |
+
reward = overlap / len(bio_terms_truth)
|
| 277 |
+
else:
|
| 278 |
+
reward = 0.0
|
| 279 |
+
|
| 280 |
+
rewards.append(min(1.0, max(0.0, reward))) # Clamp between 0 and 1
|
| 281 |
+
|
| 282 |
+
return rewards
|
| 283 |
+
|
| 284 |
+
@staticmethod
|
| 285 |
+
def repetition_reward(
|
| 286 |
+
completions: List[Dict[str, Any]],
|
| 287 |
+
**kwargs
|
| 288 |
+
) -> List[float]:
|
| 289 |
+
"""
|
| 290 |
+
Repetition penalty reward function.
|
| 291 |
+
|
| 292 |
+
Penalizes outputs with excessive repetition of words, phrases, or patterns.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
completions: List of model completions
|
| 296 |
+
**kwargs: Additional arguments including:
|
| 297 |
+
- window_size: Size of n-gram window to check (default: 3)
|
| 298 |
+
- repetition_threshold: Maximum allowed repetition ratio (default: 0.3)
|
| 299 |
+
- penalty_weight: Weight of penalty applied (default: 1.0)
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
List of reward scores (0.0 to 1.0, lower for more repetitive text)
|
| 303 |
+
"""
|
| 304 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 305 |
+
window_size = kwargs.get('window_size', 3)
|
| 306 |
+
repetition_threshold = kwargs.get('repetition_threshold', 0.3)
|
| 307 |
+
penalty_weight = kwargs.get('penalty_weight', 1.0)
|
| 308 |
+
rewards = []
|
| 309 |
+
|
| 310 |
+
for content in completion_contents:
|
| 311 |
+
if not content.strip():
|
| 312 |
+
rewards.append(0.0)
|
| 313 |
+
continue
|
| 314 |
+
|
| 315 |
+
# Tokenize content
|
| 316 |
+
tokens = content.lower().split()
|
| 317 |
+
if len(tokens) < window_size:
|
| 318 |
+
rewards.append(1.0) # Too short to have meaningful repetition
|
| 319 |
+
continue
|
| 320 |
+
|
| 321 |
+
# Calculate n-gram repetition
|
| 322 |
+
ngrams = []
|
| 323 |
+
for i in range(len(tokens) - window_size + 1):
|
| 324 |
+
ngram = ' '.join(tokens[i:i + window_size])
|
| 325 |
+
ngrams.append(ngram)
|
| 326 |
+
|
| 327 |
+
if len(ngrams) == 0:
|
| 328 |
+
rewards.append(1.0)
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
# Count unique vs total n-grams
|
| 332 |
+
unique_ngrams = len(set(ngrams))
|
| 333 |
+
total_ngrams = len(ngrams)
|
| 334 |
+
repetition_ratio = 1.0 - (unique_ngrams / total_ngrams)
|
| 335 |
+
|
| 336 |
+
# Calculate word-level repetition
|
| 337 |
+
unique_words = len(set(tokens))
|
| 338 |
+
total_words = len(tokens)
|
| 339 |
+
word_repetition_ratio = 1.0 - (unique_words / total_words)
|
| 340 |
+
|
| 341 |
+
# Combine repetition metrics
|
| 342 |
+
combined_repetition = (repetition_ratio + word_repetition_ratio) / 2
|
| 343 |
+
|
| 344 |
+
# Apply penalty
|
| 345 |
+
if combined_repetition > repetition_threshold:
|
| 346 |
+
penalty = (combined_repetition - repetition_threshold) * penalty_weight
|
| 347 |
+
reward = max(0.0, 1.0 - penalty)
|
| 348 |
+
else:
|
| 349 |
+
reward = 1.0
|
| 350 |
+
|
| 351 |
+
rewards.append(reward)
|
| 352 |
+
|
| 353 |
+
# Log repetition analysis if in debug mode
|
| 354 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 355 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 356 |
+
log_path = os.getenv("LOG_PATH", "debug.log")
|
| 357 |
+
with open(
|
| 358 |
+
log_path.replace(".txt", "_repetition.txt"), "a", encoding="utf-8"
|
| 359 |
+
) as f:
|
| 360 |
+
f.write(f"------------- {current_time} Repetition Reward -------------\n")
|
| 361 |
+
for content, reward in zip(completion_contents, rewards):
|
| 362 |
+
f.write(f"Content: {content[:200]}...\n")
|
| 363 |
+
f.write(f"Repetition reward: {reward:.3f}\n")
|
| 364 |
+
|
| 365 |
+
return rewards
|
| 366 |
+
|
| 367 |
+
@staticmethod
|
| 368 |
+
def format_accuracy_reward(
|
| 369 |
+
completions: List[Dict[str, Any]],
|
| 370 |
+
**kwargs
|
| 371 |
+
) -> List[float]:
|
| 372 |
+
"""
|
| 373 |
+
Format accuracy reward function for protein analysis outputs.
|
| 374 |
+
|
| 375 |
+
Checks if the model output follows the expected format and structure.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
completions: List of model completions
|
| 379 |
+
**kwargs: Additional arguments including:
|
| 380 |
+
- required_format: Format pattern to match (default: structured analysis)
|
| 381 |
+
- strict_mode: Whether to require exact format compliance (default: False)
|
| 382 |
+
- format_elements: List of required elements to check for
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
List of reward scores (0.0 to 1.0 based on format compliance)
|
| 386 |
+
"""
|
| 387 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 388 |
+
required_format = kwargs.get('required_format', 'structured_analysis')
|
| 389 |
+
strict_mode = kwargs.get('strict_mode', False)
|
| 390 |
+
format_elements = kwargs.get('format_elements', ['analysis', 'answer'])
|
| 391 |
+
rewards = []
|
| 392 |
+
|
| 393 |
+
for content in completion_contents:
|
| 394 |
+
if not content.strip():
|
| 395 |
+
rewards.append(0.0)
|
| 396 |
+
continue
|
| 397 |
+
|
| 398 |
+
score = 0.0
|
| 399 |
+
total_checks = 0
|
| 400 |
+
|
| 401 |
+
if required_format == 'structured_analysis':
|
| 402 |
+
# Check for structured analysis format
|
| 403 |
+
patterns = {
|
| 404 |
+
'analysis_section': r'<analysis>.*?</analysis>',
|
| 405 |
+
'answer_section': r'<answer>.*?</answer>',
|
| 406 |
+
'proper_nesting': r'<analysis>.*?</analysis>\s*<answer>.*?</answer>'
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
for pattern_name, pattern in patterns.items():
|
| 410 |
+
total_checks += 1
|
| 411 |
+
if re.search(pattern, content, re.DOTALL | re.IGNORECASE):
|
| 412 |
+
score += 1
|
| 413 |
+
|
| 414 |
+
elif required_format == 'qa_format':
|
| 415 |
+
# Check for Q&A format
|
| 416 |
+
patterns = {
|
| 417 |
+
'question_indicator': r'(?:question|q):\s*',
|
| 418 |
+
'answer_indicator': r'(?:answer|a):\s*',
|
| 419 |
+
'proper_structure': r'(?:question|q):\s*.*?(?:answer|a):\s*'
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
for pattern_name, pattern in patterns.items():
|
| 423 |
+
total_checks += 1
|
| 424 |
+
if re.search(pattern, content, re.DOTALL | re.IGNORECASE):
|
| 425 |
+
score += 1
|
| 426 |
+
|
| 427 |
+
elif required_format == 'json_format':
|
| 428 |
+
# Check for JSON-like format
|
| 429 |
+
try:
|
| 430 |
+
import json
|
| 431 |
+
# Try to find JSON-like structures
|
| 432 |
+
json_pattern = r'\{[^{}]*\}'
|
| 433 |
+
json_matches = re.findall(json_pattern, content)
|
| 434 |
+
total_checks = 1
|
| 435 |
+
if json_matches:
|
| 436 |
+
score = 1
|
| 437 |
+
except:
|
| 438 |
+
total_checks = 1
|
| 439 |
+
score = 0
|
| 440 |
+
|
| 441 |
+
elif required_format == 'custom_elements':
|
| 442 |
+
# Check for custom format elements
|
| 443 |
+
total_checks = len(format_elements)
|
| 444 |
+
for element in format_elements:
|
| 445 |
+
if element.lower() in content.lower():
|
| 446 |
+
score += 1
|
| 447 |
+
|
| 448 |
+
# Additional general format checks
|
| 449 |
+
general_checks = {
|
| 450 |
+
'proper_capitalization': bool(re.search(r'[A-Z]', content)),
|
| 451 |
+
'proper_punctuation': bool(re.search(r'[.!?]', content)),
|
| 452 |
+
'reasonable_length': len(content.split()) >= 5,
|
| 453 |
+
'no_excessive_whitespace': not bool(re.search(r'\s{5,}', content))
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
for check_name, check_result in general_checks.items():
|
| 457 |
+
total_checks += 1
|
| 458 |
+
if check_result:
|
| 459 |
+
score += 1
|
| 460 |
+
|
| 461 |
+
# Calculate final reward
|
| 462 |
+
if total_checks > 0:
|
| 463 |
+
reward = score / total_checks
|
| 464 |
+
if strict_mode and reward < 1.0:
|
| 465 |
+
reward = 0.0 # All or nothing in strict mode
|
| 466 |
+
else:
|
| 467 |
+
reward = 0.0
|
| 468 |
+
|
| 469 |
+
rewards.append(reward)
|
| 470 |
+
|
| 471 |
+
# Log format analysis if in debug mode
|
| 472 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 473 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 474 |
+
log_path = os.getenv("LOG_PATH", "debug.log")
|
| 475 |
+
with open(
|
| 476 |
+
log_path.replace(".txt", "_format_accuracy.txt"), "a", encoding="utf-8"
|
| 477 |
+
) as f:
|
| 478 |
+
f.write(f"------------- {current_time} Format Accuracy Reward -------------\n")
|
| 479 |
+
for content, reward in zip(completion_contents, rewards):
|
| 480 |
+
f.write(f"Content: {content[:200]}...\n")
|
| 481 |
+
f.write(f"Format accuracy: {reward:.3f}\n")
|
| 482 |
+
f.write(f"Required format: {required_format}\n")
|
| 483 |
+
|
| 484 |
+
return rewards
|
| 485 |
+
|
| 486 |
+
def get_reward_functions(self) -> Dict[str, callable]:
|
| 487 |
+
"""
|
| 488 |
+
Get available reward functions for protein tasks.
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
Dictionary mapping function names to callables
|
| 492 |
+
"""
|
| 493 |
+
return {
|
| 494 |
+
"accuracy": self.accuracy_reward,
|
| 495 |
+
"repetition": self.repetition_reward,
|
| 496 |
+
"format_accuracy": self.format_accuracy_reward,
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
@staticmethod
|
| 500 |
+
def select_reward_func(func: str, task_type: str = None, **kwargs) -> Callable:
|
| 501 |
+
"""
|
| 502 |
+
Select the appropriate reward function based on function name.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
func: The type of reward function ('accuracy', 'repetition', 'format_accuracy')
|
| 506 |
+
task_type: The type of task (optional, for compatibility)
|
| 507 |
+
**kwargs: Additional arguments to pass to the reward function
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
The reward function to use
|
| 511 |
+
|
| 512 |
+
Raises:
|
| 513 |
+
ValueError: If the function is not supported
|
| 514 |
+
"""
|
| 515 |
+
module = ESMQFormerModule()
|
| 516 |
+
reward_funcs = module.get_reward_functions()
|
| 517 |
+
|
| 518 |
+
if func in reward_funcs:
|
| 519 |
+
reward_func = reward_funcs[func]
|
| 520 |
+
|
| 521 |
+
# Create a wrapper that includes the kwargs
|
| 522 |
+
def wrapped_reward_func(completions, ground_truth=None, **additional_kwargs):
|
| 523 |
+
# Merge kwargs with additional_kwargs
|
| 524 |
+
merged_kwargs = {**kwargs, **additional_kwargs}
|
| 525 |
+
|
| 526 |
+
# Handle functions that don't need ground_truth
|
| 527 |
+
if func == 'repetition':
|
| 528 |
+
return reward_func(completions, **merged_kwargs)
|
| 529 |
+
elif func == 'format_accuracy':
|
| 530 |
+
return reward_func(completions, **merged_kwargs)
|
| 531 |
+
else: # accuracy
|
| 532 |
+
if ground_truth is None:
|
| 533 |
+
raise ValueError(f"Ground truth required for {func} reward function")
|
| 534 |
+
return reward_func(completions, ground_truth, **merged_kwargs)
|
| 535 |
+
|
| 536 |
+
return wrapped_reward_func
|
| 537 |
+
else:
|
| 538 |
+
raise ValueError(f"Unsupported reward function: {func}. Available functions: {list(reward_funcs.keys())}")
|
| 539 |
+
|
| 540 |
+
@staticmethod
|
| 541 |
+
def combine_rewards(
|
| 542 |
+
completions: List[Dict[str, Any]],
|
| 543 |
+
ground_truth: List[str] = None,
|
| 544 |
+
reward_weights: Dict[str, float] = None,
|
| 545 |
+
**kwargs
|
| 546 |
+
) -> List[float]:
|
| 547 |
+
"""
|
| 548 |
+
Combine multiple reward functions with specified weights.
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
completions: List of model completions
|
| 552 |
+
ground_truth: List of ground truth answers (optional)
|
| 553 |
+
reward_weights: Dictionary of reward function names and their weights
|
| 554 |
+
Default: {'accuracy': 0.5, 'repetition': 0.3, 'format_accuracy': 0.2}
|
| 555 |
+
**kwargs: Additional arguments for individual reward functions
|
| 556 |
+
|
| 557 |
+
Returns:
|
| 558 |
+
List of combined reward scores
|
| 559 |
+
"""
|
| 560 |
+
if reward_weights is None:
|
| 561 |
+
reward_weights = {
|
| 562 |
+
'accuracy': 0.5,
|
| 563 |
+
'repetition': 0.3,
|
| 564 |
+
'format_accuracy': 0.2
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
# Normalize weights
|
| 568 |
+
total_weight = sum(reward_weights.values())
|
| 569 |
+
if total_weight > 0:
|
| 570 |
+
reward_weights = {k: v / total_weight for k, v in reward_weights.items()}
|
| 571 |
+
|
| 572 |
+
combined_rewards = [0.0] * len(completions)
|
| 573 |
+
module = ESMQFormerModule()
|
| 574 |
+
|
| 575 |
+
for func_name, weight in reward_weights.items():
|
| 576 |
+
if weight <= 0:
|
| 577 |
+
continue
|
| 578 |
+
|
| 579 |
+
try:
|
| 580 |
+
if func_name == 'accuracy' and ground_truth is not None:
|
| 581 |
+
rewards = module.accuracy_reward(completions, ground_truth, **kwargs)
|
| 582 |
+
elif func_name == 'repetition':
|
| 583 |
+
rewards = module.repetition_reward(completions, **kwargs)
|
| 584 |
+
elif func_name == 'format_accuracy':
|
| 585 |
+
rewards = module.format_accuracy_reward(completions, **kwargs)
|
| 586 |
+
else:
|
| 587 |
+
continue # Skip unknown or incompatible functions
|
| 588 |
+
|
| 589 |
+
# Add weighted rewards
|
| 590 |
+
for i, reward in enumerate(rewards):
|
| 591 |
+
combined_rewards[i] += weight * reward
|
| 592 |
+
|
| 593 |
+
except Exception as e:
|
| 594 |
+
print(f"Warning: Error calculating {func_name} reward: {e}")
|
| 595 |
+
continue
|
| 596 |
+
|
| 597 |
+
return combined_rewards
|
| 598 |
+
|
| 599 |
+
def validate_model_config(self, config: Dict[str, Any]) -> bool:
|
| 600 |
+
"""
|
| 601 |
+
Validate model configuration parameters for ESM+Q-Former architecture.
|
| 602 |
+
|
| 603 |
+
Args:
|
| 604 |
+
config: Configuration dictionary
|
| 605 |
+
|
| 606 |
+
Returns:
|
| 607 |
+
True if valid, False otherwise
|
| 608 |
+
"""
|
| 609 |
+
required_keys = [
|
| 610 |
+
"text_model_name",
|
| 611 |
+
"protein_model_name",
|
| 612 |
+
]
|
| 613 |
+
|
| 614 |
+
for key in required_keys:
|
| 615 |
+
if key not in config:
|
| 616 |
+
print(f"Missing required config key: {key}")
|
| 617 |
+
return False
|
| 618 |
+
|
| 619 |
+
# Validate Q-Former configuration
|
| 620 |
+
if "qformer_num_query_tokens" in config:
|
| 621 |
+
if not isinstance(config["qformer_num_query_tokens"], int) or config["qformer_num_query_tokens"] <= 0:
|
| 622 |
+
print("qformer_num_query_tokens must be a positive integer")
|
| 623 |
+
return False
|
| 624 |
+
|
| 625 |
+
# Validate max lengths
|
| 626 |
+
for length_key in ["max_length_protein", "max_length_text"]:
|
| 627 |
+
if length_key in config:
|
| 628 |
+
if not isinstance(config[length_key], int) or config[length_key] <= 0:
|
| 629 |
+
print(f"{length_key} must be a positive integer")
|
| 630 |
+
return False
|
| 631 |
+
|
| 632 |
+
return True
|
| 633 |
+
|
| 634 |
+
def get_default_generation_config(self) -> Dict[str, Any]:
|
| 635 |
+
"""
|
| 636 |
+
Get default generation configuration for protein-LLM models.
|
| 637 |
+
|
| 638 |
+
Returns:
|
| 639 |
+
Dictionary of default generation parameters
|
| 640 |
+
"""
|
| 641 |
+
return {
|
| 642 |
+
"max_new_tokens": 512,
|
| 643 |
+
"temperature": 0.7,
|
| 644 |
+
"do_sample": True,
|
| 645 |
+
"top_p": 0.9,
|
| 646 |
+
"repetition_penalty": 1.1,
|
| 647 |
+
"pad_token_id": None, # Will be set by processor
|
| 648 |
+
"eos_token_id": None, # Will be set by processor
|
| 649 |
+
}
|
BioReason/bioreason/dna_modules/nucleotide_module.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import (
|
| 2 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 3 |
+
Qwen2VLForConditionalGeneration,
|
| 4 |
+
AutoProcessor,
|
| 5 |
+
)
|
| 6 |
+
from typing import Dict, Any, Union, List, Optional, Callable, Type
|
| 7 |
+
from trl.data_utils import maybe_apply_chat_template
|
| 8 |
+
from trl import SFTTrainer
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from bioreason.dna_modules.dna_module import DNABaseModule
|
| 12 |
+
from bioreason.models.dna_llm import DNALLMModel
|
| 13 |
+
from bioreason.models.dl.processing_dl import DLProcessor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NucleotideDNAModule(DNABaseModule):
|
| 17 |
+
"""
|
| 18 |
+
DNA module implementation for NucleotideTransformer-based models.
|
| 19 |
+
|
| 20 |
+
This module provides the interface between DNA-LLM models and the training
|
| 21 |
+
infrastructure, handling model loading, processing setup, and reward functions.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
"""Initialize the NucleotideDNAModule."""
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
def get_dnallm_key(self) -> str:
|
| 29 |
+
"""
|
| 30 |
+
Get the key identifier for this DNA-LLM implementation.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
String identifier for this module type
|
| 34 |
+
"""
|
| 35 |
+
return "qwen"
|
| 36 |
+
|
| 37 |
+
def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type:
|
| 38 |
+
"""
|
| 39 |
+
Return the appropriate model class based on model ID.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
model_id: Identifier for the model
|
| 43 |
+
model_init_kwargs: Initialization arguments for the model
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
The model class to instantiate
|
| 47 |
+
|
| 48 |
+
Raises:
|
| 49 |
+
ValueError: If the model is not supported
|
| 50 |
+
"""
|
| 51 |
+
if "DNALLM" in model_id:
|
| 52 |
+
model_cls = DNALLMModel
|
| 53 |
+
else:
|
| 54 |
+
raise ValueError(f"Unsupported model: {model_id}")
|
| 55 |
+
return model_cls
|
| 56 |
+
|
| 57 |
+
def post_model_init(self, model: Any, processing_class: Any) -> None:
|
| 58 |
+
"""
|
| 59 |
+
Perform any post-initialization setup on the model.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
model: The initialized model
|
| 63 |
+
processing_class: The processor for the model
|
| 64 |
+
"""
|
| 65 |
+
# No post-init needed for this implementation
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
def get_processing_class(self) -> Type:
|
| 69 |
+
"""
|
| 70 |
+
Get the processing class to use with this DNA-LLM model.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
The processing class
|
| 74 |
+
"""
|
| 75 |
+
return DLProcessor
|
| 76 |
+
|
| 77 |
+
def get_dnallm_modules_keywords(self) -> List[str]:
|
| 78 |
+
"""
|
| 79 |
+
Get keywords to identify DNA-specific modules in the model.
|
| 80 |
+
|
| 81 |
+
Used to exclude DNA modules from LoRA adaptation during training.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
List of keywords that identify DNA modules
|
| 85 |
+
"""
|
| 86 |
+
return ["dna"]
|
| 87 |
+
|
| 88 |
+
def get_custom_multimodal_keywords(self) -> List[str]:
|
| 89 |
+
"""
|
| 90 |
+
Get keywords for multimodal inputs that should be passed to the model.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
List of input keywords for multimodal processing
|
| 94 |
+
"""
|
| 95 |
+
return ["dna_tokenized", "batch_idx_map"]
|
| 96 |
+
|
| 97 |
+
def get_non_generate_params(self) -> List[str]:
|
| 98 |
+
"""
|
| 99 |
+
Get parameter names that should be excluded from generation.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
List of parameter names to exclude from generation calls
|
| 103 |
+
"""
|
| 104 |
+
return []
|
| 105 |
+
|
| 106 |
+
def get_custom_processing_keywords(self) -> List[tuple]:
|
| 107 |
+
"""
|
| 108 |
+
Get custom processing keywords for the processor.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
List of (component, parameter) tuples for custom processing
|
| 112 |
+
"""
|
| 113 |
+
return [("dna_tokenizer", "max_length")]
|
| 114 |
+
|
| 115 |
+
def prepare_prompt(
|
| 116 |
+
self, processing_class: Any, inputs: List[Dict[str, Union[torch.Tensor, Any]]]
|
| 117 |
+
) -> List[str]:
|
| 118 |
+
"""
|
| 119 |
+
Prepare prompts from input examples.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
processing_class: The processor to use
|
| 123 |
+
inputs: List of input examples
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
List of prepared prompts
|
| 127 |
+
"""
|
| 128 |
+
prompts_text = [
|
| 129 |
+
maybe_apply_chat_template(example, processing_class)["prompt"]
|
| 130 |
+
for example in inputs
|
| 131 |
+
]
|
| 132 |
+
return prompts_text
|
| 133 |
+
|
| 134 |
+
def prepare_model_inputs(
|
| 135 |
+
self,
|
| 136 |
+
processing_class: Any,
|
| 137 |
+
model: Any,
|
| 138 |
+
prompts_text: List[str],
|
| 139 |
+
batch_dna_sequences: List[List[str]],
|
| 140 |
+
return_tensors: str = "pt",
|
| 141 |
+
padding: bool = True,
|
| 142 |
+
padding_side: str = "left",
|
| 143 |
+
add_special_tokens: bool = False,
|
| 144 |
+
) -> Dict[str, Any]:
|
| 145 |
+
"""
|
| 146 |
+
Prepare inputs for the model.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
processing_class: The processor to use
|
| 150 |
+
model: The model to prepare inputs for
|
| 151 |
+
prompts_text: List of text prompts
|
| 152 |
+
batch_dna_sequences: List of lists of DNA sequences
|
| 153 |
+
return_tensors: Return format for tensors
|
| 154 |
+
padding: Whether to pad inputs
|
| 155 |
+
padding_side: Side to pad on
|
| 156 |
+
add_special_tokens: Whether to add special tokens
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Processed inputs for the model
|
| 160 |
+
"""
|
| 161 |
+
# Handle DataParallel wrapped models by accessing the module attribute if needed
|
| 162 |
+
max_length_text = model.max_length_text if not hasattr(model, 'module') else model.module.max_length_text
|
| 163 |
+
max_length_dna = model.max_length_dna if not hasattr(model, 'module') else model.module.max_length_dna
|
| 164 |
+
|
| 165 |
+
prompt_inputs = processing_class(
|
| 166 |
+
text=prompts_text,
|
| 167 |
+
batch_dna_sequences=batch_dna_sequences,
|
| 168 |
+
return_tensors=return_tensors,
|
| 169 |
+
padding=padding,
|
| 170 |
+
padding_side=padding_side,
|
| 171 |
+
add_special_tokens=add_special_tokens,
|
| 172 |
+
max_length_text=max_length_text,
|
| 173 |
+
max_length_dna=max_length_dna,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
return prompt_inputs
|
| 177 |
+
|
| 178 |
+
def is_embeds_input(self) -> bool:
|
| 179 |
+
"""
|
| 180 |
+
Whether the model uses embeddings as input (instead of token IDs).
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Boolean indicating if the model takes embedding inputs
|
| 184 |
+
"""
|
| 185 |
+
return True
|
| 186 |
+
|
| 187 |
+
@staticmethod
|
| 188 |
+
def get_question_template() -> str:
|
| 189 |
+
"""
|
| 190 |
+
Get the template for formatting questions.
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
String template for questions
|
| 194 |
+
"""
|
| 195 |
+
return "{Question}"
|
| 196 |
+
|
| 197 |
+
@staticmethod
|
| 198 |
+
def format_reward_rec(completions: List[Dict[str, Any]], **kwargs) -> List[float]:
|
| 199 |
+
"""
|
| 200 |
+
Check if the Qwen model output matches a specific format.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
completions: List of model completions
|
| 204 |
+
**kwargs: Additional arguments
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
List of reward scores (1.0 for match, 0.0 for no match)
|
| 208 |
+
"""
|
| 209 |
+
import re
|
| 210 |
+
import os
|
| 211 |
+
from datetime import datetime
|
| 212 |
+
|
| 213 |
+
# Pattern to match the expected output format
|
| 214 |
+
pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
|
| 215 |
+
completion_contents = [completion[0]["content"] for completion in completions]
|
| 216 |
+
matches = [
|
| 217 |
+
re.search(pattern, content, re.DOTALL) is not None
|
| 218 |
+
for content in completion_contents
|
| 219 |
+
]
|
| 220 |
+
|
| 221 |
+
# Log format results if in debug mode
|
| 222 |
+
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
|
| 223 |
+
if os.getenv("DEBUG_MODE") == "true":
|
| 224 |
+
log_path = os.getenv("LOG_PATH")
|
| 225 |
+
with open(
|
| 226 |
+
log_path.replace(".txt", "_format.txt"), "a", encoding="utf-8"
|
| 227 |
+
) as f:
|
| 228 |
+
f.write(f"------------- {current_time} Format reward -------------\n")
|
| 229 |
+
for content, match in zip(completion_contents, matches):
|
| 230 |
+
f.write(f"Content: {content}\n")
|
| 231 |
+
f.write(f"Has format: {bool(match)}\n")
|
| 232 |
+
|
| 233 |
+
return [1.0 if match else 0.0 for match in matches]
|
| 234 |
+
|
| 235 |
+
@staticmethod
|
| 236 |
+
def select_reward_func(func: str, task_type: str) -> Callable:
|
| 237 |
+
"""
|
| 238 |
+
Select the appropriate reward function based on function name and task type.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
func: The type of reward function ('accuracy', 'format', etc.)
|
| 242 |
+
task_type: The type of task ('rec', etc.)
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
The reward function to use
|
| 246 |
+
|
| 247 |
+
Raises:
|
| 248 |
+
ValueError: If the function or task type is not supported
|
| 249 |
+
"""
|
| 250 |
+
if func == "accuracy":
|
| 251 |
+
match task_type:
|
| 252 |
+
case "rec":
|
| 253 |
+
return NucleotideDNAModule.iou_reward
|
| 254 |
+
case _:
|
| 255 |
+
raise ValueError(f"Unsupported reward function: {func}")
|
| 256 |
+
elif func == "format":
|
| 257 |
+
match task_type:
|
| 258 |
+
case "rec":
|
| 259 |
+
return NucleotideDNAModule.format_reward_rec
|
| 260 |
+
case _:
|
| 261 |
+
raise ValueError(f"Unsupported reward function: {func}")
|
| 262 |
+
else:
|
| 263 |
+
raise ValueError(f"Unsupported reward function: {func}")
|
BioReason/bioreason/dna_modules/protein_module.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Dict, Any, Union, List, Optional, Type
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
class ProteinBaseModule(ABC):
|
| 6 |
+
"""
|
| 7 |
+
Abstract base class for protein-language model modules.
|
| 8 |
+
|
| 9 |
+
This class defines the interface that all protein-LLM implementations
|
| 10 |
+
must follow, providing standardized methods for model loading,
|
| 11 |
+
processing, and training integration.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
"""Initialize the protein module."""
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def get_protein_llm_key(self) -> str:
|
| 20 |
+
"""
|
| 21 |
+
Get the unique identifier for this protein-LLM implementation.
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
String identifier for this module type
|
| 25 |
+
"""
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def get_model_class(self, model_id: str, model_init_kwargs: Dict[str, Any]) -> Type:
|
| 30 |
+
"""
|
| 31 |
+
Return the appropriate model class based on model ID.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model_id: Identifier for the model
|
| 35 |
+
model_init_kwargs: Initialization arguments for the model
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
The model class to instantiate
|
| 39 |
+
"""
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def post_model_init(self, model: Any, processing_class: Any) -> None:
|
| 43 |
+
"""
|
| 44 |
+
Perform any post-initialization setup on the model.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
model: The initialized model
|
| 48 |
+
processing_class: The processor for the model
|
| 49 |
+
"""
|
| 50 |
+
# Default implementation does nothing
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def is_embeds_input(self) -> bool:
|
| 54 |
+
"""
|
| 55 |
+
Whether the model uses embeddings as input (instead of token IDs).
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
Boolean indicating if the model takes embedding inputs
|
| 59 |
+
"""
|
| 60 |
+
# Default for protein-LLM models is True due to Q-Former integration
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
@abstractmethod
|
| 64 |
+
def get_processing_class(self) -> Type:
|
| 65 |
+
"""
|
| 66 |
+
Get the processing class to use with this protein-LLM model.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
The processing class
|
| 70 |
+
"""
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
@abstractmethod
|
| 74 |
+
def get_protein_llm_modules_keywords(self) -> List[str]:
|
| 75 |
+
"""
|
| 76 |
+
Get keywords to identify protein-specific modules in the model.
|
| 77 |
+
|
| 78 |
+
Used to exclude protein modules from LoRA adaptation during training
|
| 79 |
+
or to identify components for specific training strategies.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
List of keywords that identify protein modules
|
| 83 |
+
"""
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def get_custom_multimodal_keywords(self) -> List[str]:
|
| 88 |
+
"""
|
| 89 |
+
Get keywords for multimodal inputs that should be passed to the model.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
List of input keywords for multimodal processing
|
| 93 |
+
"""
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
@abstractmethod
|
| 97 |
+
def get_non_generate_params(self) -> List[str]:
|
| 98 |
+
"""
|
| 99 |
+
Get parameter names that should be excluded from generation.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
List of parameter names to exclude from generation calls
|
| 103 |
+
"""
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
@abstractmethod
|
| 107 |
+
def get_custom_processing_keywords(self) -> List[tuple]:
|
| 108 |
+
"""
|
| 109 |
+
Get custom processing keywords for the processor.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
List of (component, parameter) tuples for custom processing
|
| 113 |
+
"""
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
@abstractmethod
|
| 117 |
+
def prepare_prompt(
|
| 118 |
+
self,
|
| 119 |
+
processing_class: Any,
|
| 120 |
+
inputs: List[Dict[str, Union[torch.Tensor, Any]]]
|
| 121 |
+
) -> List[str]:
|
| 122 |
+
"""
|
| 123 |
+
Prepare prompts from input examples.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
processing_class: The processor to use
|
| 127 |
+
inputs: List of input examples
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
List of prepared prompts
|
| 131 |
+
"""
|
| 132 |
+
pass
|
| 133 |
+
|
| 134 |
+
@abstractmethod
|
| 135 |
+
def prepare_model_inputs(
|
| 136 |
+
self,
|
| 137 |
+
processing_class: Any,
|
| 138 |
+
model: Any,
|
| 139 |
+
prompts_text: List[str],
|
| 140 |
+
batch_protein_sequences: List[List[str]],
|
| 141 |
+
return_tensors: str = "pt",
|
| 142 |
+
padding: bool = True,
|
| 143 |
+
padding_side: str = "left",
|
| 144 |
+
add_special_tokens: bool = False,
|
| 145 |
+
**kwargs
|
| 146 |
+
) -> Dict[str, Any]:
|
| 147 |
+
"""
|
| 148 |
+
Prepare inputs for the model.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
processing_class: The processor to use
|
| 152 |
+
model: The model to prepare inputs for
|
| 153 |
+
prompts_text: List of text prompts
|
| 154 |
+
batch_protein_sequences: List of lists of protein sequences
|
| 155 |
+
return_tensors: Return format for tensors
|
| 156 |
+
padding: Whether to pad inputs
|
| 157 |
+
padding_side: Side to pad on
|
| 158 |
+
add_special_tokens: Whether to add special tokens
|
| 159 |
+
**kwargs: Additional arguments
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
Processed inputs for the model
|
| 163 |
+
"""
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
def get_reward_functions(self) -> Dict[str, callable]:
|
| 167 |
+
"""
|
| 168 |
+
Get available reward functions for this module.
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
Dictionary mapping function names to callables
|
| 172 |
+
"""
|
| 173 |
+
return {}
|
| 174 |
+
|
| 175 |
+
def validate_model_config(self, config: Dict[str, Any]) -> bool:
|
| 176 |
+
"""
|
| 177 |
+
Validate model configuration parameters.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
config: Configuration dictionary
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
True if valid, False otherwise
|
| 184 |
+
"""
|
| 185 |
+
return True
|
| 186 |
+
|
| 187 |
+
def get_default_generation_config(self) -> Dict[str, Any]:
|
| 188 |
+
"""
|
| 189 |
+
Get default generation configuration for this model type.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
Dictionary of default generation parameters
|
| 193 |
+
"""
|
| 194 |
+
return {
|
| 195 |
+
"max_new_tokens": 512,
|
| 196 |
+
"temperature": 0.7,
|
| 197 |
+
"do_sample": True,
|
| 198 |
+
"top_p": 0.9,
|
| 199 |
+
"pad_token_id": None, # Will be set by processor
|
| 200 |
+
}
|
BioReason/bioreason/models/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dna_only import DNAClassifierModel
|
| 2 |
+
from .dna_llm import DNALLMModel
|
| 3 |
+
from .evo2_tokenizer import Evo2Tokenizer
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"DNAClassifierModel",
|
| 7 |
+
"DNALLMModel",
|
| 8 |
+
"Evo2Tokenizer",
|
| 9 |
+
]
|
BioReason/bioreason/models/dl/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
BioReason/bioreason/models/dl/chat_template_dl.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
CHAT_TEMPLATE = "{%- set dna_count = namespace(value=0) %}{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content is string and message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' }} {%- if message.content is string %}{{- message.content + '<|im_end|>' + '\\n' }}{%- else %}{%- for content in message.content %}{%- if content.type == 'dna' or 'dna' in content %}{%- set dna_count.value = dna_count.value + 1 %}{%- if add_dna_id %}DNA Sequence {{- dna_count.value }}: {%- endif %}<|dna_start|><|dna_pad|><|dna_end|>{%- elif 'text' in content %}{{- content.text }}{%- endif %}{%- endfor %}{{- '<|im_end|>' + '\\n' }}{%- endif %}{%- elif message.role == \"assistant\" %}\n {%- set content = message.content[0].text %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content[0].text.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content[0].text.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
BioReason/bioreason/models/dl/configuration_dl.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class DLDNAConfig(PretrainedConfig):
|
| 4 |
+
model_type = "dl"
|
| 5 |
+
base_config_key = "dna_config"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
depth=32,
|
| 10 |
+
hidden_size=3584,
|
| 11 |
+
hidden_act="silu",
|
| 12 |
+
intermediate_size=3420,
|
| 13 |
+
num_heads=16,
|
| 14 |
+
in_channels=3,
|
| 15 |
+
patch_size=14,
|
| 16 |
+
spatial_merge_size=2,
|
| 17 |
+
temporal_patch_size=2,
|
| 18 |
+
tokens_per_second=4,
|
| 19 |
+
window_size=112,
|
| 20 |
+
out_hidden_size=3584,
|
| 21 |
+
fullatt_block_indexes=[7, 15, 23, 31],
|
| 22 |
+
**kwargs,
|
| 23 |
+
):
|
| 24 |
+
super().__init__(**kwargs)
|
| 25 |
+
|
| 26 |
+
self.depth = depth
|
| 27 |
+
self.hidden_size = hidden_size
|
| 28 |
+
self.hidden_act = hidden_act
|
| 29 |
+
self.intermediate_size = intermediate_size
|
| 30 |
+
self.num_heads = num_heads
|
| 31 |
+
self.in_channels = in_channels
|
| 32 |
+
self.patch_size = patch_size
|
| 33 |
+
self.spatial_merge_size = spatial_merge_size
|
| 34 |
+
self.temporal_patch_size = temporal_patch_size
|
| 35 |
+
self.tokens_per_second = tokens_per_second
|
| 36 |
+
self.window_size = window_size
|
| 37 |
+
self.fullatt_block_indexes = fullatt_block_indexes
|
| 38 |
+
self.out_hidden_size = out_hidden_size
|
| 39 |
+
|
| 40 |
+
class DLConfig(PretrainedConfig):
|
| 41 |
+
r"""
|
| 42 |
+
This is the configuration class to store the configuration of a [`Qwen2_5_VLModel`]. It is used to instantiate a
|
| 43 |
+
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 44 |
+
with the defaults will yield a similar configuration to that of
|
| 45 |
+
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
| 46 |
+
|
| 47 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 48 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
vocab_size (`int`, *optional*, defaults to 152064):
|
| 53 |
+
Vocabulary size of the Qwen2_5_VL model. Defines the number of different tokens that can be represented by the
|
| 54 |
+
`inputs_ids` passed when calling [`Qwen2_5_VLModel`]
|
| 55 |
+
hidden_size (`int`, *optional*, defaults to 8192):
|
| 56 |
+
Dimension of the hidden representations.
|
| 57 |
+
intermediate_size (`int`, *optional*, defaults to 29568):
|
| 58 |
+
Dimension of the MLP representations.
|
| 59 |
+
num_hidden_layers (`int`, *optional*, defaults to 80):
|
| 60 |
+
Number of hidden layers in the Transformer encoder.
|
| 61 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
| 62 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 63 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 64 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 65 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 66 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 67 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 68 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 69 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
| 70 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 71 |
+
The non-linear activation function (function or string) in the decoder.
|
| 72 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 73 |
+
The maximum sequence length that this model might ever be used with.
|
| 74 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 75 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 76 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 77 |
+
The epsilon used by the rms normalization layers.
|
| 78 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 79 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 80 |
+
relevant if `config.is_decoder=True`.
|
| 81 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 82 |
+
Whether the model's input and output word embeddings should be tied.
|
| 83 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
| 84 |
+
The base period of the RoPE embeddings.
|
| 85 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 86 |
+
Whether to use sliding window attention.
|
| 87 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 88 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 89 |
+
max_window_layers (`int`, *optional*, defaults to 80):
|
| 90 |
+
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
| 91 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 92 |
+
The dropout ratio for the attention probabilities.
|
| 93 |
+
vision_config (`Dict`, *optional*):
|
| 94 |
+
The config for the visual encoder initialization.
|
| 95 |
+
rope_scaling (`Dict`, *optional*):
|
| 96 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 97 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 98 |
+
accordingly.
|
| 99 |
+
Expected contents:
|
| 100 |
+
`rope_type` (`str`):
|
| 101 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 102 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 103 |
+
`factor` (`float`, *optional*):
|
| 104 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 105 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 106 |
+
original maximum pre-trained length.
|
| 107 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 108 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 109 |
+
pretraining.
|
| 110 |
+
`attention_factor` (`float`, *optional*):
|
| 111 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 112 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 113 |
+
`factor` field to infer the suggested value.
|
| 114 |
+
`beta_fast` (`float`, *optional*):
|
| 115 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 116 |
+
ramp function. If unspecified, it defaults to 32.
|
| 117 |
+
`beta_slow` (`float`, *optional*):
|
| 118 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 119 |
+
ramp function. If unspecified, it defaults to 1.
|
| 120 |
+
`short_factor` (`List[float]`, *optional*):
|
| 121 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 122 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 123 |
+
size divided by the number of attention heads divided by 2
|
| 124 |
+
`long_factor` (`List[float]`, *optional*):
|
| 125 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 126 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 127 |
+
size divided by the number of attention heads divided by 2
|
| 128 |
+
`low_freq_factor` (`float`, *optional*):
|
| 129 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 130 |
+
`high_freq_factor` (`float`, *optional*):
|
| 131 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 132 |
+
|
| 133 |
+
```python
|
| 134 |
+
>>> from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLConfig
|
| 135 |
+
|
| 136 |
+
>>> # Initializing a Qwen2_5_VL style configuration
|
| 137 |
+
>>> configuration = Qwen2_5_VLConfig()
|
| 138 |
+
|
| 139 |
+
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
| 140 |
+
>>> model = Qwen2_5_VLForConditionalGeneration(configuration)
|
| 141 |
+
|
| 142 |
+
>>> # Accessing the model configuration
|
| 143 |
+
>>> configuration = model.config
|
| 144 |
+
```"""
|
| 145 |
+
|
| 146 |
+
model_type = "dl"
|
| 147 |
+
sub_configs = {"dna_config": DLDNAConfig}
|
| 148 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 149 |
+
# Default tensor parallel plan for base model `Qwen2_5_VL`
|
| 150 |
+
base_model_tp_plan = {
|
| 151 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 152 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 153 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 154 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 155 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 156 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 157 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 158 |
+
}
|
| 159 |
+
base_model_pp_plan = {
|
| 160 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 161 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 162 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
vocab_size=152064,
|
| 168 |
+
hidden_size=8192,
|
| 169 |
+
intermediate_size=29568,
|
| 170 |
+
num_hidden_layers=80,
|
| 171 |
+
num_attention_heads=64,
|
| 172 |
+
num_key_value_heads=8,
|
| 173 |
+
hidden_act="silu",
|
| 174 |
+
max_position_embeddings=32768,
|
| 175 |
+
initializer_range=0.02,
|
| 176 |
+
rms_norm_eps=1e-05,
|
| 177 |
+
use_cache=True,
|
| 178 |
+
tie_word_embeddings=False,
|
| 179 |
+
rope_theta=1000000.0,
|
| 180 |
+
use_sliding_window=False,
|
| 181 |
+
sliding_window=4096,
|
| 182 |
+
max_window_layers=80,
|
| 183 |
+
attention_dropout=0.0,
|
| 184 |
+
vision_config=None,
|
| 185 |
+
rope_scaling=None,
|
| 186 |
+
image_token_id=None,
|
| 187 |
+
**kwargs,
|
| 188 |
+
):
|
| 189 |
+
if isinstance(vision_config, dict):
|
| 190 |
+
self.vision_config = self.sub_configs["vision_config"](**vision_config)
|
| 191 |
+
elif vision_config is None:
|
| 192 |
+
self.vision_config = self.sub_configs["vision_config"]()
|
| 193 |
+
|
| 194 |
+
self.vocab_size = vocab_size
|
| 195 |
+
self.max_position_embeddings = max_position_embeddings
|
| 196 |
+
self.hidden_size = hidden_size
|
| 197 |
+
self.intermediate_size = intermediate_size
|
| 198 |
+
self.num_hidden_layers = num_hidden_layers
|
| 199 |
+
self.num_attention_heads = num_attention_heads
|
| 200 |
+
self.use_sliding_window = use_sliding_window
|
| 201 |
+
self.sliding_window = sliding_window
|
| 202 |
+
self.max_window_layers = max_window_layers
|
| 203 |
+
|
| 204 |
+
# for backward compatibility
|
| 205 |
+
if num_key_value_heads is None:
|
| 206 |
+
num_key_value_heads = num_attention_heads
|
| 207 |
+
|
| 208 |
+
self.num_key_value_heads = num_key_value_heads
|
| 209 |
+
self.hidden_act = hidden_act
|
| 210 |
+
self.initializer_range = initializer_range
|
| 211 |
+
self.rms_norm_eps = rms_norm_eps
|
| 212 |
+
self.use_cache = use_cache
|
| 213 |
+
self.rope_theta = rope_theta
|
| 214 |
+
self.attention_dropout = attention_dropout
|
| 215 |
+
self.rope_scaling = rope_scaling
|
| 216 |
+
|
| 217 |
+
self.dna_token_id = image_token_id
|
| 218 |
+
|
| 219 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 220 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 221 |
+
# and change type from 'mrope' to 'default' because `mrope` does default RoPE calculations
|
| 222 |
+
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
| 223 |
+
# TODO: @raushan update config in the hub
|
| 224 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 225 |
+
if self.rope_scaling["type"] == "mrope":
|
| 226 |
+
self.rope_scaling["type"] = "default"
|
| 227 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 228 |
+
rope_config_validation(self, ignore_keys={"mrope_section"})
|
| 229 |
+
|
| 230 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| 231 |
+
|
| 232 |
+
__all__ = ["DLConfig"]
|
BioReason/bioreason/models/dl/processing_dl.py
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union, Dict, Any, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
from transformers.processing_utils import (
|
| 9 |
+
CommonKwargs,
|
| 10 |
+
ProcessingKwargs,
|
| 11 |
+
ProcessorMixin,
|
| 12 |
+
Unpack,
|
| 13 |
+
)
|
| 14 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 15 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 16 |
+
from transformers.utils import logging
|
| 17 |
+
|
| 18 |
+
from bioreason.utils.dna_utils import DNAInput
|
| 19 |
+
|
| 20 |
+
class DLDNAKwargs(CommonKwargs):
|
| 21 |
+
"""Keyword arguments specific to DNA processing"""
|
| 22 |
+
max_length_text: Optional[int]
|
| 23 |
+
max_length_dna: Optional[int]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DLProcessorKwargs(ProcessingKwargs, total=False):
|
| 27 |
+
"""Processing keyword arguments for the DL processor"""
|
| 28 |
+
dna_kwargs: DLDNAKwargs
|
| 29 |
+
_defaults = {
|
| 30 |
+
"text_kwargs": {
|
| 31 |
+
"padding": False,
|
| 32 |
+
},
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
class DLProcessor(ProcessorMixin):
|
| 36 |
+
r"""
|
| 37 |
+
Constructs a DL processor which wraps a NucleotideTransformer DNA processor and a Qwen2_5 tokenizer into a single processor.
|
| 38 |
+
This processor handles both text and DNA sequence processing to prepare inputs for the DNALLMModel.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
tokenizer (PreTrainedTokenizerBase, *optional*):
|
| 42 |
+
The text tokenizer used for processing text inputs.
|
| 43 |
+
dna_tokenizer (PreTrainedTokenizerBase, *optional*):
|
| 44 |
+
The DNA tokenizer used for processing DNA sequences.
|
| 45 |
+
chat_template (`str`, *optional*):
|
| 46 |
+
A Jinja template for chat formatting. If None, will use the tokenizer's template.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
attributes = ["tokenizer", "dna_tokenizer"]
|
| 50 |
+
valid_kwargs = ["model", "chat_template"]
|
| 51 |
+
tokenizer_class = (
|
| 52 |
+
"Qwen2Tokenizer", "Qwen2TokenizerFast",
|
| 53 |
+
"GPT2TokenizerFast",
|
| 54 |
+
)
|
| 55 |
+
dna_tokenizer_class = ("EsmTokenizer", "Evo2Tokenizer")
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self, tokenizer=None, dna_tokenizer=None, chat_template=None, **kwargs
|
| 59 |
+
):
|
| 60 |
+
"""
|
| 61 |
+
Initialize the processor with text and DNA tokenizers.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
tokenizer: Text tokenizer (usually from a language model)
|
| 65 |
+
dna_tokenizer: DNA tokenizer (usually from a DNA model)
|
| 66 |
+
chat_template: Template for formatting chat conversations
|
| 67 |
+
**kwargs: Additional arguments
|
| 68 |
+
"""
|
| 69 |
+
self.tokenizer = tokenizer
|
| 70 |
+
self.dna_tokenizer = dna_tokenizer
|
| 71 |
+
|
| 72 |
+
self.dna_token = (
|
| 73 |
+
"<|dna_pad|>"
|
| 74 |
+
if not hasattr(self.tokenizer, "dna_token")
|
| 75 |
+
else self.tokenizer.dna_token
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Get chat template from tokenizer if not provided
|
| 79 |
+
if chat_template is None and hasattr(self.tokenizer, "chat_template"):
|
| 80 |
+
chat_template = self.tokenizer.chat_template
|
| 81 |
+
super().__init__(tokenizer, dna_tokenizer, chat_template=chat_template)
|
| 82 |
+
|
| 83 |
+
# The GRPO trainer might expect this to be set
|
| 84 |
+
if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
|
| 85 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 86 |
+
|
| 87 |
+
def tokenize_dna_sequences(
|
| 88 |
+
self,
|
| 89 |
+
batch_dna_sequences: List[List[str]],
|
| 90 |
+
max_length: int = 2048,
|
| 91 |
+
return_tensors: str = "pt",
|
| 92 |
+
device: str = "cuda",
|
| 93 |
+
) -> Dict[str, Any]:
|
| 94 |
+
"""
|
| 95 |
+
Tokenize a batch of DNA sequences.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
batch_dna_sequences: List of lists of DNA sequences per batch item
|
| 99 |
+
max_length: Maximum allowed length for DNA sequences
|
| 100 |
+
return_tensors: Return format for tensors ("pt" for PyTorch)
|
| 101 |
+
device: Device to place tensors on
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Dict containing:
|
| 105 |
+
- dna_tokenized: The tokenized DNA sequences
|
| 106 |
+
- batch_idx_map: Mapping of which sequences belong to which batch item
|
| 107 |
+
"""
|
| 108 |
+
# Create a mapping to track which sequences belong to which batch item
|
| 109 |
+
batch_idx_map = []
|
| 110 |
+
all_sequences = []
|
| 111 |
+
|
| 112 |
+
# Flatten all sequences with batch tracking
|
| 113 |
+
for batch_idx, dna_sequences in enumerate(batch_dna_sequences):
|
| 114 |
+
for seq in dna_sequences:
|
| 115 |
+
all_sequences.append(seq)
|
| 116 |
+
batch_idx_map.append(batch_idx)
|
| 117 |
+
|
| 118 |
+
# If no sequences in the entire batch, return empty dict
|
| 119 |
+
if not all_sequences:
|
| 120 |
+
return {"dna_tokenized": None, "batch_idx_map": []}
|
| 121 |
+
|
| 122 |
+
# Tokenize all sequences at once
|
| 123 |
+
dna_tokenized = self.dna_tokenizer(
|
| 124 |
+
all_sequences,
|
| 125 |
+
padding=True,
|
| 126 |
+
truncation=True,
|
| 127 |
+
max_length=max_length,
|
| 128 |
+
return_tensors=return_tensors,
|
| 129 |
+
return_attention_mask=True,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
return {"dna_tokenized": dna_tokenized, "batch_idx_map": batch_idx_map}
|
| 133 |
+
|
| 134 |
+
def __call__(
|
| 135 |
+
self,
|
| 136 |
+
batch_dna_sequences: Optional[List[List[str]]] = None,
|
| 137 |
+
text: Optional[
|
| 138 |
+
Union[
|
| 139 |
+
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
| 140 |
+
]
|
| 141 |
+
] = None,
|
| 142 |
+
max_length_text: int = 512,
|
| 143 |
+
max_length_dna: int = 2048,
|
| 144 |
+
return_tensors: str = "pt",
|
| 145 |
+
device: str = "cuda",
|
| 146 |
+
**kwargs: Unpack[DLProcessorKwargs],
|
| 147 |
+
) -> BatchFeature:
|
| 148 |
+
"""
|
| 149 |
+
Process text and DNA sequences for model input.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
batch_dna_sequences: List of lists of DNA sequences per batch item
|
| 153 |
+
text: Input text or list of texts
|
| 154 |
+
max_length_text: Maximum length for text sequences
|
| 155 |
+
max_length_dna: Maximum length for DNA sequences
|
| 156 |
+
return_tensors: Return format for tensors
|
| 157 |
+
device: Device to place tensors on
|
| 158 |
+
**kwargs: Additional processor keyword arguments
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
BatchFeature with tokenized inputs for the model
|
| 162 |
+
"""
|
| 163 |
+
output_kwargs = self._merge_kwargs(
|
| 164 |
+
DLProcessorKwargs,
|
| 165 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 166 |
+
**kwargs,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Ensure text is a list
|
| 170 |
+
if not isinstance(text, list):
|
| 171 |
+
text = [text]
|
| 172 |
+
|
| 173 |
+
# flattened_dna_sequences = [dna_sequence for dna_sequences in batch_dna_sequences for dna_sequence in dna_sequences]
|
| 174 |
+
dna_inputs = {}
|
| 175 |
+
if batch_dna_sequences is not None:
|
| 176 |
+
# Tokenize DNA sequences
|
| 177 |
+
dna_processing_result = self.tokenize_dna_sequences(
|
| 178 |
+
batch_dna_sequences,
|
| 179 |
+
max_length=max_length_dna,
|
| 180 |
+
return_tensors=return_tensors,
|
| 181 |
+
device=device,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Replace DNA tokens in text if needed
|
| 185 |
+
index = 0
|
| 186 |
+
for i in range(len(text)):
|
| 187 |
+
while self.dna_token in text[i]:
|
| 188 |
+
num_dna_tokens = (dna_processing_result['dna_tokenized']['input_ids'][index] != 1).sum().item()
|
| 189 |
+
text[i] = text[i].replace(
|
| 190 |
+
self.dna_token, "<|placeholder|>" * num_dna_tokens, 1
|
| 191 |
+
)
|
| 192 |
+
index += 1
|
| 193 |
+
text[i] = text[i].replace("<|placeholder|>", self.dna_token)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Add batch info to the output
|
| 198 |
+
dna_inputs = {
|
| 199 |
+
# "batch_dna_sequences": batch_dna_sequences,
|
| 200 |
+
"dna_tokenized": dna_processing_result["dna_tokenized"],
|
| 201 |
+
"batch_idx_map": dna_processing_result["batch_idx_map"],
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
# Tokenize text
|
| 205 |
+
text_kwargs = output_kwargs.get("text_kwargs", {})
|
| 206 |
+
|
| 207 |
+
if 'padding' in text_kwargs:
|
| 208 |
+
del text_kwargs['padding']
|
| 209 |
+
|
| 210 |
+
# print("__call__ (processor):", text)
|
| 211 |
+
text_inputs = self.tokenizer(
|
| 212 |
+
text,
|
| 213 |
+
max_length=max_length_text + 2 * max_length_dna,
|
| 214 |
+
return_tensors=return_tensors,
|
| 215 |
+
padding=True,
|
| 216 |
+
truncation=True,
|
| 217 |
+
**text_kwargs,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# The BatchFeature should have all required fields for the model's forward pass
|
| 221 |
+
return BatchFeature(data={**text_inputs, **dna_inputs})
|
| 222 |
+
|
| 223 |
+
def batch_decode(self, *args, **kwargs) -> List[str]:
|
| 224 |
+
"""
|
| 225 |
+
This method forwards all its arguments to the tokenizer's batch_decode.
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
List of decoded strings
|
| 229 |
+
"""
|
| 230 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 231 |
+
|
| 232 |
+
def decode(self, *args, **kwargs) -> str:
|
| 233 |
+
"""
|
| 234 |
+
This method forwards all its arguments to the tokenizer's decode.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
Decoded string
|
| 238 |
+
"""
|
| 239 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 240 |
+
|
| 241 |
+
def post_process_dna_to_text(
|
| 242 |
+
self,
|
| 243 |
+
generated_outputs: torch.Tensor,
|
| 244 |
+
skip_special_tokens: bool = True,
|
| 245 |
+
**kwargs,
|
| 246 |
+
) -> List[str]:
|
| 247 |
+
"""
|
| 248 |
+
Post-process the model output to decode the text.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
generated_outputs: The token IDs generated by the model
|
| 252 |
+
skip_special_tokens: Whether to skip special tokens in the output
|
| 253 |
+
**kwargs: Additional arguments for the decoder
|
| 254 |
+
|
| 255 |
+
Returns:
|
| 256 |
+
List of decoded strings
|
| 257 |
+
"""
|
| 258 |
+
return self.tokenizer.batch_decode(
|
| 259 |
+
generated_outputs,
|
| 260 |
+
skip_special_tokens=skip_special_tokens,
|
| 261 |
+
**kwargs,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
@property
|
| 265 |
+
def model_input_names(self) -> List[str]:
|
| 266 |
+
"""
|
| 267 |
+
Get the input names expected by the model.
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
List of input names
|
| 271 |
+
"""
|
| 272 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 273 |
+
dna_input_names = ["dna_tokenized", "batch_idx_map"]
|
| 274 |
+
|
| 275 |
+
return list(dict.fromkeys(tokenizer_input_names + dna_input_names))
|
BioReason/bioreason/models/dna_llm.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoTokenizer,
|
| 7 |
+
AutoModelForCausalLM,
|
| 8 |
+
AutoModelForMaskedLM,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from typing import Optional, List, Dict, Any, Union, Tuple
|
| 12 |
+
|
| 13 |
+
from bioreason.utils.dna_utils import DNAInput
|
| 14 |
+
from bioreason.models.dl.processing_dl import DLProcessor
|
| 15 |
+
from bioreason.models.dl.chat_template_dl import CHAT_TEMPLATE
|
| 16 |
+
from bioreason.models.evo2_tokenizer import Evo2Tokenizer
|
| 17 |
+
|
| 18 |
+
class DNALLMModel(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
A combined model that processes both DNA sequences and text inputs.
|
| 21 |
+
|
| 22 |
+
The model uses a DNA encoder (like NucleotideTransformer) to extract features from DNA sequences
|
| 23 |
+
and a text model (LLM) to process text inputs and generate responses. The DNA features are
|
| 24 |
+
projected to the text model's embedding space and prepended to the text embeddings.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
text_model_name: str,
|
| 30 |
+
dna_model_name: str,
|
| 31 |
+
cache_dir: Optional[str] = None,
|
| 32 |
+
max_length_dna: int = 2048,
|
| 33 |
+
max_length_text: int = 512,
|
| 34 |
+
text_model_finetune: bool = True,
|
| 35 |
+
dna_model_finetune: bool = True,
|
| 36 |
+
dna_is_evo2: bool = False,
|
| 37 |
+
dna_embedding_layer: str = None
|
| 38 |
+
):
|
| 39 |
+
"""
|
| 40 |
+
Initialize the DNALLMModel.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
text_model_name: Name of the text model to be used.
|
| 44 |
+
dna_model_name: Name of the DNA model to be used.
|
| 45 |
+
cache_dir: Directory to cache the models.
|
| 46 |
+
max_length_dna: Maximum length of DNA sequences. Defaults to 2048.
|
| 47 |
+
max_length_text: Maximum length of text sequences. Defaults to 512.
|
| 48 |
+
text_model_finetune: Whether to finetune the text model. Defaults to True.
|
| 49 |
+
dna_model_finetune: Whether to finetune the DNA model. Defaults to True.
|
| 50 |
+
dna_is_evo2: Whether the DNA model is Evo2. Defaults to False.
|
| 51 |
+
dna_embedding_layer: Name of the layer to use for the Evo2 model. Defaults to None.
|
| 52 |
+
"""
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
self.text_model_finetune = text_model_finetune
|
| 56 |
+
self.dna_model_finetune = dna_model_finetune
|
| 57 |
+
self.max_length_dna = max_length_dna
|
| 58 |
+
self.max_length_text = max_length_text
|
| 59 |
+
self.dna_is_evo2 = dna_is_evo2
|
| 60 |
+
self.dna_embedding_layer = dna_embedding_layer
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# Load the text model and tokenizer
|
| 64 |
+
self.text_model = AutoModelForCausalLM.from_pretrained(
|
| 65 |
+
text_model_name, cache_dir=cache_dir, trust_remote_code=True
|
| 66 |
+
)
|
| 67 |
+
self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name, trust_remote_code=True)
|
| 68 |
+
self.text_config = self.text_model.config
|
| 69 |
+
self.text_tokenizer.chat_template = CHAT_TEMPLATE
|
| 70 |
+
self.text_tokenizer.pad_token = self.text_tokenizer.eos_token
|
| 71 |
+
|
| 72 |
+
new_tokens = ["<|dna_start|>", "<|dna_pad|>", "<|dna_end|>"]
|
| 73 |
+
self.text_tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
|
| 74 |
+
self.dna_token_id = self.text_tokenizer.convert_tokens_to_ids("<|dna_pad|>")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Load the DNA model and tokenizer
|
| 78 |
+
if not self.dna_is_evo2:
|
| 79 |
+
self.dna_model = AutoModelForMaskedLM.from_pretrained(
|
| 80 |
+
dna_model_name, cache_dir=cache_dir, trust_remote_code=True
|
| 81 |
+
)
|
| 82 |
+
self.dna_tokenizer = AutoTokenizer.from_pretrained(dna_model_name, trust_remote_code=True)
|
| 83 |
+
self.dna_config = self.dna_model.config
|
| 84 |
+
|
| 85 |
+
else:
|
| 86 |
+
from evo2 import Evo2
|
| 87 |
+
self.dna_model = Evo2(dna_model_name)
|
| 88 |
+
self.dna_tokenizer = Evo2Tokenizer(self.dna_model.tokenizer)
|
| 89 |
+
self.dna_config = self.dna_model.model.config
|
| 90 |
+
self.dna_embedding_layer = self.dna_embedding_layer
|
| 91 |
+
|
| 92 |
+
# Get model dimensions
|
| 93 |
+
self.text_hidden_size = self.text_config.hidden_size
|
| 94 |
+
self.dna_hidden_size = self.dna_config.hidden_size
|
| 95 |
+
|
| 96 |
+
# Create projection layer to map DNA embeddings to text model's embedding space
|
| 97 |
+
self.dna_projection = nn.Linear(self.dna_hidden_size, self.text_hidden_size)
|
| 98 |
+
|
| 99 |
+
# Create processor for handling inputs
|
| 100 |
+
self.processor = DLProcessor(tokenizer=self.text_tokenizer, dna_tokenizer=self.dna_tokenizer)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def process_dna_embeddings(
|
| 104 |
+
self,
|
| 105 |
+
dna_tokenized: Dict[str, torch.Tensor],
|
| 106 |
+
batch_idx_map: List[int],
|
| 107 |
+
batch_size: int,
|
| 108 |
+
) -> List[torch.Tensor]:
|
| 109 |
+
"""
|
| 110 |
+
Process DNA sequences to obtain embeddings.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
dna_tokenized: Tokenized DNA sequences
|
| 114 |
+
batch_idx_map: Mapping of each sequence to its batch item
|
| 115 |
+
batch_size: Number of items in the batch
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
List of tensor embeddings for each batch item
|
| 119 |
+
"""
|
| 120 |
+
# Process all sequences to get DNA representations
|
| 121 |
+
with torch.no_grad():
|
| 122 |
+
# Handle different model types based on dna_is_evo2 attribute
|
| 123 |
+
if self.dna_is_evo2 and self.dna_embedding_layer is not None: # Evo2 model
|
| 124 |
+
# Get embeddings from the specific layer in Evo2
|
| 125 |
+
hidden_states_list = []
|
| 126 |
+
|
| 127 |
+
for seq_idx in range(len(dna_tokenized["input_ids"])):
|
| 128 |
+
# Extract single sequence
|
| 129 |
+
input_ids = dna_tokenized["input_ids"][seq_idx:seq_idx+1]
|
| 130 |
+
|
| 131 |
+
# Call Evo2 with return_embeddings=True
|
| 132 |
+
_, embeddings = self.dna_model(
|
| 133 |
+
input_ids,
|
| 134 |
+
return_embeddings=True,
|
| 135 |
+
layer_names=[self.dna_embedding_layer]
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Get embeddings for the specified layer
|
| 139 |
+
seq_embeddings = embeddings[self.dna_embedding_layer].squeeze(0)
|
| 140 |
+
hidden_states_list.append(seq_embeddings)
|
| 141 |
+
|
| 142 |
+
# Stack to get same format as non-Evo2 output
|
| 143 |
+
if hidden_states_list:
|
| 144 |
+
hidden_states = torch.stack(hidden_states_list)
|
| 145 |
+
else:
|
| 146 |
+
return [torch.zeros((0, self.text_hidden_size)) for _ in range(batch_size)]
|
| 147 |
+
|
| 148 |
+
else: # Standard HuggingFace model
|
| 149 |
+
# Use existing code path for HF models
|
| 150 |
+
outputs = self.dna_model(
|
| 151 |
+
input_ids=dna_tokenized["input_ids"],
|
| 152 |
+
attention_mask=dna_tokenized["attention_mask"],
|
| 153 |
+
output_hidden_states=True,
|
| 154 |
+
)
|
| 155 |
+
# Get the last hidden state
|
| 156 |
+
hidden_states = outputs.hidden_states[-1] # shape: [n_seqs, seq_len, hidden_dim]
|
| 157 |
+
|
| 158 |
+
# Project all embeddings at once
|
| 159 |
+
hidden_states = hidden_states.to(device=self.dna_projection.weight.device, dtype=self.dna_projection.weight.dtype)
|
| 160 |
+
projected_states = self.dna_projection(hidden_states)
|
| 161 |
+
|
| 162 |
+
# Group embeddings by batch item
|
| 163 |
+
result = [[] for _ in range(batch_size)]
|
| 164 |
+
|
| 165 |
+
# For each sequence, get its embeddings and add to appropriate batch result
|
| 166 |
+
for seq_idx, batch_idx in enumerate(batch_idx_map):
|
| 167 |
+
# Get only the valid (non-padding) tokens
|
| 168 |
+
valid_length = dna_tokenized["attention_mask"][seq_idx].sum().item()
|
| 169 |
+
seq_embedding = projected_states[seq_idx, :valid_length]
|
| 170 |
+
result[batch_idx].append(seq_embedding)
|
| 171 |
+
|
| 172 |
+
# Concatenate embeddings for each batch item
|
| 173 |
+
for i in range(batch_size):
|
| 174 |
+
if result[i]:
|
| 175 |
+
result[i] = torch.cat(result[i], dim=0)
|
| 176 |
+
else:
|
| 177 |
+
result[i] = torch.zeros((0, self.text_hidden_size))
|
| 178 |
+
|
| 179 |
+
return result
|
| 180 |
+
|
| 181 |
+
def forward(
|
| 182 |
+
self,
|
| 183 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 184 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 185 |
+
dna_tokenized: Optional[Dict[str, torch.Tensor]] = None,
|
| 186 |
+
batch_idx_map: Optional[List[int]] = None,
|
| 187 |
+
labels: Optional[torch.Tensor] = None,
|
| 188 |
+
**kwargs,
|
| 189 |
+
) -> torch.Tensor:
|
| 190 |
+
"""
|
| 191 |
+
Generate text based on DNA and text inputs.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
input_ids: Input IDs (used if provided directly)
|
| 195 |
+
attention_mask: Attention mask (used if provided directly)
|
| 196 |
+
dna_tokenized: Tokenized DNA sequences (used if provided directly)
|
| 197 |
+
batch_idx_map: Batch mapping for DNA sequences (used if provided directly)
|
| 198 |
+
labels: Labels for supervised fine-tuning (used if provided directly)
|
| 199 |
+
**kwargs: Additional arguments for generation
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Outputs from the text model
|
| 203 |
+
"""
|
| 204 |
+
# Ensure required inputs are available
|
| 205 |
+
if input_ids is None or attention_mask is None:
|
| 206 |
+
raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided")
|
| 207 |
+
|
| 208 |
+
batch_size = input_ids.shape[0]
|
| 209 |
+
|
| 210 |
+
# Get text embeddings from the model's embedding layer
|
| 211 |
+
text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
| 212 |
+
|
| 213 |
+
if dna_tokenized is not None and batch_idx_map:
|
| 214 |
+
batch_dna_embeds = self.process_dna_embeddings(dna_tokenized, batch_idx_map, batch_size)
|
| 215 |
+
|
| 216 |
+
mask = input_ids == self.dna_token_id
|
| 217 |
+
|
| 218 |
+
n_dna_tokens = mask.sum().item()
|
| 219 |
+
dna_embeds_flat = torch.cat(batch_dna_embeds, dim=0)
|
| 220 |
+
n_dna_features = dna_embeds_flat.shape[0]
|
| 221 |
+
|
| 222 |
+
if n_dna_features != n_dna_tokens:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"DNA features and DNA tokens do not match: features {n_dna_features}, tokens: {n_dna_tokens}"
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Ensure DNA embeddings have the same dtype as the text embeddings
|
| 228 |
+
dna_embeds_flat = dna_embeds_flat.to(dtype=text_inputs_embeds.dtype)
|
| 229 |
+
text_inputs_embeds[mask] = dna_embeds_flat
|
| 230 |
+
|
| 231 |
+
# Handle labels if provided (for training)
|
| 232 |
+
if labels is not None:
|
| 233 |
+
# TODO: Implement this
|
| 234 |
+
pass
|
| 235 |
+
|
| 236 |
+
# Forward pass through the text model (loss is computed if labels is provided)
|
| 237 |
+
outputs = self.text_model(
|
| 238 |
+
inputs_embeds=text_inputs_embeds,
|
| 239 |
+
attention_mask=attention_mask,
|
| 240 |
+
labels=labels,
|
| 241 |
+
**kwargs,
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
return outputs
|
| 245 |
+
|
| 246 |
+
def generate(
|
| 247 |
+
self,
|
| 248 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 249 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 250 |
+
dna_tokenized: Optional[Dict[str, torch.Tensor]] = None,
|
| 251 |
+
batch_idx_map: Optional[List[int]] = None,
|
| 252 |
+
**generation_kwargs,
|
| 253 |
+
) -> Union[torch.Tensor, List[str]]:
|
| 254 |
+
"""
|
| 255 |
+
Generate text based on DNA and text inputs.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
inputs: The preprocessed inputs from the processor (preferred method)
|
| 259 |
+
batch_dna_sequences: List of lists of DNA sequences per batch item (legacy method)
|
| 260 |
+
input_texts: List of input texts (legacy method)
|
| 261 |
+
input_ids: Input IDs (used if provided directly)
|
| 262 |
+
attention_mask: Attention mask (used if provided directly)
|
| 263 |
+
dna_tokenized: Tokenized DNA sequences (used if provided directly)
|
| 264 |
+
batch_idx_map: Batch mapping for DNA sequences (used if provided directly)
|
| 265 |
+
**generation_kwargs: Additional arguments for generation
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Generated token IDs which can be decoded using the processor
|
| 269 |
+
"""
|
| 270 |
+
# Ensure required inputs are available
|
| 271 |
+
if input_ids is None or attention_mask is None:
|
| 272 |
+
raise ValueError("Either 'inputs' or 'input_ids'/'attention_mask' must be provided")
|
| 273 |
+
|
| 274 |
+
batch_size = input_ids.shape[0]
|
| 275 |
+
|
| 276 |
+
# Get text embeddings from the model's embedding layer
|
| 277 |
+
text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
| 278 |
+
|
| 279 |
+
if dna_tokenized is not None and batch_idx_map:
|
| 280 |
+
batch_dna_embeds = self.process_dna_embeddings(dna_tokenized, batch_idx_map, batch_size)
|
| 281 |
+
|
| 282 |
+
mask = input_ids == self.dna_token_id
|
| 283 |
+
|
| 284 |
+
n_dna_tokens = mask.sum().item()
|
| 285 |
+
dna_embeds_flat = torch.cat(batch_dna_embeds, dim=0)
|
| 286 |
+
n_dna_features = dna_embeds_flat.shape[0]
|
| 287 |
+
|
| 288 |
+
if n_dna_features != n_dna_tokens:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"DNA features and DNA tokens do not match: features {n_dna_features}, tokens: {n_dna_tokens}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Ensure DNA embeddings have the same dtype as the text embeddings
|
| 294 |
+
dna_embeds_flat = dna_embeds_flat.to(dtype=text_inputs_embeds.dtype)
|
| 295 |
+
text_inputs_embeds[mask] = dna_embeds_flat
|
| 296 |
+
|
| 297 |
+
# Generation parameters may need adjustment based on model type
|
| 298 |
+
with torch.no_grad():
|
| 299 |
+
outputs = self.text_model.generate(
|
| 300 |
+
inputs_embeds=text_inputs_embeds,
|
| 301 |
+
attention_mask=attention_mask,
|
| 302 |
+
use_cache=True,
|
| 303 |
+
**generation_kwargs,
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
return outputs
|
BioReason/bioreason/models/dna_only.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from typing import Dict
|
| 5 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SelfAttentionPooling(nn.Module):
|
| 9 |
+
def __init__(self, hidden_size, num_heads=8):
|
| 10 |
+
super().__init__()
|
| 11 |
+
# Use PyTorch's built-in multi-head attention
|
| 12 |
+
self.attention = nn.MultiheadAttention(
|
| 13 |
+
embed_dim=hidden_size,
|
| 14 |
+
num_heads=num_heads,
|
| 15 |
+
batch_first=True
|
| 16 |
+
)
|
| 17 |
+
# Learnable query vector
|
| 18 |
+
self.query = nn.Parameter(torch.randn(1, 1, hidden_size))
|
| 19 |
+
|
| 20 |
+
def forward(self, embeddings, attention_mask=None):
|
| 21 |
+
# Expand query to batch size
|
| 22 |
+
batch_size = embeddings.size(0)
|
| 23 |
+
query = self.query.expand(batch_size, -1, -1)
|
| 24 |
+
|
| 25 |
+
# Create key padding mask from attention mask if provided
|
| 26 |
+
key_padding_mask = None
|
| 27 |
+
if attention_mask is not None:
|
| 28 |
+
key_padding_mask = attention_mask == 0 # Convert to boolean mask where True means ignore
|
| 29 |
+
|
| 30 |
+
# Apply attention: query attends to embeddings
|
| 31 |
+
context, _ = self.attention(
|
| 32 |
+
query=query, # [batch_size, 1, hidden_size]
|
| 33 |
+
key=embeddings, # [batch_size, seq_len, hidden_size]
|
| 34 |
+
value=embeddings, # [batch_size, seq_len, hidden_size]
|
| 35 |
+
key_padding_mask=key_padding_mask
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Squeeze out the singleton dimension
|
| 39 |
+
return context.squeeze(1) # [batch_size, hidden_size]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class DNAClassifierModel(nn.Module):
|
| 43 |
+
"""
|
| 44 |
+
A simple classifier that uses a DNA model with a classification head.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
dna_model_name: str,
|
| 50 |
+
cache_dir: str = None,
|
| 51 |
+
max_length_dna: int = 4096,
|
| 52 |
+
num_classes: int = 2, # Binary classification by default
|
| 53 |
+
dna_is_evo2: bool = False,
|
| 54 |
+
dna_embedding_layer: str = None,
|
| 55 |
+
train_just_classifier: bool = True
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Initialize the DNAClassifierModel.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
dna_model_name (str): Name of the DNA model to use
|
| 62 |
+
cache_dir (str): Directory to cache models
|
| 63 |
+
max_length_dna (int): Maximum sequence length
|
| 64 |
+
num_classes (int): Number of output classes
|
| 65 |
+
dna_is_evo2: Whether the DNA model is Evo2. Defaults to False
|
| 66 |
+
dna_embedding_layer: Name of the layer to use for the Evo2 model. Defaults to None
|
| 67 |
+
train_just_classifier: Whether to train just the classifier. Defaults to True
|
| 68 |
+
"""
|
| 69 |
+
super().__init__()
|
| 70 |
+
|
| 71 |
+
self.dna_model_name = dna_model_name
|
| 72 |
+
self.cache_dir = cache_dir
|
| 73 |
+
self.max_length_dna = max_length_dna
|
| 74 |
+
self.num_classes = num_classes
|
| 75 |
+
self.dna_is_evo2 = dna_is_evo2
|
| 76 |
+
self.dna_embedding_layer = dna_embedding_layer
|
| 77 |
+
self.train_just_classifier = train_just_classifier
|
| 78 |
+
|
| 79 |
+
# Load the DNA model and tokenizer
|
| 80 |
+
if not self.dna_is_evo2:
|
| 81 |
+
self.dna_model = AutoModelForMaskedLM.from_pretrained(
|
| 82 |
+
dna_model_name, cache_dir=cache_dir, trust_remote_code=True
|
| 83 |
+
)
|
| 84 |
+
self.dna_tokenizer = AutoTokenizer.from_pretrained(dna_model_name, trust_remote_code=True)
|
| 85 |
+
self.dna_config = self.dna_model.config
|
| 86 |
+
|
| 87 |
+
else:
|
| 88 |
+
from evo2 import Evo2
|
| 89 |
+
from bioreason.models.evo2_tokenizer import Evo2Tokenizer
|
| 90 |
+
self.dna_model = Evo2(dna_model_name)
|
| 91 |
+
self.dna_tokenizer = Evo2Tokenizer(self.dna_model.tokenizer)
|
| 92 |
+
self.dna_config = self.dna_model.model.config
|
| 93 |
+
self.dna_embedding_layer = self.dna_embedding_layer
|
| 94 |
+
|
| 95 |
+
# Get hidden size from model config
|
| 96 |
+
self.hidden_size = self.dna_config.hidden_size
|
| 97 |
+
|
| 98 |
+
# Add the self-attention pooling module
|
| 99 |
+
self.pooler = SelfAttentionPooling(self.hidden_size)
|
| 100 |
+
|
| 101 |
+
# Create classification head that takes concatenated embeddings from both sequences
|
| 102 |
+
self.classifier = nn.Sequential(
|
| 103 |
+
nn.Linear(self.hidden_size * 2, self.hidden_size),
|
| 104 |
+
nn.ReLU(),
|
| 105 |
+
nn.Dropout(0.1),
|
| 106 |
+
nn.Linear(self.hidden_size, num_classes),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.max_length_dna = max_length_dna
|
| 110 |
+
|
| 111 |
+
def get_dna_embedding(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
|
| 112 |
+
"""
|
| 113 |
+
Get DNA embedding for a single DNA sequence using self-attention pooling.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
input_ids: DNA tokenized sequence
|
| 117 |
+
attention_mask: DNA tokenized sequence attention mask
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
torch.Tensor: Tensor containing the self-attention pooled DNA embedding
|
| 121 |
+
"""
|
| 122 |
+
# Add batch dimension if not present
|
| 123 |
+
if input_ids.dim() == 1:
|
| 124 |
+
input_ids = input_ids.unsqueeze(0) # [1, seq_len]
|
| 125 |
+
|
| 126 |
+
# Handle attention mask - create if not provided or add batch dimension
|
| 127 |
+
if attention_mask is None:
|
| 128 |
+
attention_mask = torch.ones_like(input_ids)
|
| 129 |
+
elif attention_mask.dim() == 1:
|
| 130 |
+
attention_mask = attention_mask.unsqueeze(0) # [1, seq_len]
|
| 131 |
+
|
| 132 |
+
# Get embeddings from DNA model
|
| 133 |
+
with torch.set_grad_enabled(not self.train_just_classifier): # Enable gradients for fine-tuning
|
| 134 |
+
|
| 135 |
+
if self.dna_is_evo2 and self.dna_embedding_layer is not None: # Evo2 model
|
| 136 |
+
# Get embeddings from the specific layer in Evo2
|
| 137 |
+
_, embeddings = self.dna_model(
|
| 138 |
+
input_ids,
|
| 139 |
+
return_embeddings=True,
|
| 140 |
+
layer_names=[self.dna_embedding_layer]
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Get embeddings for the specified layer
|
| 144 |
+
hidden_states = embeddings[self.dna_embedding_layer]
|
| 145 |
+
|
| 146 |
+
else:
|
| 147 |
+
# Get embeddings from the last hidden state
|
| 148 |
+
outputs = self.dna_model(
|
| 149 |
+
input_ids,
|
| 150 |
+
attention_mask=attention_mask,
|
| 151 |
+
output_hidden_states=True,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Get the last hidden state
|
| 155 |
+
hidden_states = outputs.hidden_states[-1]
|
| 156 |
+
|
| 157 |
+
# Apply self-attention pooling to get a weighted representation
|
| 158 |
+
sequence_embedding = self.pooler(hidden_states, attention_mask)
|
| 159 |
+
return sequence_embedding.squeeze(0)
|
| 160 |
+
|
| 161 |
+
def forward(
|
| 162 |
+
self, ref_ids=None, alt_ids=None, ref_attention_mask=None, alt_attention_mask=None
|
| 163 |
+
):
|
| 164 |
+
"""
|
| 165 |
+
Forward pass of the model.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
ref_ids: Reference sequence token IDsself.dna_model
|
| 169 |
+
alt_ids: Alternate sequence token IDsself.dna_model
|
| 170 |
+
ref_attention_mask: Reference sequence attention maskself.dna_model
|
| 171 |
+
alt_attention_mask: Alternate sequence attention maskself.dna_model
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
torch.Tensor: Classification logits
|
| 175 |
+
"""
|
| 176 |
+
batch_size = ref_ids.shape[0] if ref_ids is not None else alt_ids.shape[0]
|
| 177 |
+
|
| 178 |
+
if batch_size is None:
|
| 179 |
+
raise ValueError("Either token IDs must be provided")
|
| 180 |
+
|
| 181 |
+
ref_embeddings = []
|
| 182 |
+
alt_embeddings = []
|
| 183 |
+
|
| 184 |
+
# Process each example in the batch
|
| 185 |
+
for i in range(batch_size):
|
| 186 |
+
|
| 187 |
+
# Get sequence embeddings
|
| 188 |
+
ref_embed = self.get_dna_embedding(ref_ids[i], ref_attention_mask[i])
|
| 189 |
+
alt_embed = self.get_dna_embedding(alt_ids[i], alt_attention_mask[i])
|
| 190 |
+
ref_embeddings.append(ref_embed)
|
| 191 |
+
alt_embeddings.append(alt_embed)
|
| 192 |
+
|
| 193 |
+
# Stack embeddings
|
| 194 |
+
ref_embeddings = torch.stack(ref_embeddings)
|
| 195 |
+
alt_embeddings = torch.stack(alt_embeddings)
|
| 196 |
+
|
| 197 |
+
# Concatenate ref and alt embeddings
|
| 198 |
+
combined_embeddings = torch.cat([ref_embeddings, alt_embeddings], dim=1)
|
| 199 |
+
|
| 200 |
+
# Pass through classifier
|
| 201 |
+
logits = self.classifier(combined_embeddings)
|
| 202 |
+
|
| 203 |
+
return logits
|
BioReason/bioreason/models/esm_tokenizer.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 2 |
+
from transformers.utils import logging
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List, Dict, Optional, Union, Tuple
|
| 8 |
+
|
| 9 |
+
logger = logging.get_logger(__name__)
|
| 10 |
+
|
| 11 |
+
class EsmTokenizer(PreTrainedTokenizer):
|
| 12 |
+
"""
|
| 13 |
+
Tokenizer for ESM models - wraps the ESM tokenizer to be compatible with HuggingFace interfaces.
|
| 14 |
+
This tokenizer handles protein sequences (amino acid sequences).
|
| 15 |
+
"""
|
| 16 |
+
vocab_files_names = {} # ESM tokenizer doesn't require vocab files
|
| 17 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 18 |
+
|
| 19 |
+
# Standard amino acid alphabet used by ESM
|
| 20 |
+
AMINO_ACIDS = "ACDEFGHIKLMNPQRSTVWY"
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
esm_model_name: str = "facebook/esm2_t33_650M_UR50D",
|
| 25 |
+
bos_token="<cls>",
|
| 26 |
+
eos_token="<eos>",
|
| 27 |
+
pad_token="<pad>",
|
| 28 |
+
unk_token="<unk>",
|
| 29 |
+
mask_token="<mask>",
|
| 30 |
+
**kwargs
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Initialize the ESM Tokenizer.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
esm_model_name: Name of the ESM model to load the tokenizer from
|
| 37 |
+
bos_token: Beginning of sequence token (CLS token in ESM)
|
| 38 |
+
eos_token: End of sequence token
|
| 39 |
+
pad_token: Padding token
|
| 40 |
+
unk_token: Unknown token
|
| 41 |
+
mask_token: Mask token for masked language modeling
|
| 42 |
+
"""
|
| 43 |
+
# Load the actual ESM tokenizer
|
| 44 |
+
try:
|
| 45 |
+
self.esm_tokenizer = AutoTokenizer.from_pretrained(esm_model_name, trust_remote_code=True)
|
| 46 |
+
except:
|
| 47 |
+
# Fallback to manual tokenizer if auto loading fails
|
| 48 |
+
self.esm_tokenizer = None
|
| 49 |
+
self._create_manual_tokenizer()
|
| 50 |
+
|
| 51 |
+
# Set special tokens
|
| 52 |
+
self._pad_token = pad_token
|
| 53 |
+
self._eos_token = eos_token
|
| 54 |
+
self._bos_token = bos_token
|
| 55 |
+
self._unk_token = unk_token
|
| 56 |
+
self._mask_token = mask_token
|
| 57 |
+
|
| 58 |
+
# Initialize with special tokens
|
| 59 |
+
super().__init__(
|
| 60 |
+
bos_token=bos_token,
|
| 61 |
+
eos_token=eos_token,
|
| 62 |
+
pad_token=pad_token,
|
| 63 |
+
unk_token=unk_token,
|
| 64 |
+
mask_token=mask_token,
|
| 65 |
+
**kwargs
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# Set token IDs
|
| 69 |
+
if self.esm_tokenizer is not None:
|
| 70 |
+
self.pad_token_id = self.esm_tokenizer.pad_token_id
|
| 71 |
+
self.eos_token_id = self.esm_tokenizer.eos_token_id
|
| 72 |
+
self.bos_token_id = getattr(self.esm_tokenizer, 'cls_token_id', 0)
|
| 73 |
+
self.unk_token_id = self.esm_tokenizer.unk_token_id
|
| 74 |
+
self.mask_token_id = getattr(self.esm_tokenizer, 'mask_token_id', 32)
|
| 75 |
+
else:
|
| 76 |
+
# Manual token IDs for fallback
|
| 77 |
+
self.pad_token_id = 1
|
| 78 |
+
self.eos_token_id = 2
|
| 79 |
+
self.bos_token_id = 0 # CLS token
|
| 80 |
+
self.unk_token_id = 3
|
| 81 |
+
self.mask_token_id = 32
|
| 82 |
+
|
| 83 |
+
def _create_manual_tokenizer(self):
|
| 84 |
+
"""Create a manual tokenizer mapping if ESM tokenizer loading fails."""
|
| 85 |
+
# Create vocabulary mapping
|
| 86 |
+
special_tokens = ["<cls>", "<pad>", "<eos>", "<unk>"]
|
| 87 |
+
amino_acids = list(self.AMINO_ACIDS)
|
| 88 |
+
|
| 89 |
+
self.token_to_id = {}
|
| 90 |
+
self.id_to_token = {}
|
| 91 |
+
|
| 92 |
+
# Add special tokens first
|
| 93 |
+
for i, token in enumerate(special_tokens):
|
| 94 |
+
self.token_to_id[token] = i
|
| 95 |
+
self.id_to_token[i] = token
|
| 96 |
+
|
| 97 |
+
# Add amino acids
|
| 98 |
+
for i, aa in enumerate(amino_acids):
|
| 99 |
+
token_id = i + len(special_tokens)
|
| 100 |
+
self.token_to_id[aa] = token_id
|
| 101 |
+
self.id_to_token[token_id] = aa
|
| 102 |
+
|
| 103 |
+
# Add mask token
|
| 104 |
+
mask_id = 32
|
| 105 |
+
self.token_to_id["<mask>"] = mask_id
|
| 106 |
+
self.id_to_token[mask_id] = "<mask>"
|
| 107 |
+
|
| 108 |
+
self._vocab_size = max(self.id_to_token.keys()) + 1
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def vocab_size(self) -> int:
|
| 112 |
+
"""Return the vocab size of the tokenizer."""
|
| 113 |
+
if self.esm_tokenizer is not None:
|
| 114 |
+
return self.esm_tokenizer.vocab_size
|
| 115 |
+
else:
|
| 116 |
+
return self._vocab_size
|
| 117 |
+
|
| 118 |
+
def get_vocab(self) -> Dict:
|
| 119 |
+
"""Return vocab as a dictionary."""
|
| 120 |
+
if self.esm_tokenizer is not None:
|
| 121 |
+
return self.esm_tokenizer.get_vocab()
|
| 122 |
+
else:
|
| 123 |
+
return self.token_to_id.copy()
|
| 124 |
+
|
| 125 |
+
def _tokenize(self, text: str) -> List[str]:
|
| 126 |
+
"""Tokenize a protein sequence string."""
|
| 127 |
+
if self.esm_tokenizer is not None:
|
| 128 |
+
return self.esm_tokenizer.tokenize(text)
|
| 129 |
+
else:
|
| 130 |
+
# Manual tokenization - split into individual amino acids
|
| 131 |
+
tokens = []
|
| 132 |
+
for char in text.upper():
|
| 133 |
+
if char in self.AMINO_ACIDS:
|
| 134 |
+
tokens.append(char)
|
| 135 |
+
else:
|
| 136 |
+
tokens.append(self._unk_token)
|
| 137 |
+
return tokens
|
| 138 |
+
|
| 139 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 140 |
+
"""Convert a token to an id."""
|
| 141 |
+
if self.esm_tokenizer is not None:
|
| 142 |
+
return self.esm_tokenizer.convert_tokens_to_ids(token)
|
| 143 |
+
else:
|
| 144 |
+
return self.token_to_id.get(token, self.unk_token_id)
|
| 145 |
+
|
| 146 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 147 |
+
"""Convert an id to a token."""
|
| 148 |
+
if self.esm_tokenizer is not None:
|
| 149 |
+
return self.esm_tokenizer.convert_ids_to_tokens(index)
|
| 150 |
+
else:
|
| 151 |
+
return self.id_to_token.get(index, self._unk_token)
|
| 152 |
+
|
| 153 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 154 |
+
"""Convert a sequence of tokens to a single string."""
|
| 155 |
+
# Filter out special tokens and join
|
| 156 |
+
filtered_tokens = []
|
| 157 |
+
for token in tokens:
|
| 158 |
+
if token not in [self._bos_token, self._eos_token, self._pad_token]:
|
| 159 |
+
filtered_tokens.append(token)
|
| 160 |
+
return "".join(filtered_tokens)
|
| 161 |
+
|
| 162 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 163 |
+
"""ESM tokenizer doesn't need vocabulary saving, return empty tuple."""
|
| 164 |
+
return ()
|
| 165 |
+
|
| 166 |
+
def __call__(
|
| 167 |
+
self,
|
| 168 |
+
text: Union[str, List[str]],
|
| 169 |
+
text_pair: Optional[Union[str, List[str]]] = None,
|
| 170 |
+
padding: Union[bool, str] = False,
|
| 171 |
+
truncation: Union[bool, str] = False,
|
| 172 |
+
max_length: Optional[int] = None,
|
| 173 |
+
return_tensors: Optional[str] = None,
|
| 174 |
+
return_token_type_ids: Optional[bool] = None,
|
| 175 |
+
return_attention_mask: Optional[bool] = True,
|
| 176 |
+
add_special_tokens: bool = True,
|
| 177 |
+
**kwargs
|
| 178 |
+
) -> BatchEncoding:
|
| 179 |
+
"""
|
| 180 |
+
Main tokenization method that handles batching and converts to tensors.
|
| 181 |
+
"""
|
| 182 |
+
# Use ESM tokenizer if available
|
| 183 |
+
if self.esm_tokenizer is not None:
|
| 184 |
+
return self.esm_tokenizer(
|
| 185 |
+
text=text,
|
| 186 |
+
text_pair=text_pair,
|
| 187 |
+
padding=padding,
|
| 188 |
+
truncation=truncation,
|
| 189 |
+
max_length=max_length,
|
| 190 |
+
return_tensors=return_tensors,
|
| 191 |
+
return_token_type_ids=return_token_type_ids,
|
| 192 |
+
return_attention_mask=return_attention_mask,
|
| 193 |
+
add_special_tokens=add_special_tokens,
|
| 194 |
+
**kwargs
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Manual tokenization fallback
|
| 198 |
+
# Handle single string vs list of strings
|
| 199 |
+
if isinstance(text, str):
|
| 200 |
+
text = [text]
|
| 201 |
+
|
| 202 |
+
# Tokenize all sequences
|
| 203 |
+
input_ids_list = []
|
| 204 |
+
for seq in text:
|
| 205 |
+
# Clean sequence (remove spaces, convert to uppercase)
|
| 206 |
+
seq = seq.replace(" ", "").upper()
|
| 207 |
+
|
| 208 |
+
# Tokenize sequence
|
| 209 |
+
tokens = self._tokenize(seq)
|
| 210 |
+
token_ids = [self._convert_token_to_id(token) for token in tokens]
|
| 211 |
+
|
| 212 |
+
# Add special tokens if requested
|
| 213 |
+
if add_special_tokens:
|
| 214 |
+
token_ids = [self.bos_token_id] + token_ids + [self.eos_token_id]
|
| 215 |
+
|
| 216 |
+
# Truncate if needed
|
| 217 |
+
if truncation and max_length and len(token_ids) > max_length:
|
| 218 |
+
if add_special_tokens:
|
| 219 |
+
# Keep BOS, truncate middle, keep EOS
|
| 220 |
+
token_ids = [token_ids[0]] + token_ids[1:max_length-1] + [token_ids[-1]]
|
| 221 |
+
else:
|
| 222 |
+
token_ids = token_ids[:max_length]
|
| 223 |
+
|
| 224 |
+
input_ids_list.append(token_ids)
|
| 225 |
+
|
| 226 |
+
# Apply padding if needed
|
| 227 |
+
if padding:
|
| 228 |
+
if max_length:
|
| 229 |
+
max_len = max_length
|
| 230 |
+
else:
|
| 231 |
+
max_len = max(len(ids) for ids in input_ids_list)
|
| 232 |
+
|
| 233 |
+
# Create padded sequences and attention masks
|
| 234 |
+
padded_input_ids = []
|
| 235 |
+
attention_mask = []
|
| 236 |
+
|
| 237 |
+
for ids in input_ids_list:
|
| 238 |
+
# Apply right padding (pad on the right for protein sequences)
|
| 239 |
+
padding_length = max_len - len(ids)
|
| 240 |
+
padded_ids = ids + [self.pad_token_id] * padding_length
|
| 241 |
+
mask = [1] * len(ids) + [0] * padding_length
|
| 242 |
+
|
| 243 |
+
padded_input_ids.append(padded_ids)
|
| 244 |
+
attention_mask.append(mask)
|
| 245 |
+
|
| 246 |
+
input_ids_list = padded_input_ids
|
| 247 |
+
else:
|
| 248 |
+
# Create attention mask without padding
|
| 249 |
+
attention_mask = [[1] * len(ids) for ids in input_ids_list]
|
| 250 |
+
|
| 251 |
+
# Create result dictionary
|
| 252 |
+
result = {"input_ids": input_ids_list}
|
| 253 |
+
if return_attention_mask:
|
| 254 |
+
result["attention_mask"] = attention_mask
|
| 255 |
+
|
| 256 |
+
# Convert to tensors if requested
|
| 257 |
+
if return_tensors == "pt":
|
| 258 |
+
result = {k: torch.tensor(v) for k, v in result.items()}
|
| 259 |
+
|
| 260 |
+
# Return a BatchEncoding object
|
| 261 |
+
return BatchEncoding(
|
| 262 |
+
data=result,
|
| 263 |
+
tensor_type=return_tensors,
|
| 264 |
+
prepend_batch_axis=False,
|
| 265 |
+
encoding=None
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
def batch_decode(
|
| 269 |
+
self,
|
| 270 |
+
sequences: Union[List[int], List[List[int]], torch.Tensor],
|
| 271 |
+
skip_special_tokens: bool = True,
|
| 272 |
+
**kwargs
|
| 273 |
+
) -> List[str]:
|
| 274 |
+
"""
|
| 275 |
+
Decode a batch of token ids to strings.
|
| 276 |
+
"""
|
| 277 |
+
if self.esm_tokenizer is not None:
|
| 278 |
+
return self.esm_tokenizer.batch_decode(sequences, skip_special_tokens=skip_special_tokens, **kwargs)
|
| 279 |
+
|
| 280 |
+
if isinstance(sequences, torch.Tensor):
|
| 281 |
+
sequences = sequences.tolist()
|
| 282 |
+
|
| 283 |
+
results = []
|
| 284 |
+
for seq in sequences:
|
| 285 |
+
tokens = [self._convert_id_to_token(token_id) for token_id in seq]
|
| 286 |
+
if skip_special_tokens:
|
| 287 |
+
tokens = [token for token in tokens if token not in [
|
| 288 |
+
self._bos_token, self._eos_token, self._pad_token, self._unk_token
|
| 289 |
+
]]
|
| 290 |
+
results.append("".join(tokens))
|
| 291 |
+
|
| 292 |
+
return results
|
| 293 |
+
|
| 294 |
+
def decode(
|
| 295 |
+
self,
|
| 296 |
+
token_ids: Union[int, List[int], torch.Tensor],
|
| 297 |
+
skip_special_tokens: bool = True,
|
| 298 |
+
**kwargs
|
| 299 |
+
) -> str:
|
| 300 |
+
"""
|
| 301 |
+
Decode a single sequence of token ids to a string.
|
| 302 |
+
"""
|
| 303 |
+
if self.esm_tokenizer is not None:
|
| 304 |
+
return self.esm_tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens, **kwargs)
|
| 305 |
+
|
| 306 |
+
if isinstance(token_ids, torch.Tensor):
|
| 307 |
+
token_ids = token_ids.tolist()
|
| 308 |
+
|
| 309 |
+
# Handle both single sequence and batch with one item
|
| 310 |
+
if not isinstance(token_ids, list) or not token_ids or not isinstance(token_ids[0], (list, torch.Tensor)):
|
| 311 |
+
# Single sequence
|
| 312 |
+
tokens = [self._convert_id_to_token(token_id) for token_id in token_ids]
|
| 313 |
+
if skip_special_tokens:
|
| 314 |
+
tokens = [token for token in tokens if token not in [
|
| 315 |
+
self._bos_token, self._eos_token, self._pad_token, self._unk_token
|
| 316 |
+
]]
|
| 317 |
+
return "".join(tokens)
|
| 318 |
+
|
| 319 |
+
# Batch with one item
|
| 320 |
+
return self.batch_decode(token_ids, skip_special_tokens, **kwargs)[0]
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def register_esm_tokenizer():
|
| 324 |
+
"""Register the EsmTokenizer with HuggingFace's AutoTokenizer."""
|
| 325 |
+
AutoTokenizer.register("esm", EsmTokenizer)
|
| 326 |
+
print("EsmTokenizer registered with AutoTokenizer")
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
if __name__ == "__main__":
|
| 330 |
+
register_esm_tokenizer()
|
BioReason/bioreason/models/evo2_tokenizer.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
| 2 |
+
from transformers.utils import logging
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import List, Dict, Optional, Union, Tuple
|
| 8 |
+
|
| 9 |
+
# Register the tokenizer with AutoTokenizer
|
| 10 |
+
from transformers.models.auto import AutoTokenizer
|
| 11 |
+
from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING
|
| 12 |
+
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
|
| 13 |
+
|
| 14 |
+
logger = logging.get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
class Evo2Tokenizer(PreTrainedTokenizer):
|
| 17 |
+
"""
|
| 18 |
+
Tokenizer for Evo2 models - wraps the CharLevelTokenizer to be compatible with HuggingFace.
|
| 19 |
+
"""
|
| 20 |
+
vocab_files_names = {} # No vocab files needed
|
| 21 |
+
model_input_names = ["input_ids", "attention_mask"]
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
evo2_tokenizer,
|
| 26 |
+
bos_token="<s>",
|
| 27 |
+
eos_token="</s>",
|
| 28 |
+
pad_token="<pad>",
|
| 29 |
+
unk_token="<unk>",
|
| 30 |
+
**kwargs
|
| 31 |
+
):
|
| 32 |
+
"""
|
| 33 |
+
Initialize the Evo2Tokenizer.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
evo2_tokenizer: The Evo2 CharLevelTokenizer to wrap
|
| 37 |
+
bos_token: Beginning of sequence token
|
| 38 |
+
eos_token: End of sequence token
|
| 39 |
+
pad_token: Padding token
|
| 40 |
+
unk_token: Unknown token
|
| 41 |
+
"""
|
| 42 |
+
self.evo2_tokenizer = evo2_tokenizer
|
| 43 |
+
|
| 44 |
+
# Map special tokens to Evo2 tokenizer's special token IDs
|
| 45 |
+
self._pad_token = pad_token
|
| 46 |
+
self._eos_token = eos_token
|
| 47 |
+
self._bos_token = bos_token
|
| 48 |
+
self._unk_token = unk_token
|
| 49 |
+
|
| 50 |
+
# Initialize with special tokens
|
| 51 |
+
super().__init__(
|
| 52 |
+
bos_token=bos_token,
|
| 53 |
+
eos_token=eos_token,
|
| 54 |
+
pad_token=pad_token,
|
| 55 |
+
unk_token=unk_token,
|
| 56 |
+
**kwargs
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Set token IDs from Evo2 tokenizer
|
| 60 |
+
self.pad_token_id = self.evo2_tokenizer.pad_id
|
| 61 |
+
self.eos_token_id = self.evo2_tokenizer.eos_id
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def vocab_size(self) -> int:
|
| 65 |
+
"""Return the vocab size of the tokenizer."""
|
| 66 |
+
return self.evo2_tokenizer.vocab_size
|
| 67 |
+
|
| 68 |
+
def get_vocab(self) -> Dict:
|
| 69 |
+
"""Return vocab as a dictionary."""
|
| 70 |
+
# Evo2 CharLevelTokenizer doesn't have a traditional vocab dict
|
| 71 |
+
# Create a simple mapping of ASCII codes to tokens
|
| 72 |
+
return {chr(i): i for i in range(self.vocab_size)}
|
| 73 |
+
|
| 74 |
+
def _tokenize(self, text: str) -> List[int]:
|
| 75 |
+
"""Tokenize a string using the Evo2 tokenizer."""
|
| 76 |
+
return [chr(int(token)) for token in self.evo2_tokenizer.tokenize(text)]
|
| 77 |
+
|
| 78 |
+
def _convert_token_to_id(self, token: str) -> int:
|
| 79 |
+
"""Convert a token to an id using the Evo2 tokenizer."""
|
| 80 |
+
# Since tokens are just characters, convert to their ASCII value
|
| 81 |
+
return ord(token)
|
| 82 |
+
|
| 83 |
+
def _convert_id_to_token(self, index: int) -> str:
|
| 84 |
+
"""Convert an id to a token using the Evo2 tokenizer."""
|
| 85 |
+
# Convert ASCII value back to character
|
| 86 |
+
return chr(index)
|
| 87 |
+
|
| 88 |
+
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| 89 |
+
"""Convert a sequence of tokens to a single string."""
|
| 90 |
+
return "".join(tokens)
|
| 91 |
+
|
| 92 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
| 93 |
+
"""No vocabulary to save for Evo2Tokenizer, so just return an empty tuple."""
|
| 94 |
+
return ()
|
| 95 |
+
|
| 96 |
+
def __call__(
|
| 97 |
+
self,
|
| 98 |
+
text: Union[str, List[str]],
|
| 99 |
+
text_pair: Optional[Union[str, List[str]]] = None,
|
| 100 |
+
padding: Union[bool, str] = False,
|
| 101 |
+
truncation: Union[bool, str] = False,
|
| 102 |
+
max_length: Optional[int] = None,
|
| 103 |
+
return_tensors: Optional[str] = None,
|
| 104 |
+
return_token_type_ids: Optional[bool] = None,
|
| 105 |
+
return_attention_mask: Optional[bool] = True,
|
| 106 |
+
**kwargs
|
| 107 |
+
) -> Dict[str, torch.Tensor]:
|
| 108 |
+
"""
|
| 109 |
+
Main tokenization method that handles batching and converts to tensors.
|
| 110 |
+
"""
|
| 111 |
+
# Handle single string vs list of strings
|
| 112 |
+
if isinstance(text, str):
|
| 113 |
+
text = [text]
|
| 114 |
+
|
| 115 |
+
# Tokenize all sequences - note: tokenizer only accepts strings, not lists
|
| 116 |
+
input_ids_list = []
|
| 117 |
+
for seq in text:
|
| 118 |
+
# Tokenize and convert numpy.uint8 to Python integers
|
| 119 |
+
tokens = [int(token) for token in self.evo2_tokenizer.tokenize(seq)]
|
| 120 |
+
|
| 121 |
+
# Truncate if needed
|
| 122 |
+
if truncation and max_length and len(tokens) > max_length:
|
| 123 |
+
tokens = tokens[:max_length]
|
| 124 |
+
|
| 125 |
+
input_ids_list.append(tokens)
|
| 126 |
+
|
| 127 |
+
# Apply padding if needed
|
| 128 |
+
if padding:
|
| 129 |
+
if False:#max_length:
|
| 130 |
+
max_len = max_length
|
| 131 |
+
else:
|
| 132 |
+
max_len = max(len(ids) for ids in input_ids_list)
|
| 133 |
+
|
| 134 |
+
# Create padded sequences and attention masks
|
| 135 |
+
padded_input_ids = []
|
| 136 |
+
attention_mask = []
|
| 137 |
+
|
| 138 |
+
for ids in input_ids_list:
|
| 139 |
+
# Apply left padding (pad on the left)
|
| 140 |
+
padding_length = max_len - len(ids)
|
| 141 |
+
padded_ids = [self.pad_token_id] * padding_length + ids
|
| 142 |
+
mask = [0] * padding_length + [1] * len(ids)
|
| 143 |
+
|
| 144 |
+
padded_input_ids.append(padded_ids)
|
| 145 |
+
attention_mask.append(mask)
|
| 146 |
+
|
| 147 |
+
input_ids_list = padded_input_ids
|
| 148 |
+
else:
|
| 149 |
+
# Create attention mask without padding
|
| 150 |
+
attention_mask = [[1] * len(ids) for ids in input_ids_list]
|
| 151 |
+
|
| 152 |
+
# Create result dictionary
|
| 153 |
+
result = {"input_ids": input_ids_list}
|
| 154 |
+
if return_attention_mask:
|
| 155 |
+
result["attention_mask"] = attention_mask
|
| 156 |
+
|
| 157 |
+
# Convert to tensors if requested
|
| 158 |
+
if return_tensors == "pt":
|
| 159 |
+
result = {k: torch.tensor(v) for k, v in result.items()}
|
| 160 |
+
|
| 161 |
+
# Return a BatchEncoding object rather than a plain dictionary
|
| 162 |
+
return BatchEncoding(
|
| 163 |
+
data=result,
|
| 164 |
+
tensor_type=return_tensors,
|
| 165 |
+
prepend_batch_axis=False, # Already handled in our tensor creation
|
| 166 |
+
encoding=None # No encoding info from Evo2's tokenizer
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
def batch_decode(
|
| 170 |
+
self,
|
| 171 |
+
sequences: Union[List[int], List[List[int]], torch.Tensor],
|
| 172 |
+
skip_special_tokens: bool = False,
|
| 173 |
+
**kwargs
|
| 174 |
+
) -> List[str]:
|
| 175 |
+
"""
|
| 176 |
+
Decode a batch of token ids to strings.
|
| 177 |
+
"""
|
| 178 |
+
if isinstance(sequences, torch.Tensor):
|
| 179 |
+
sequences = sequences.tolist()
|
| 180 |
+
|
| 181 |
+
return self.evo2_tokenizer.detokenize_batch(sequences)
|
| 182 |
+
|
| 183 |
+
def decode(
|
| 184 |
+
self,
|
| 185 |
+
token_ids: Union[int, List[int], torch.Tensor],
|
| 186 |
+
skip_special_tokens: bool = False,
|
| 187 |
+
**kwargs
|
| 188 |
+
) -> str:
|
| 189 |
+
"""
|
| 190 |
+
Decode a single sequence of token ids to a string.
|
| 191 |
+
"""
|
| 192 |
+
if isinstance(token_ids, torch.Tensor):
|
| 193 |
+
token_ids = token_ids.tolist()
|
| 194 |
+
|
| 195 |
+
# Single sequence
|
| 196 |
+
if not isinstance(token_ids, list) or not token_ids or not isinstance(token_ids[0], (list, torch.Tensor)):
|
| 197 |
+
return self.evo2_tokenizer.detokenize(token_ids)
|
| 198 |
+
|
| 199 |
+
# Batch with one item
|
| 200 |
+
return self.batch_decode(token_ids, skip_special_tokens, **kwargs)[0]
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# Register the tokenizer - you'll need to do this when your script loads
|
| 204 |
+
# You might want to put this in your __init__.py file
|
| 205 |
+
def register_evo2_tokenizer():
|
| 206 |
+
"""Register the Evo2Tokenizer with HuggingFace's AutoTokenizer."""
|
| 207 |
+
|
| 208 |
+
# This will register the tokenizer so AutoTokenizer.from_pretrained knows about it
|
| 209 |
+
AutoTokenizer.register("evo2", Evo2Tokenizer)
|
| 210 |
+
|
| 211 |
+
# If you have a config class, you would also register that
|
| 212 |
+
# from transformers.models.auto import AutoConfig
|
| 213 |
+
# AutoConfig.register("evo2", Evo2Config)
|
| 214 |
+
|
| 215 |
+
print("Evo2Tokenizer registered with AutoTokenizer")
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
register_evo2_tokenizer()
|
BioReason/bioreason/models/pl/chat_template_pl.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
CHAT_TEMPLATE = "{%- set protein_count = namespace(value=0) %}{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content is string and message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' }} {%- if message.content is string %}{{- message.content + '<|im_end|>' + '\\n' }}{%- else %}{%- for content in message.content %}{%- if content.type == 'protein' or 'protein' in content %}{%- set protein_count.value = protein_count.value + 1 %}{%- if add_protein_id %}Protein Sequence {{- protein_count.value }}: {%- endif %}<|protein_start|><|protein_pad|><|protein_end|>{%- elif 'text' in content %}{{- content.text }}{%- endif %}{%- endfor %}{{- '<|im_end|>' + '\\n' }}{%- endif %}{%- elif message.role == \"assistant\" %}\n {%- set content = message.content[0].text %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in message.content %}\n {%- set content = message.content[0].text.split('</think>')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content[0].text.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}"
|
BioReason/bioreason/models/pl/configuration_pl.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class ProteinLLMESMConfig(PretrainedConfig):
|
| 4 |
+
model_type = "protein_llm"
|
| 5 |
+
base_config_key = "esm_config"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
# ESM2 related configurations
|
| 10 |
+
esm_hidden_size=1280,
|
| 11 |
+
esm_num_layers=33,
|
| 12 |
+
esm_num_attention_heads=20,
|
| 13 |
+
esm_vocab_size=33,
|
| 14 |
+
esm_max_position_embeddings=1026,
|
| 15 |
+
esm_layer_norm_eps=1e-5,
|
| 16 |
+
esm_hidden_dropout_prob=0.1,
|
| 17 |
+
esm_attention_probs_dropout_prob=0.1,
|
| 18 |
+
esm_intermediate_size=5120,
|
| 19 |
+
esm_hidden_act="gelu",
|
| 20 |
+
esm_initializer_range=0.02,
|
| 21 |
+
esm_layer_norm_eps=1e-5,
|
| 22 |
+
**kwargs,
|
| 23 |
+
):
|
| 24 |
+
super().__init__(**kwargs)
|
| 25 |
+
|
| 26 |
+
# ESM2 configurations
|
| 27 |
+
self.esm_hidden_size = esm_hidden_size
|
| 28 |
+
self.esm_num_layers = esm_num_layers
|
| 29 |
+
self.esm_num_attention_heads = esm_num_attention_heads
|
| 30 |
+
self.esm_vocab_size = esm_vocab_size
|
| 31 |
+
self.esm_max_position_embeddings = esm_max_position_embeddings
|
| 32 |
+
self.esm_layer_norm_eps = esm_layer_norm_eps
|
| 33 |
+
self.esm_hidden_dropout_prob = esm_hidden_dropout_prob
|
| 34 |
+
self.esm_attention_probs_dropout_prob = esm_attention_probs_dropout_prob
|
| 35 |
+
self.esm_intermediate_size = esm_intermediate_size
|
| 36 |
+
self.esm_hidden_act = esm_hidden_act
|
| 37 |
+
self.esm_initializer_range = esm_initializer_range
|
| 38 |
+
|
| 39 |
+
class ProteinLLMQFormerConfig(PretrainedConfig):
|
| 40 |
+
model_type = "protein_llm"
|
| 41 |
+
base_config_key = "qformer_config"
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
# Q-Former configurations
|
| 46 |
+
qformer_hidden_size=768,
|
| 47 |
+
qformer_num_hidden_layers=12,
|
| 48 |
+
qformer_num_attention_heads=12,
|
| 49 |
+
qformer_intermediate_size=3072,
|
| 50 |
+
qformer_hidden_act="gelu",
|
| 51 |
+
qformer_hidden_dropout_prob=0.1,
|
| 52 |
+
qformer_attention_probs_dropout_prob=0.1,
|
| 53 |
+
qformer_max_position_embeddings=512,
|
| 54 |
+
qformer_layer_norm_eps=1e-12,
|
| 55 |
+
qformer_initializer_range=0.02,
|
| 56 |
+
qformer_vocab_size=30522,
|
| 57 |
+
qformer_pad_token_id=0,
|
| 58 |
+
qformer_position_embedding_type="absolute",
|
| 59 |
+
qformer_use_cache=True,
|
| 60 |
+
# Query tokens
|
| 61 |
+
num_query_tokens=32,
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
super().__init__(**kwargs)
|
| 65 |
+
|
| 66 |
+
# Q-Former configurations
|
| 67 |
+
self.qformer_hidden_size = qformer_hidden_size
|
| 68 |
+
self.qformer_num_hidden_layers = qformer_num_hidden_layers
|
| 69 |
+
self.qformer_num_attention_heads = qformer_num_attention_heads
|
| 70 |
+
self.qformer_intermediate_size = qformer_intermediate_size
|
| 71 |
+
self.qformer_hidden_act = qformer_hidden_act
|
| 72 |
+
self.qformer_hidden_dropout_prob = qformer_hidden_dropout_prob
|
| 73 |
+
self.qformer_attention_probs_dropout_prob = qformer_attention_probs_dropout_prob
|
| 74 |
+
self.qformer_max_position_embeddings = qformer_max_position_embeddings
|
| 75 |
+
self.qformer_layer_norm_eps = qformer_layer_norm_eps
|
| 76 |
+
self.qformer_initializer_range = qformer_initializer_range
|
| 77 |
+
self.qformer_vocab_size = qformer_vocab_size
|
| 78 |
+
self.qformer_pad_token_id = qformer_pad_token_id
|
| 79 |
+
self.qformer_position_embedding_type = qformer_position_embedding_type
|
| 80 |
+
self.qformer_use_cache = qformer_use_cache
|
| 81 |
+
self.num_query_tokens = num_query_tokens
|
| 82 |
+
|
| 83 |
+
class ProteinLLMConfig(PretrainedConfig):
|
| 84 |
+
r"""
|
| 85 |
+
This is the configuration class to store the configuration of a [`ProteinLLMModel`]. It is used to instantiate a
|
| 86 |
+
Protein-LLM model according to the specified arguments, defining the model architecture. The model combines
|
| 87 |
+
ESM2 protein encoder, Q-Former, and a language model for protein understanding and generation.
|
| 88 |
+
|
| 89 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 90 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
vocab_size (`int`, *optional*, defaults to 152064):
|
| 94 |
+
Vocabulary size of the language model. Defines the number of different tokens that can be represented by the
|
| 95 |
+
`inputs_ids` passed when calling the model.
|
| 96 |
+
hidden_size (`int`, *optional*, defaults to 8192):
|
| 97 |
+
Dimension of the hidden representations in the language model.
|
| 98 |
+
intermediate_size (`int`, *optional*, defaults to 29568):
|
| 99 |
+
Dimension of the MLP representations in the language model.
|
| 100 |
+
num_hidden_layers (`int`, *optional*, defaults to 80):
|
| 101 |
+
Number of hidden layers in the Transformer encoder of the language model.
|
| 102 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
| 103 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 104 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 105 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention.
|
| 106 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 107 |
+
The non-linear activation function (function or string) in the decoder.
|
| 108 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 109 |
+
The maximum sequence length that this model might ever be used with.
|
| 110 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 111 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 112 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 113 |
+
The epsilon used by the rms normalization layers.
|
| 114 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 115 |
+
Whether or not the model should return the last key/values attentions.
|
| 116 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 117 |
+
Whether the model's input and output word embeddings should be tied.
|
| 118 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
| 119 |
+
The base period of the RoPE embeddings.
|
| 120 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 121 |
+
Whether to use sliding window attention.
|
| 122 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 123 |
+
Sliding window attention (SWA) window size.
|
| 124 |
+
max_window_layers (`int`, *optional*, defaults to 80):
|
| 125 |
+
The number of layers that use SWA.
|
| 126 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 127 |
+
The dropout ratio for the attention probabilities.
|
| 128 |
+
esm_config (`Dict`, *optional*):
|
| 129 |
+
The config for the ESM2 protein encoder initialization.
|
| 130 |
+
qformer_config (`Dict`, *optional*):
|
| 131 |
+
The config for the Q-Former initialization.
|
| 132 |
+
rope_scaling (`Dict`, *optional*):
|
| 133 |
+
Dictionary containing the scaling configuration for the RoPE embeddings.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
model_type = "protein_llm"
|
| 137 |
+
sub_configs = {
|
| 138 |
+
"esm_config": ProteinLLMESMConfig,
|
| 139 |
+
"qformer_config": ProteinLLMQFormerConfig
|
| 140 |
+
}
|
| 141 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 142 |
+
|
| 143 |
+
# Default tensor parallel plan for base model
|
| 144 |
+
base_model_tp_plan = {
|
| 145 |
+
"layers.*.self_attn.q_proj": "colwise",
|
| 146 |
+
"layers.*.self_attn.k_proj": "colwise",
|
| 147 |
+
"layers.*.self_attn.v_proj": "colwise",
|
| 148 |
+
"layers.*.self_attn.o_proj": "rowwise",
|
| 149 |
+
"layers.*.mlp.gate_proj": "colwise",
|
| 150 |
+
"layers.*.mlp.up_proj": "colwise",
|
| 151 |
+
"layers.*.mlp.down_proj": "rowwise",
|
| 152 |
+
}
|
| 153 |
+
base_model_pp_plan = {
|
| 154 |
+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
| 155 |
+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
| 156 |
+
"norm": (["hidden_states"], ["hidden_states"]),
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
def __init__(
|
| 160 |
+
self,
|
| 161 |
+
vocab_size=152064,
|
| 162 |
+
hidden_size=8192,
|
| 163 |
+
intermediate_size=29568,
|
| 164 |
+
num_hidden_layers=80,
|
| 165 |
+
num_attention_heads=64,
|
| 166 |
+
num_key_value_heads=8,
|
| 167 |
+
hidden_act="silu",
|
| 168 |
+
max_position_embeddings=32768,
|
| 169 |
+
initializer_range=0.02,
|
| 170 |
+
rms_norm_eps=1e-05,
|
| 171 |
+
use_cache=True,
|
| 172 |
+
tie_word_embeddings=False,
|
| 173 |
+
rope_theta=1000000.0,
|
| 174 |
+
use_sliding_window=False,
|
| 175 |
+
sliding_window=4096,
|
| 176 |
+
max_window_layers=80,
|
| 177 |
+
attention_dropout=0.0,
|
| 178 |
+
esm_config=None,
|
| 179 |
+
qformer_config=None,
|
| 180 |
+
rope_scaling=None,
|
| 181 |
+
protein_token_id=None,
|
| 182 |
+
**kwargs,
|
| 183 |
+
):
|
| 184 |
+
# Initialize ESM config
|
| 185 |
+
if isinstance(esm_config, dict):
|
| 186 |
+
self.esm_config = self.sub_configs["esm_config"](**esm_config)
|
| 187 |
+
elif esm_config is None:
|
| 188 |
+
self.esm_config = self.sub_configs["esm_config"]()
|
| 189 |
+
else:
|
| 190 |
+
self.esm_config = esm_config
|
| 191 |
+
|
| 192 |
+
# Initialize Q-Former config
|
| 193 |
+
if isinstance(qformer_config, dict):
|
| 194 |
+
self.qformer_config = self.sub_configs["qformer_config"](**qformer_config)
|
| 195 |
+
elif qformer_config is None:
|
| 196 |
+
self.qformer_config = self.sub_configs["qformer_config"]()
|
| 197 |
+
else:
|
| 198 |
+
self.qformer_config = qformer_config
|
| 199 |
+
|
| 200 |
+
# Language model configurations
|
| 201 |
+
self.vocab_size = vocab_size
|
| 202 |
+
self.max_position_embeddings = max_position_embeddings
|
| 203 |
+
self.hidden_size = hidden_size
|
| 204 |
+
self.intermediate_size = intermediate_size
|
| 205 |
+
self.num_hidden_layers = num_hidden_layers
|
| 206 |
+
self.num_attention_heads = num_attention_heads
|
| 207 |
+
self.use_sliding_window = use_sliding_window
|
| 208 |
+
self.sliding_window = sliding_window
|
| 209 |
+
self.max_window_layers = max_window_layers
|
| 210 |
+
|
| 211 |
+
# for backward compatibility
|
| 212 |
+
if num_key_value_heads is None:
|
| 213 |
+
num_key_value_heads = num_attention_heads
|
| 214 |
+
|
| 215 |
+
self.num_key_value_heads = num_key_value_heads
|
| 216 |
+
self.hidden_act = hidden_act
|
| 217 |
+
self.initializer_range = initializer_range
|
| 218 |
+
self.rms_norm_eps = rms_norm_eps
|
| 219 |
+
self.use_cache = use_cache
|
| 220 |
+
self.rope_theta = rope_theta
|
| 221 |
+
self.attention_dropout = attention_dropout
|
| 222 |
+
self.rope_scaling = rope_scaling
|
| 223 |
+
|
| 224 |
+
self.protein_token_id = protein_token_id
|
| 225 |
+
|
| 226 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 227 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 228 |
+
if self.rope_scaling["type"] == "mrope":
|
| 229 |
+
self.rope_scaling["type"] = "default"
|
| 230 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 231 |
+
|
| 232 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| 233 |
+
|
| 234 |
+
__all__ = ["ProteinLLMConfig"]
|
BioReason/bioreason/models/pl/processing_pl.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Union, Dict, Any, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
from transformers import AutoTokenizer
|
| 8 |
+
from transformers.processing_utils import (
|
| 9 |
+
CommonKwargs,
|
| 10 |
+
ProcessingKwargs,
|
| 11 |
+
ProcessorMixin,
|
| 12 |
+
Unpack,
|
| 13 |
+
)
|
| 14 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 15 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 16 |
+
from transformers.utils import logging
|
| 17 |
+
|
| 18 |
+
from bioreason.utils.protein_utils import ProteinInput
|
| 19 |
+
|
| 20 |
+
class ProteinLLMProteinKwargs(CommonKwargs):
|
| 21 |
+
"""Keyword arguments specific to protein sequence processing"""
|
| 22 |
+
max_length_text: Optional[int]
|
| 23 |
+
max_length_protein: Optional[int]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ProteinLLMProcessorKwargs(ProcessingKwargs, total=False):
|
| 27 |
+
"""Processing keyword arguments for the ProteinLLM processor"""
|
| 28 |
+
protein_kwargs: ProteinLLMProteinKwargs
|
| 29 |
+
_defaults = {
|
| 30 |
+
"text_kwargs": {
|
| 31 |
+
"padding": False,
|
| 32 |
+
},
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
class ProteinLLMProcessor(ProcessorMixin):
|
| 36 |
+
r"""
|
| 37 |
+
Constructs a ProteinLLM processor which wraps an ESM protein tokenizer and a language model tokenizer into a single processor.
|
| 38 |
+
This processor handles both text and protein sequence processing to prepare inputs for the ProteinLLMModel.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
tokenizer (PreTrainedTokenizerBase, *optional*):
|
| 42 |
+
The text tokenizer used for processing text inputs.
|
| 43 |
+
protein_tokenizer (PreTrainedTokenizerBase, *optional*):
|
| 44 |
+
The protein tokenizer (ESM) used for processing protein sequences.
|
| 45 |
+
chat_template (`str`, *optional*):
|
| 46 |
+
A Jinja template for chat formatting. If None, will use the tokenizer's template.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
attributes = ["tokenizer", "protein_tokenizer"]
|
| 50 |
+
valid_kwargs = ["model", "chat_template"]
|
| 51 |
+
tokenizer_class = (
|
| 52 |
+
"Qwen2Tokenizer", "Qwen2TokenizerFast",
|
| 53 |
+
"GPT2TokenizerFast", "LlamaTokenizer", "LlamaTokenizerFast",
|
| 54 |
+
)
|
| 55 |
+
protein_tokenizer_class = ("EsmTokenizer",)
|
| 56 |
+
|
| 57 |
+
def __init__(
|
| 58 |
+
self, tokenizer=None, protein_tokenizer=None, chat_template=None, **kwargs
|
| 59 |
+
):
|
| 60 |
+
"""
|
| 61 |
+
Initialize the processor with text and protein tokenizers.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
tokenizer: Text tokenizer (usually from a language model)
|
| 65 |
+
protein_tokenizer: Protein tokenizer (usually ESM tokenizer)
|
| 66 |
+
chat_template: Template for formatting chat conversations
|
| 67 |
+
**kwargs: Additional arguments
|
| 68 |
+
"""
|
| 69 |
+
self.tokenizer = tokenizer
|
| 70 |
+
self.protein_tokenizer = protein_tokenizer
|
| 71 |
+
|
| 72 |
+
self.protein_token = (
|
| 73 |
+
"<|protein_pad|>"
|
| 74 |
+
if not hasattr(self.tokenizer, "protein_token")
|
| 75 |
+
else self.tokenizer.protein_token
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Get chat template from tokenizer if not provided
|
| 79 |
+
if chat_template is None and hasattr(self.tokenizer, "chat_template"):
|
| 80 |
+
chat_template = self.tokenizer.chat_template
|
| 81 |
+
super().__init__(tokenizer, protein_tokenizer, chat_template=chat_template)
|
| 82 |
+
|
| 83 |
+
# The GRPO trainer might expect this to be set
|
| 84 |
+
if not hasattr(self.tokenizer, 'pad_token') or self.tokenizer.pad_token is None:
|
| 85 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 86 |
+
|
| 87 |
+
def tokenize_protein_sequences(
|
| 88 |
+
self,
|
| 89 |
+
batch_protein_sequences: List[List[str]],
|
| 90 |
+
max_length: int = 1024,
|
| 91 |
+
return_tensors: str = "pt",
|
| 92 |
+
device: str = "cuda",
|
| 93 |
+
) -> Dict[str, Any]:
|
| 94 |
+
"""
|
| 95 |
+
Tokenize a batch of protein sequences.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
batch_protein_sequences: List of lists of protein sequences per batch item
|
| 99 |
+
max_length: Maximum allowed length for protein sequences
|
| 100 |
+
return_tensors: Return format for tensors ("pt" for PyTorch)
|
| 101 |
+
device: Device to place tensors on
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
Dict containing:
|
| 105 |
+
- protein_tokenized: The tokenized protein sequences
|
| 106 |
+
- batch_idx_map: Mapping of which sequences belong to which batch item
|
| 107 |
+
"""
|
| 108 |
+
# Create a mapping to track which sequences belong to which batch item
|
| 109 |
+
batch_idx_map = []
|
| 110 |
+
all_sequences = []
|
| 111 |
+
|
| 112 |
+
# Flatten all sequences with batch tracking
|
| 113 |
+
for batch_idx, protein_sequences in enumerate(batch_protein_sequences):
|
| 114 |
+
for seq in protein_sequences:
|
| 115 |
+
all_sequences.append(seq)
|
| 116 |
+
batch_idx_map.append(batch_idx)
|
| 117 |
+
|
| 118 |
+
# If no sequences in the entire batch, return empty dict
|
| 119 |
+
if not all_sequences:
|
| 120 |
+
return {"protein_tokenized": None, "batch_idx_map": []}
|
| 121 |
+
|
| 122 |
+
# Tokenize all sequences at once
|
| 123 |
+
protein_tokenized = self.protein_tokenizer(
|
| 124 |
+
all_sequences,
|
| 125 |
+
padding=True,
|
| 126 |
+
truncation=True,
|
| 127 |
+
max_length=max_length,
|
| 128 |
+
return_tensors=return_tensors,
|
| 129 |
+
return_attention_mask=True,
|
| 130 |
+
add_special_tokens=True,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return {"protein_tokenized": protein_tokenized, "batch_idx_map": batch_idx_map}
|
| 134 |
+
|
| 135 |
+
def __call__(
|
| 136 |
+
self,
|
| 137 |
+
batch_protein_sequences: Optional[List[List[str]]] = None,
|
| 138 |
+
text: Optional[
|
| 139 |
+
Union[
|
| 140 |
+
TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
|
| 141 |
+
]
|
| 142 |
+
] = None,
|
| 143 |
+
max_length_text: int = 512,
|
| 144 |
+
max_length_protein: int = 1024,
|
| 145 |
+
return_tensors: str = "pt",
|
| 146 |
+
device: str = "cuda",
|
| 147 |
+
**kwargs: Unpack[ProteinLLMProcessorKwargs],
|
| 148 |
+
) -> BatchFeature:
|
| 149 |
+
"""
|
| 150 |
+
Process text and protein sequences for model input.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
batch_protein_sequences: List of lists of protein sequences per batch item
|
| 154 |
+
text: Input text or list of texts
|
| 155 |
+
max_length_text: Maximum length for text sequences
|
| 156 |
+
max_length_protein: Maximum length for protein sequences
|
| 157 |
+
return_tensors: Return format for tensors
|
| 158 |
+
device: Device to place tensors on
|
| 159 |
+
**kwargs: Additional processor keyword arguments
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
BatchFeature with tokenized inputs for the model
|
| 163 |
+
"""
|
| 164 |
+
output_kwargs = self._merge_kwargs(
|
| 165 |
+
ProteinLLMProcessorKwargs,
|
| 166 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 167 |
+
**kwargs,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Ensure text is a list
|
| 171 |
+
if not isinstance(text, list):
|
| 172 |
+
text = [text]
|
| 173 |
+
|
| 174 |
+
protein_inputs = {}
|
| 175 |
+
if batch_protein_sequences is not None:
|
| 176 |
+
# Tokenize protein sequences
|
| 177 |
+
protein_processing_result = self.tokenize_protein_sequences(
|
| 178 |
+
batch_protein_sequences,
|
| 179 |
+
max_length=max_length_protein,
|
| 180 |
+
return_tensors=return_tensors,
|
| 181 |
+
device=device,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Replace protein tokens in text if needed
|
| 185 |
+
index = 0
|
| 186 |
+
for i in range(len(text)):
|
| 187 |
+
while self.protein_token in text[i]:
|
| 188 |
+
# For ESM tokenizer, calculate actual tokens excluding special tokens
|
| 189 |
+
protein_token_ids = protein_processing_result['protein_tokenized']['input_ids'][index]
|
| 190 |
+
# Exclude BOS and EOS tokens (typically first and last tokens in ESM)
|
| 191 |
+
num_protein_tokens = (protein_token_ids != self.protein_tokenizer.pad_token_id).sum().item() - 2
|
| 192 |
+
num_protein_tokens = max(1, num_protein_tokens) # Ensure at least 1 token
|
| 193 |
+
|
| 194 |
+
text[i] = text[i].replace(
|
| 195 |
+
self.protein_token, "<|placeholder|>" * num_protein_tokens, 1
|
| 196 |
+
)
|
| 197 |
+
index += 1
|
| 198 |
+
text[i] = text[i].replace("<|placeholder|>", self.protein_token)
|
| 199 |
+
|
| 200 |
+
# Add batch info to the output
|
| 201 |
+
protein_inputs = {
|
| 202 |
+
"protein_tokenized": protein_processing_result["protein_tokenized"],
|
| 203 |
+
"batch_idx_map": protein_processing_result["batch_idx_map"],
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
# Tokenize text
|
| 207 |
+
text_kwargs = output_kwargs.get("text_kwargs", {})
|
| 208 |
+
|
| 209 |
+
if 'padding' in text_kwargs:
|
| 210 |
+
del text_kwargs['padding']
|
| 211 |
+
|
| 212 |
+
text_inputs = self.tokenizer(
|
| 213 |
+
text,
|
| 214 |
+
max_length=max_length_text + 2 * max_length_protein,
|
| 215 |
+
return_tensors=return_tensors,
|
| 216 |
+
padding=True,
|
| 217 |
+
truncation=True,
|
| 218 |
+
**text_kwargs,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# The BatchFeature should have all required fields for the model's forward pass
|
| 222 |
+
return BatchFeature(data={**text_inputs, **protein_inputs})
|
| 223 |
+
|
| 224 |
+
def batch_decode(self, *args, **kwargs) -> List[str]:
|
| 225 |
+
"""
|
| 226 |
+
This method forwards all its arguments to the tokenizer's batch_decode.
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
List of decoded strings
|
| 230 |
+
"""
|
| 231 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 232 |
+
|
| 233 |
+
def decode(self, *args, **kwargs) -> str:
|
| 234 |
+
"""
|
| 235 |
+
This method forwards all its arguments to the tokenizer's decode.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
Decoded string
|
| 239 |
+
"""
|
| 240 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 241 |
+
|
| 242 |
+
def post_process_protein_to_text(
|
| 243 |
+
self,
|
| 244 |
+
generated_outputs: torch.Tensor,
|
| 245 |
+
skip_special_tokens: bool = True,
|
| 246 |
+
**kwargs,
|
| 247 |
+
) -> List[str]:
|
| 248 |
+
"""
|
| 249 |
+
Post-process the model output to decode the text.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
generated_outputs: The token IDs generated by the model
|
| 253 |
+
skip_special_tokens: Whether to skip special tokens in the output
|
| 254 |
+
**kwargs: Additional arguments for the decoder
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
List of decoded strings
|
| 258 |
+
"""
|
| 259 |
+
return self.tokenizer.batch_decode(
|
| 260 |
+
generated_outputs,
|
| 261 |
+
skip_special_tokens=skip_special_tokens,
|
| 262 |
+
**kwargs,
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
@property
|
| 266 |
+
def model_input_names(self) -> List[str]:
|
| 267 |
+
"""
|
| 268 |
+
Get the input names expected by the model.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
List of input names
|
| 272 |
+
"""
|
| 273 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 274 |
+
protein_input_names = ["protein_tokenized", "batch_idx_map"]
|
| 275 |
+
|
| 276 |
+
return list(dict.fromkeys(tokenizer_input_names + protein_input_names))
|
BioReason/bioreason/models/protein_llm.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoTokenizer,
|
| 7 |
+
AutoModelForCausalLM,
|
| 8 |
+
AutoModelForMaskedLM,
|
| 9 |
+
AutoModel, # 新增:用于加载BiomedBERT
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from typing import Optional, List, Dict, Any, Union, Tuple
|
| 14 |
+
|
| 15 |
+
from bioreason.utils.protein_utils import ProteinInput
|
| 16 |
+
from bioreason.models.pl.processing_pl import ProteinLLMProcessor
|
| 17 |
+
from bioreason.models.pl.chat_template_pl import CHAT_TEMPLATE
|
| 18 |
+
|
| 19 |
+
class BiomedBERTQFormer(nn.Module):
|
| 20 |
+
"""
|
| 21 |
+
使用预训练的BiomedBERT作为Q-Former的模型
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, config, biomedbert_model_name="microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.config = config
|
| 27 |
+
|
| 28 |
+
# 直接加载预训练的BiomedBERT模型
|
| 29 |
+
self.biomedbert = AutoModel.from_pretrained(
|
| 30 |
+
biomedbert_model_name,
|
| 31 |
+
trust_remote_code=True,
|
| 32 |
+
add_cross_attention=True, # 启用交叉注意力机制
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# 获取BiomedBERT的配置
|
| 36 |
+
self.biomedbert_config = self.biomedbert.config
|
| 37 |
+
self.biomedbert_hidden_size = self.biomedbert_config.hidden_size
|
| 38 |
+
|
| 39 |
+
# Query tokens - 可学习的参数
|
| 40 |
+
self.query_tokens = nn.Parameter(
|
| 41 |
+
torch.zeros(1, config.qformer_config.num_query_tokens, self.biomedbert_hidden_size)
|
| 42 |
+
)
|
| 43 |
+
self.query_tokens.data.normal_(mean=0.0, std=0.02)
|
| 44 |
+
|
| 45 |
+
# 如果蛋白质编码器的维度与BiomedBERT不匹配,需要投影层
|
| 46 |
+
if config.esm_config.esm_hidden_size != self.biomedbert_hidden_size:
|
| 47 |
+
self.protein_projection = nn.Linear(
|
| 48 |
+
config.esm_config.esm_hidden_size,
|
| 49 |
+
self.biomedbert_hidden_size
|
| 50 |
+
)
|
| 51 |
+
else:
|
| 52 |
+
self.protein_projection = None
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self,
|
| 56 |
+
protein_embeds: torch.Tensor,
|
| 57 |
+
protein_attention_mask: torch.Tensor,
|
| 58 |
+
output_attentions: Optional[bool] = None,
|
| 59 |
+
output_hidden_states: Optional[bool] = None,
|
| 60 |
+
return_dict: Optional[bool] = None,
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
Args:
|
| 64 |
+
protein_embeds: 来自ESM2的蛋白质嵌入 [batch_size, seq_len, esm_hidden_size]
|
| 65 |
+
protein_attention_mask: 蛋白质序列的注意力掩码 [batch_size, seq_len]
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
可用作语言模型输入的查询嵌入
|
| 69 |
+
"""
|
| 70 |
+
batch_size = protein_embeds.shape[0]
|
| 71 |
+
|
| 72 |
+
# 如果需要,将蛋白质嵌入投影到BiomedBERT维度
|
| 73 |
+
if self.protein_projection is not None:
|
| 74 |
+
protein_embeds = self.protein_projection(protein_embeds)
|
| 75 |
+
|
| 76 |
+
# 为批次扩展查询token
|
| 77 |
+
query_tokens = self.query_tokens.expand(batch_size, -1, -1)
|
| 78 |
+
|
| 79 |
+
# 为查询token创建注意力掩码
|
| 80 |
+
query_attention_mask = torch.ones(
|
| 81 |
+
query_tokens.shape[:2],
|
| 82 |
+
dtype=torch.long,
|
| 83 |
+
device=query_tokens.device
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# BiomedBERT前向传播,使用交叉注意力连接蛋白质嵌入
|
| 87 |
+
query_outputs = self.biomedbert(
|
| 88 |
+
inputs_embeds=query_tokens, # 使用嵌入而不是input_ids
|
| 89 |
+
attention_mask=query_attention_mask,
|
| 90 |
+
encoder_hidden_states=protein_embeds, # 蛋白质嵌入作为编码器状态
|
| 91 |
+
encoder_attention_mask=protein_attention_mask,
|
| 92 |
+
output_attentions=output_attentions,
|
| 93 |
+
output_hidden_states=output_hidden_states,
|
| 94 |
+
return_dict=return_dict,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return query_outputs.last_hidden_state
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class ProteinLLMModel(nn.Module):
|
| 101 |
+
"""
|
| 102 |
+
使用BiomedBERT作为Q-Former的蛋白质-语言模型
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
text_model_name: str,
|
| 108 |
+
protein_model_name: str,
|
| 109 |
+
biomedbert_model_name: str = "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext",
|
| 110 |
+
cache_dir: Optional[str] = None,
|
| 111 |
+
max_length_protein: int = 1024,
|
| 112 |
+
max_length_text: int = 512,
|
| 113 |
+
text_model_finetune: bool = True,
|
| 114 |
+
protein_model_finetune: bool = True,
|
| 115 |
+
biomedbert_finetune: bool = True, # 新增:是否微调BiomedBERT
|
| 116 |
+
qformer_num_query_tokens: int = 32,
|
| 117 |
+
):
|
| 118 |
+
"""
|
| 119 |
+
初始化ProteinLLMModel,使用BiomedBERT作为Q-Former
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
text_model_name: 文本模型名称
|
| 123 |
+
protein_model_name: ESM2蛋白质模型名称
|
| 124 |
+
biomedbert_model_name: BiomedBERT模型名称
|
| 125 |
+
cache_dir: 模型缓存目录
|
| 126 |
+
max_length_protein: 蛋白质序列最大长度
|
| 127 |
+
max_length_text: 文本序列最大长度
|
| 128 |
+
text_model_finetune: 是否微调文本模型
|
| 129 |
+
protein_model_finetune: 是否微调蛋白质模型
|
| 130 |
+
biomedbert_finetune: 是否微调BiomedBERT
|
| 131 |
+
qformer_num_query_tokens: Q-Former查询token数量
|
| 132 |
+
"""
|
| 133 |
+
super().__init__()
|
| 134 |
+
|
| 135 |
+
self.text_model_finetune = text_model_finetune
|
| 136 |
+
self.protein_model_finetune = protein_model_finetune
|
| 137 |
+
self.biomedbert_finetune = biomedbert_finetune
|
| 138 |
+
self.max_length_protein = max_length_protein
|
| 139 |
+
self.max_length_text = max_length_text
|
| 140 |
+
self.qformer_num_query_tokens = qformer_num_query_tokens
|
| 141 |
+
|
| 142 |
+
# 加载文本模型和分词器
|
| 143 |
+
self.text_model = AutoModelForCausalLM.from_pretrained(
|
| 144 |
+
text_model_name, cache_dir=cache_dir, trust_remote_code=True
|
| 145 |
+
)
|
| 146 |
+
self.text_tokenizer = AutoTokenizer.from_pretrained(text_model_name, trust_remote_code=True)
|
| 147 |
+
self.text_config = self.text_model.config
|
| 148 |
+
self.text_tokenizer.chat_template = CHAT_TEMPLATE
|
| 149 |
+
self.text_tokenizer.pad_token = self.text_tokenizer.eos_token
|
| 150 |
+
|
| 151 |
+
# 添加蛋白质特殊token
|
| 152 |
+
new_tokens = ["<|protein_start|>", "<|protein_pad|>", "<|protein_end|>"]
|
| 153 |
+
self.text_tokenizer.add_special_tokens({"additional_special_tokens": new_tokens})
|
| 154 |
+
self.protein_token_id = self.text_tokenizer.convert_tokens_to_ids("<|protein_pad|>")
|
| 155 |
+
|
| 156 |
+
# 加载ESM2蛋白质模型和分词器
|
| 157 |
+
self.protein_model = AutoModelForMaskedLM.from_pretrained(
|
| 158 |
+
protein_model_name, cache_dir=cache_dir, trust_remote_code=True
|
| 159 |
+
)
|
| 160 |
+
self.protein_tokenizer = AutoTokenizer.from_pretrained(protein_model_name, trust_remote_code=True)
|
| 161 |
+
self.protein_config = self.protein_model.config
|
| 162 |
+
|
| 163 |
+
# 获取模型维度
|
| 164 |
+
self.text_hidden_size = self.text_config.hidden_size
|
| 165 |
+
self.protein_hidden_size = self.protein_config.hidden_size
|
| 166 |
+
|
| 167 |
+
# 创建Q-Former配置(简化版本,因为我们直接使用BiomedBERT)
|
| 168 |
+
from types import SimpleNamespace
|
| 169 |
+
qformer_config = SimpleNamespace(
|
| 170 |
+
qformer_config=SimpleNamespace(
|
| 171 |
+
num_query_tokens=qformer_num_query_tokens,
|
| 172 |
+
),
|
| 173 |
+
esm_config=SimpleNamespace(
|
| 174 |
+
esm_hidden_size=self.protein_hidden_size,
|
| 175 |
+
)
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# 初始化BiomedBERT Q-Former
|
| 179 |
+
self.qformer = BiomedBERTQFormer(qformer_config, biomedbert_model_name)
|
| 180 |
+
|
| 181 |
+
# 获取BiomedBERT的隐藏层大小
|
| 182 |
+
biomedbert_hidden_size = self.qformer.biomedbert_hidden_size
|
| 183 |
+
|
| 184 |
+
# 创建投影层,将BiomedBERT输出映射到文本模型嵌入空间
|
| 185 |
+
self.protein_projection = nn.Linear(biomedbert_hidden_size, self.text_hidden_size)
|
| 186 |
+
|
| 187 |
+
# 创建处理器
|
| 188 |
+
self.processor = ProteinLLMProcessor(
|
| 189 |
+
tokenizer=self.text_tokenizer,
|
| 190 |
+
protein_tokenizer=self.protein_tokenizer
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# 根据微调设置冻结模型参数
|
| 194 |
+
if not self.protein_model_finetune:
|
| 195 |
+
for param in self.protein_model.parameters():
|
| 196 |
+
param.requires_grad = False
|
| 197 |
+
|
| 198 |
+
if not self.text_model_finetune:
|
| 199 |
+
for param in self.text_model.parameters():
|
| 200 |
+
param.requires_grad = False
|
| 201 |
+
|
| 202 |
+
if not self.biomedbert_finetune:
|
| 203 |
+
for param in self.qformer.biomedbert.parameters():
|
| 204 |
+
param.requires_grad = False
|
| 205 |
+
|
| 206 |
+
def process_protein_embeddings(
|
| 207 |
+
self,
|
| 208 |
+
protein_tokenized: Dict[str, torch.Tensor],
|
| 209 |
+
batch_idx_map: List[int],
|
| 210 |
+
batch_size: int,
|
| 211 |
+
) -> List[torch.Tensor]:
|
| 212 |
+
"""
|
| 213 |
+
通过ESM2和BiomedBERT Q-Former处理蛋白质序列获得嵌入
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
protein_tokenized: 分词后的蛋白质序列
|
| 217 |
+
batch_idx_map: 每个序列到批次项的映射
|
| 218 |
+
batch_size: 批次中的项目数量
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
每个批次项的张量嵌入列表
|
| 222 |
+
"""
|
| 223 |
+
# 通过ESM2处理所有序列获得表示
|
| 224 |
+
with torch.no_grad() if not self.protein_model_finetune else torch.enable_grad():
|
| 225 |
+
esm_outputs = self.protein_model(
|
| 226 |
+
input_ids=protein_tokenized["input_ids"],
|
| 227 |
+
attention_mask=protein_tokenized["attention_mask"],
|
| 228 |
+
output_hidden_states=True,
|
| 229 |
+
)
|
| 230 |
+
protein_hidden_states = esm_outputs.hidden_states[-1]
|
| 231 |
+
|
| 232 |
+
# 通过BiomedBERT Q-Former处理获得查询嵌入
|
| 233 |
+
result = []
|
| 234 |
+
|
| 235 |
+
for batch_idx in range(batch_size):
|
| 236 |
+
# 找到属于此批次项的序列
|
| 237 |
+
seq_indices = [i for i, idx in enumerate(batch_idx_map) if idx == batch_idx]
|
| 238 |
+
|
| 239 |
+
if not seq_indices:
|
| 240 |
+
# 此批次项没有序列,创建虚拟嵌入
|
| 241 |
+
dummy_embeddings = torch.zeros(
|
| 242 |
+
(self.qformer_num_query_tokens, self.text_hidden_size),
|
| 243 |
+
device=protein_hidden_states.device,
|
| 244 |
+
dtype=protein_hidden_states.dtype
|
| 245 |
+
)
|
| 246 |
+
result.append(dummy_embeddings)
|
| 247 |
+
continue
|
| 248 |
+
|
| 249 |
+
# 获取此批次项的蛋白质嵌入
|
| 250 |
+
batch_protein_embeds = protein_hidden_states[seq_indices]
|
| 251 |
+
batch_attention_mask = protein_tokenized["attention_mask"][seq_indices]
|
| 252 |
+
|
| 253 |
+
# 如果有多个序列,沿序列维度连接
|
| 254 |
+
if len(seq_indices) > 1:
|
| 255 |
+
concat_embeds = batch_protein_embeds.view(1, -1, self.protein_hidden_size)
|
| 256 |
+
concat_mask = batch_attention_mask.view(1, -1)
|
| 257 |
+
else:
|
| 258 |
+
concat_embeds = batch_protein_embeds
|
| 259 |
+
concat_mask = batch_attention_mask
|
| 260 |
+
|
| 261 |
+
# 通过BiomedBERT Q-Former
|
| 262 |
+
with torch.no_grad() if not self.biomedbert_finetune else torch.enable_grad():
|
| 263 |
+
query_embeddings = self.qformer(
|
| 264 |
+
protein_embeds=concat_embeds,
|
| 265 |
+
protein_attention_mask=concat_mask,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# 投影到文本模型空间
|
| 269 |
+
projected_embeddings = self.protein_projection(query_embeddings.squeeze(0))
|
| 270 |
+
result.append(projected_embeddings)
|
| 271 |
+
|
| 272 |
+
return result
|
| 273 |
+
|
| 274 |
+
def forward(
|
| 275 |
+
self,
|
| 276 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 277 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 278 |
+
protein_tokenized: Optional[Dict[str, torch.Tensor]] = None,
|
| 279 |
+
batch_idx_map: Optional[List[int]] = None,
|
| 280 |
+
labels: Optional[torch.Tensor] = None,
|
| 281 |
+
**kwargs,
|
| 282 |
+
) -> torch.Tensor:
|
| 283 |
+
"""前向传播"""
|
| 284 |
+
if input_ids is None or attention_mask is None:
|
| 285 |
+
raise ValueError("input_ids and attention_mask must be provided")
|
| 286 |
+
|
| 287 |
+
batch_size = input_ids.shape[0]
|
| 288 |
+
text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
| 289 |
+
|
| 290 |
+
if protein_tokenized is not None and batch_idx_map:
|
| 291 |
+
batch_protein_embeds = self.process_protein_embeddings(protein_tokenized, batch_idx_map, batch_size)
|
| 292 |
+
|
| 293 |
+
# 用BiomedBERT Q-Former输出替换蛋白质token
|
| 294 |
+
mask = input_ids == self.protein_token_id
|
| 295 |
+
n_protein_tokens = mask.sum().item()
|
| 296 |
+
expected_protein_features = batch_size * self.qformer_num_query_tokens
|
| 297 |
+
|
| 298 |
+
if n_protein_tokens != expected_protein_features:
|
| 299 |
+
protein_embeds_flat = torch.cat(batch_protein_embeds, dim=0)
|
| 300 |
+
n_protein_features = protein_embeds_flat.shape[0]
|
| 301 |
+
|
| 302 |
+
if n_protein_features > n_protein_tokens:
|
| 303 |
+
protein_embeds_flat = protein_embeds_flat[:n_protein_tokens]
|
| 304 |
+
elif n_protein_features < n_protein_tokens:
|
| 305 |
+
padding = torch.zeros(
|
| 306 |
+
(n_protein_tokens - n_protein_features, self.text_hidden_size),
|
| 307 |
+
device=protein_embeds_flat.device,
|
| 308 |
+
dtype=protein_embeds_flat.dtype
|
| 309 |
+
)
|
| 310 |
+
protein_embeds_flat = torch.cat([protein_embeds_flat, padding], dim=0)
|
| 311 |
+
else:
|
| 312 |
+
protein_embeds_flat = torch.cat(batch_protein_embeds, dim=0)
|
| 313 |
+
|
| 314 |
+
protein_embeds_flat = protein_embeds_flat.to(dtype=text_inputs_embeds.dtype)
|
| 315 |
+
text_inputs_embeds[mask] = protein_embeds_flat
|
| 316 |
+
|
| 317 |
+
outputs = self.text_model(
|
| 318 |
+
inputs_embeds=text_inputs_embeds,
|
| 319 |
+
attention_mask=attention_mask,
|
| 320 |
+
labels=labels,
|
| 321 |
+
**kwargs,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
return outputs
|
| 325 |
+
|
| 326 |
+
def generate(
|
| 327 |
+
self,
|
| 328 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 329 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 330 |
+
protein_tokenized: Optional[Dict[str, torch.Tensor]] = None,
|
| 331 |
+
batch_idx_map: Optional[List[int]] = None,
|
| 332 |
+
**generation_kwargs,
|
| 333 |
+
) -> Union[torch.Tensor, List[str]]:
|
| 334 |
+
"""生成文本"""
|
| 335 |
+
if input_ids is None or attention_mask is None:
|
| 336 |
+
raise ValueError("input_ids and attention_mask must be provided")
|
| 337 |
+
|
| 338 |
+
batch_size = input_ids.shape[0]
|
| 339 |
+
text_inputs_embeds = self.text_model.get_input_embeddings()(input_ids)
|
| 340 |
+
|
| 341 |
+
if protein_tokenized is not None and batch_idx_map:
|
| 342 |
+
batch_protein_embeds = self.process_protein_embeddings(protein_tokenized, batch_idx_map, batch_size)
|
| 343 |
+
|
| 344 |
+
mask = input_ids == self.protein_token_id
|
| 345 |
+
n_protein_tokens = mask.sum().item()
|
| 346 |
+
expected_protein_features = batch_size * self.qformer_num_query_tokens
|
| 347 |
+
|
| 348 |
+
if n_protein_tokens != expected_protein_features:
|
| 349 |
+
protein_embeds_flat = torch.cat(batch_protein_embeds, dim=0)
|
| 350 |
+
n_protein_features = protein_embeds_flat.shape[0]
|
| 351 |
+
|
| 352 |
+
if n_protein_features > n_protein_tokens:
|
| 353 |
+
protein_embeds_flat = protein_embeds_flat[:n_protein_tokens]
|
| 354 |
+
elif n_protein_features < n_protein_tokens:
|
| 355 |
+
padding = torch.zeros(
|
| 356 |
+
(n_protein_tokens - n_protein_features, self.text_hidden_size),
|
| 357 |
+
device=protein_embeds_flat.device,
|
| 358 |
+
dtype=protein_embeds_flat.dtype
|
| 359 |
+
)
|
| 360 |
+
protein_embeds_flat = torch.cat([protein_embeds_flat, padding], dim=0)
|
| 361 |
+
else:
|
| 362 |
+
protein_embeds_flat = torch.cat(batch_protein_embeds, dim=0)
|
| 363 |
+
|
| 364 |
+
protein_embeds_flat = protein_embeds_flat.to(dtype=text_inputs_embeds.dtype)
|
| 365 |
+
text_inputs_embeds[mask] = protein_embeds_flat
|
| 366 |
+
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
outputs = self.text_model.generate(
|
| 369 |
+
inputs_embeds=text_inputs_embeds,
|
| 370 |
+
attention_mask=attention_mask,
|
| 371 |
+
use_cache=True,
|
| 372 |
+
**generation_kwargs,
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
return outputs
|
BioReason/bioreason/models/protein_utils.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Optional, Dict, Any
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class ProteinInput:
|
| 7 |
+
"""
|
| 8 |
+
Data class for protein-text input pairs.
|
| 9 |
+
|
| 10 |
+
This class encapsulates a text input along with associated protein sequences
|
| 11 |
+
for processing by the ESM2-LLM model.
|
| 12 |
+
"""
|
| 13 |
+
text: str
|
| 14 |
+
protein_sequences: List[str]
|
| 15 |
+
metadata: Optional[Dict[str, Any]] = None
|
| 16 |
+
|
| 17 |
+
def __post_init__(self):
|
| 18 |
+
"""Validate inputs after initialization."""
|
| 19 |
+
if not isinstance(self.text, str):
|
| 20 |
+
raise TypeError("text must be a string")
|
| 21 |
+
|
| 22 |
+
if not isinstance(self.protein_sequences, list):
|
| 23 |
+
raise TypeError("protein_sequences must be a list")
|
| 24 |
+
|
| 25 |
+
for i, seq in enumerate(self.protein_sequences):
|
| 26 |
+
if not isinstance(seq, str):
|
| 27 |
+
raise TypeError(f"protein_sequences[{i}] must be a string")
|
| 28 |
+
|
| 29 |
+
# Basic validation for protein sequence (amino acid characters)
|
| 30 |
+
valid_aa = set('ACDEFGHIKLMNPQRSTVWYX') # X for unknown
|
| 31 |
+
if not all(aa.upper() in valid_aa for aa in seq):
|
| 32 |
+
raise ValueError(f"protein_sequences[{i}] contains invalid amino acid characters")
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
"""Return the number of protein sequences."""
|
| 36 |
+
return len(self.protein_sequences)
|
| 37 |
+
|
| 38 |
+
def __getitem__(self, index):
|
| 39 |
+
"""Get a specific protein sequence by index."""
|
| 40 |
+
return self.protein_sequences[index]
|
| 41 |
+
|
| 42 |
+
def add_protein_sequence(self, sequence: str):
|
| 43 |
+
"""Add a protein sequence to the input."""
|
| 44 |
+
if not isinstance(sequence, str):
|
| 45 |
+
raise TypeError("sequence must be a string")
|
| 46 |
+
|
| 47 |
+
# Validate amino acid sequence
|
| 48 |
+
valid_aa = set('ACDEFGHIKLMNPQRSTVWYX')
|
| 49 |
+
if not all(aa.upper() in valid_aa for aa in sequence):
|
| 50 |
+
raise ValueError("sequence contains invalid amino acid characters")
|
| 51 |
+
|
| 52 |
+
self.protein_sequences.append(sequence)
|
| 53 |
+
|
| 54 |
+
def get_total_protein_length(self) -> int:
|
| 55 |
+
"""Get the total length of all protein sequences."""
|
| 56 |
+
return sum(len(seq) for seq in self.protein_sequences)
|
| 57 |
+
|
| 58 |
+
def get_protein_lengths(self) -> List[int]:
|
| 59 |
+
"""Get the length of each protein sequence."""
|
| 60 |
+
return [len(seq) for seq in self.protein_sequences]
|
| 61 |
+
|
| 62 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 63 |
+
"""Convert to dictionary representation."""
|
| 64 |
+
return {
|
| 65 |
+
"text": self.text,
|
| 66 |
+
"protein_sequences": self.protein_sequences,
|
| 67 |
+
"metadata": self.metadata,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def from_dict(cls, data: Dict[str, Any]) -> "ProteinInput":
|
| 72 |
+
"""Create ProteinInput from dictionary."""
|
| 73 |
+
return cls(
|
| 74 |
+
text=data["text"],
|
| 75 |
+
protein_sequences=data["protein_sequences"],
|
| 76 |
+
metadata=data.get("metadata"),
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def validate_protein_sequence(sequence: str) -> bool:
|
| 81 |
+
"""
|
| 82 |
+
Validate if a string is a valid protein sequence.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
sequence: String to validate
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
True if valid, False otherwise
|
| 89 |
+
"""
|
| 90 |
+
if not isinstance(sequence, str):
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
valid_aa = set('ACDEFGHIKLMNPQRSTVWYX') # Standard amino acids + X for unknown
|
| 94 |
+
return all(aa.upper() in valid_aa for aa in sequence)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def clean_protein_sequence(sequence: str) -> str:
|
| 98 |
+
"""
|
| 99 |
+
Clean a protein sequence by removing invalid characters and whitespace.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
sequence: Raw protein sequence
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Cleaned protein sequence
|
| 106 |
+
"""
|
| 107 |
+
if not isinstance(sequence, str):
|
| 108 |
+
raise TypeError("sequence must be a string")
|
| 109 |
+
|
| 110 |
+
# Remove whitespace and convert to uppercase
|
| 111 |
+
cleaned = ''.join(sequence.split()).upper()
|
| 112 |
+
|
| 113 |
+
# Keep only valid amino acid characters
|
| 114 |
+
valid_aa = set('ACDEFGHIKLMNPQRSTVWYX')
|
| 115 |
+
cleaned = ''.join(aa for aa in cleaned if aa in valid_aa)
|
| 116 |
+
|
| 117 |
+
return cleaned
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def load_protein_sequences_from_fasta(filepath: str) -> List[Dict[str, str]]:
|
| 121 |
+
"""
|
| 122 |
+
Load protein sequences from a FASTA file.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
filepath: Path to FASTA file
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List of dictionaries with 'id' and 'sequence' keys
|
| 129 |
+
"""
|
| 130 |
+
sequences = []
|
| 131 |
+
current_id = None
|
| 132 |
+
current_seq = []
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
with open(filepath, 'r') as f:
|
| 136 |
+
for line in f:
|
| 137 |
+
line = line.strip()
|
| 138 |
+
if line.startswith('>'):
|
| 139 |
+
# Save previous sequence if exists
|
| 140 |
+
if current_id is not None and current_seq:
|
| 141 |
+
sequences.append({
|
| 142 |
+
'id': current_id,
|
| 143 |
+
'sequence': clean_protein_sequence(''.join(current_seq))
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
# Start new sequence
|
| 147 |
+
current_id = line[1:] # Remove '>'
|
| 148 |
+
current_seq = []
|
| 149 |
+
elif line:
|
| 150 |
+
current_seq.append(line)
|
| 151 |
+
|
| 152 |
+
# Save last sequence
|
| 153 |
+
if current_id is not None and current_seq:
|
| 154 |
+
sequences.append({
|
| 155 |
+
'id': current_id,
|
| 156 |
+
'sequence': clean_protein_sequence(''.join(current_seq))
|
| 157 |
+
})
|
| 158 |
+
|
| 159 |
+
except FileNotFoundError:
|
| 160 |
+
raise FileNotFoundError(f"FASTA file not found: {filepath}")
|
| 161 |
+
except Exception as e:
|
| 162 |
+
raise Exception(f"Error reading FASTA file: {e}")
|
| 163 |
+
|
| 164 |
+
return sequences
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def create_protein_inputs_from_fasta(
|
| 168 |
+
fasta_filepath: str,
|
| 169 |
+
text_template: str = "Analyze this protein: {protein_id}",
|
| 170 |
+
include_sequence_in_text: bool = False,
|
| 171 |
+
) -> List[ProteinInput]:
|
| 172 |
+
"""
|
| 173 |
+
Create ProteinInput objects from a FASTA file.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
fasta_filepath: Path to FASTA file
|
| 177 |
+
text_template: Template for creating text (can use {protein_id})
|
| 178 |
+
include_sequence_in_text: Whether to include sequence in text
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
List of ProteinInput objects
|
| 182 |
+
"""
|
| 183 |
+
sequences = load_protein_sequences_from_fasta(fasta_filepath)
|
| 184 |
+
|
| 185 |
+
protein_inputs = []
|
| 186 |
+
for seq_data in sequences:
|
| 187 |
+
# Create text from template
|
| 188 |
+
text = text_template.format(protein_id=seq_data['id'])
|
| 189 |
+
|
| 190 |
+
if include_sequence_in_text:
|
| 191 |
+
text += f" Sequence: {seq_data['sequence']}"
|
| 192 |
+
|
| 193 |
+
# Create ProteinInput
|
| 194 |
+
protein_input = ProteinInput(
|
| 195 |
+
text=text,
|
| 196 |
+
protein_sequences=[seq_data['sequence']],
|
| 197 |
+
metadata={'protein_id': seq_data['id']}
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
protein_inputs.append(protein_input)
|
| 201 |
+
|
| 202 |
+
return protein_inputs
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# Example usage and utility functions
|
| 206 |
+
if __name__ == "__main__":
|
| 207 |
+
# Example usage
|
| 208 |
+
protein_input = ProteinInput(
|
| 209 |
+
text="What is the function of this protein?",
|
| 210 |
+
protein_sequences=[
|
| 211 |
+
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLA",
|
| 212 |
+
"ACDEFGHIKLMNPQRSTVWY"
|
| 213 |
+
],
|
| 214 |
+
metadata={"source": "example"}
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
print(f"Number of protein sequences: {len(protein_input)}")
|
| 218 |
+
print(f"Total protein length: {protein_input.get_total_protein_length()}")
|
| 219 |
+
print(f"Individual lengths: {protein_input.get_protein_lengths()}")
|
| 220 |
+
|
| 221 |
+
# Test validation
|
| 222 |
+
try:
|
| 223 |
+
invalid_input = ProteinInput(
|
| 224 |
+
text="Test",
|
| 225 |
+
protein_sequences=["INVALID123"] # Contains numbers
|
| 226 |
+
)
|
| 227 |
+
except ValueError as e:
|
| 228 |
+
print(f"Validation caught invalid sequence: {e}")
|
| 229 |
+
|
| 230 |
+
# Test cleaning
|
| 231 |
+
dirty_sequence = " AC DE FG HI 123 KL "
|
| 232 |
+
clean_sequence = clean_protein_sequence(dirty_sequence)
|
| 233 |
+
print(f"Cleaned sequence: '{clean_sequence}'")
|
BioReason/bioreason/trainer/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .grpo_config import DNALLMGRPOConfig
|
| 2 |
+
from .grpo_trainer import DNALLMGRPOTrainer
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"DNALLMGRPOConfig",
|
| 6 |
+
"DNALLMGRPOTrainer",
|
| 7 |
+
]
|
BioReason/bioreason/trainer/demo_grpo.py
ADDED
|
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import textwrap
|
| 17 |
+
import warnings
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
from typing import Any, Callable, Optional, Sized, Union
|
| 20 |
+
from unittest.mock import patch
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.data
|
| 24 |
+
import transformers
|
| 25 |
+
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
| 26 |
+
from accelerate.utils.other import is_compiled_module
|
| 27 |
+
from datasets import Dataset, IterableDataset
|
| 28 |
+
from packaging import version
|
| 29 |
+
from torch import nn
|
| 30 |
+
from torch.utils.data import Sampler
|
| 31 |
+
from transformers import (
|
| 32 |
+
AutoModelForCausalLM,
|
| 33 |
+
AutoModelForSequenceClassification,
|
| 34 |
+
AutoTokenizer,
|
| 35 |
+
GenerationConfig,
|
| 36 |
+
PreTrainedModel,
|
| 37 |
+
PreTrainedTokenizerBase,
|
| 38 |
+
Trainer,
|
| 39 |
+
TrainerCallback,
|
| 40 |
+
is_wandb_available,
|
| 41 |
+
)
|
| 42 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 43 |
+
from transformers.utils import is_peft_available
|
| 44 |
+
|
| 45 |
+
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
| 46 |
+
from trl.import_utils import is_vllm_available
|
| 47 |
+
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
| 48 |
+
from trl import SyncRefModelCallback
|
| 49 |
+
from trl import GRPOConfig
|
| 50 |
+
from trl.trainer.utils import generate_model_card, get_comet_experiment_url, pad, selective_log_softmax
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if is_peft_available():
|
| 54 |
+
from peft import PeftConfig, get_peft_model
|
| 55 |
+
|
| 56 |
+
if is_vllm_available():
|
| 57 |
+
from vllm import LLM, SamplingParams
|
| 58 |
+
|
| 59 |
+
if is_wandb_available():
|
| 60 |
+
import wandb
|
| 61 |
+
|
| 62 |
+
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
| 63 |
+
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
| 64 |
+
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class RepeatRandomSampler(Sampler):
|
| 68 |
+
"""
|
| 69 |
+
Sampler that repeats the indices of a dataset N times.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
data_source (`Sized`):
|
| 73 |
+
Dataset to sample from.
|
| 74 |
+
repeat_count (`int`):
|
| 75 |
+
Number of times to repeat each index.
|
| 76 |
+
seed (`Optional[int]`):
|
| 77 |
+
Random seed for reproducibility (only affects this sampler).
|
| 78 |
+
|
| 79 |
+
Example:
|
| 80 |
+
```python
|
| 81 |
+
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
|
| 82 |
+
>>> list(sampler)
|
| 83 |
+
[2, 2, 0, 0, 3, 3, 1, 1]
|
| 84 |
+
```
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
|
| 88 |
+
self.data_source = data_source
|
| 89 |
+
self.repeat_count = repeat_count
|
| 90 |
+
self.num_samples = len(data_source)
|
| 91 |
+
self.seed = seed
|
| 92 |
+
self.generator = torch.Generator() # Create a local random generator
|
| 93 |
+
if seed is not None:
|
| 94 |
+
self.generator.manual_seed(seed)
|
| 95 |
+
|
| 96 |
+
def __iter__(self):
|
| 97 |
+
indexes = [
|
| 98 |
+
idx
|
| 99 |
+
for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
|
| 100 |
+
for _ in range(self.repeat_count)
|
| 101 |
+
]
|
| 102 |
+
return iter(indexes)
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return self.num_samples * self.repeat_count
|
| 106 |
+
|
| 107 |
+
# made this to test out the usual pipeline of GRPOTrainer data, and add my own debug messages
|
| 108 |
+
class FakeGRPOTrainer(Trainer):
|
| 109 |
+
"""
|
| 110 |
+
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
| 111 |
+
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
| 112 |
+
|
| 113 |
+
Example:
|
| 114 |
+
|
| 115 |
+
```python
|
| 116 |
+
from datasets import load_dataset
|
| 117 |
+
from trl import GRPOTrainer
|
| 118 |
+
|
| 119 |
+
dataset = load_dataset("trl-lib/tldr", split="train")
|
| 120 |
+
|
| 121 |
+
def reward_func(completions, **kwargs):
|
| 122 |
+
# Dummy reward function that rewards completions with more unique letters.
|
| 123 |
+
return [float(len(set(completion))) for completion in completions]
|
| 124 |
+
|
| 125 |
+
trainer = GRPOTrainer(
|
| 126 |
+
model="Qwen/Qwen2-0.5B-Instruct",
|
| 127 |
+
reward_funcs=reward_func,
|
| 128 |
+
train_dataset=dataset,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
trainer.train()
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
model (`Union[str, PreTrainedModel]`):
|
| 136 |
+
Model to be trained. Can be either:
|
| 137 |
+
|
| 138 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
| 139 |
+
a path to a *directory* containing model weights saved using
|
| 140 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
| 141 |
+
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
| 142 |
+
in `args.model_init_kwargs`.
|
| 143 |
+
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
| 144 |
+
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
| 145 |
+
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
| 146 |
+
functions with the prompts and completions and sum the rewards. Can be either:
|
| 147 |
+
|
| 148 |
+
- A single reward function, such as:
|
| 149 |
+
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
| 150 |
+
path to a *directory* containing model weights saved using
|
| 151 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
| 152 |
+
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
| 153 |
+
keyword arguments in `args.model_init_kwargs`.
|
| 154 |
+
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
| 155 |
+
- A custom reward function: The function is provided with the prompts and the generated completions,
|
| 156 |
+
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
| 157 |
+
[Using a custom reward function](#using-a-custom-reward-function).
|
| 158 |
+
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
| 159 |
+
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
| 160 |
+
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
| 161 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
| 162 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
| 163 |
+
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
| 164 |
+
ignored. The format of the samples can be either:
|
| 165 |
+
|
| 166 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
| 167 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
| 168 |
+
and content).
|
| 169 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
| 170 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
| 171 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
| 172 |
+
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
| 173 |
+
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
| 174 |
+
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
| 175 |
+
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
| 176 |
+
|
| 177 |
+
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
| 178 |
+
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
| 179 |
+
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
| 180 |
+
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
| 181 |
+
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
| 182 |
+
the corresponding entries in `reward_processing_classes` are ignored.
|
| 183 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
| 184 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
| 185 |
+
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
| 186 |
+
|
| 187 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
| 188 |
+
method.
|
| 189 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
| 190 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
| 191 |
+
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
| 192 |
+
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
| 193 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
_tag_names = ["trl", "grpo"]
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
model: Union[str, PreTrainedModel],
|
| 201 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
| 202 |
+
args: GRPOConfig = None,
|
| 203 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 204 |
+
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
| 205 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 206 |
+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
| 207 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 208 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 209 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 210 |
+
):
|
| 211 |
+
# Args
|
| 212 |
+
if args is None:
|
| 213 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 214 |
+
model_name = model_name.split("/")[-1]
|
| 215 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
| 216 |
+
|
| 217 |
+
# Models
|
| 218 |
+
# Trained model
|
| 219 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 220 |
+
if isinstance(model, str):
|
| 221 |
+
model_id = model
|
| 222 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 223 |
+
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
| 224 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 225 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 226 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 227 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 228 |
+
else:
|
| 229 |
+
raise ValueError(
|
| 230 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
| 231 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 232 |
+
)
|
| 233 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 234 |
+
model_init_kwargs["use_cache"] = (
|
| 235 |
+
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
| 236 |
+
)
|
| 237 |
+
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
| 238 |
+
else:
|
| 239 |
+
model_id = model.config._name_or_path
|
| 240 |
+
if args.model_init_kwargs is not None:
|
| 241 |
+
raise ValueError(
|
| 242 |
+
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
|
| 243 |
+
"This argument can only be used when the `model` argument is a string."
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if peft_config is not None:
|
| 247 |
+
model = get_peft_model(model, peft_config)
|
| 248 |
+
|
| 249 |
+
# Reference model
|
| 250 |
+
if is_deepspeed_zero3_enabled():
|
| 251 |
+
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
|
| 252 |
+
elif not is_peft_model(model):
|
| 253 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
| 254 |
+
self.ref_model = create_reference_model(model)
|
| 255 |
+
else:
|
| 256 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
| 257 |
+
# to revert to the initial model.
|
| 258 |
+
self.ref_model = None
|
| 259 |
+
|
| 260 |
+
# Processing class
|
| 261 |
+
if processing_class is None:
|
| 262 |
+
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
|
| 263 |
+
|
| 264 |
+
# Reward functions
|
| 265 |
+
if not isinstance(reward_funcs, list):
|
| 266 |
+
reward_funcs = [reward_funcs]
|
| 267 |
+
for i, reward_func in enumerate(reward_funcs):
|
| 268 |
+
if isinstance(reward_func, str):
|
| 269 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
| 270 |
+
reward_func, num_labels=1, **model_init_kwargs
|
| 271 |
+
)
|
| 272 |
+
self.reward_funcs = reward_funcs
|
| 273 |
+
|
| 274 |
+
# Reward weights
|
| 275 |
+
if args.reward_weights is not None:
|
| 276 |
+
if len(args.reward_weights) != len(reward_funcs):
|
| 277 |
+
raise ValueError(
|
| 278 |
+
f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
|
| 279 |
+
f"functions ({len(reward_funcs)})"
|
| 280 |
+
)
|
| 281 |
+
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
|
| 282 |
+
else:
|
| 283 |
+
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
|
| 284 |
+
|
| 285 |
+
# Reward processing class
|
| 286 |
+
if reward_processing_classes is None:
|
| 287 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
| 288 |
+
elif not isinstance(reward_processing_classes, list):
|
| 289 |
+
reward_processing_classes = [reward_processing_classes]
|
| 290 |
+
else:
|
| 291 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
| 292 |
+
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
| 293 |
+
|
| 294 |
+
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
| 295 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 296 |
+
if reward_processing_class is None:
|
| 297 |
+
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
| 298 |
+
if reward_processing_class.pad_token_id is None:
|
| 299 |
+
reward_processing_class.pad_token = reward_processing_class.eos_token
|
| 300 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
| 301 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
| 302 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
| 303 |
+
reward_processing_classes[i] = reward_processing_class
|
| 304 |
+
self.reward_processing_classes = reward_processing_classes
|
| 305 |
+
|
| 306 |
+
# Data collator
|
| 307 |
+
def data_collator(features): # No data collation is needed in GRPO
|
| 308 |
+
return features
|
| 309 |
+
|
| 310 |
+
# Training arguments
|
| 311 |
+
self.max_prompt_length = args.max_prompt_length
|
| 312 |
+
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
| 313 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
| 314 |
+
self.use_vllm = args.use_vllm
|
| 315 |
+
|
| 316 |
+
self.beta = args.beta
|
| 317 |
+
|
| 318 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 319 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
| 320 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
| 321 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
| 322 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
| 323 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
| 324 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 325 |
+
|
| 326 |
+
# Initialize the metrics
|
| 327 |
+
self._metrics = defaultdict(list)
|
| 328 |
+
self.log_completions = args.log_completions
|
| 329 |
+
|
| 330 |
+
super().__init__(
|
| 331 |
+
model=model,
|
| 332 |
+
args=args,
|
| 333 |
+
data_collator=data_collator,
|
| 334 |
+
train_dataset=train_dataset,
|
| 335 |
+
eval_dataset=eval_dataset,
|
| 336 |
+
processing_class=processing_class,
|
| 337 |
+
callbacks=callbacks,
|
| 338 |
+
optimizers=optimizers,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
| 342 |
+
num_processes = self.accelerator.num_processes
|
| 343 |
+
global_batch_size = args.per_device_train_batch_size * num_processes
|
| 344 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 345 |
+
if self.num_generations not in possible_values:
|
| 346 |
+
raise ValueError(
|
| 347 |
+
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
| 348 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
| 349 |
+
f"batch size, the valid values for the number of generations are: {possible_values}."
|
| 350 |
+
)
|
| 351 |
+
if self.args.eval_strategy != "no":
|
| 352 |
+
global_batch_size = args.per_device_eval_batch_size * num_processes
|
| 353 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 354 |
+
if self.num_generations not in possible_values:
|
| 355 |
+
raise ValueError(
|
| 356 |
+
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
| 357 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
| 358 |
+
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
| 362 |
+
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
| 363 |
+
# it's safer to set it in all cases.
|
| 364 |
+
set_seed(args.seed, device_specific=True)
|
| 365 |
+
|
| 366 |
+
if self.use_vllm:
|
| 367 |
+
if not is_vllm_available():
|
| 368 |
+
raise ImportError(
|
| 369 |
+
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
|
| 370 |
+
"`pip install vllm` to use it."
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
if self.accelerator.is_main_process:
|
| 374 |
+
vllm_device = self.args.vllm_device
|
| 375 |
+
if vllm_device == "auto":
|
| 376 |
+
if torch.cuda.device_count() == 1:
|
| 377 |
+
vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
|
| 378 |
+
else:
|
| 379 |
+
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
|
| 380 |
+
# Check that the requested device is available
|
| 381 |
+
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
|
| 382 |
+
raise ValueError(
|
| 383 |
+
f"The requested device for vllm ({vllm_device}) is not available. You are likely using vLLM "
|
| 384 |
+
"without restricting the number of GPUs for training. Set the `--num_processes` argument to a "
|
| 385 |
+
"value lower than the number of GPUs available on your machine—typically, reducing it by one "
|
| 386 |
+
f"is sufficient. In your case: `--num_processes {torch.cuda.device_count() - 1}`."
|
| 387 |
+
)
|
| 388 |
+
# Check that the requested device is not also used for training
|
| 389 |
+
if vllm_device in {f"cuda:{idx}" for idx in range(self.accelerator.num_processes)}:
|
| 390 |
+
warnings.warn(
|
| 391 |
+
f"The requested device {vllm_device} is also being used for training. For higher throughput "
|
| 392 |
+
"and to avoid out-of-memory errors, it is recommended to use a dedicated device for vLLM. "
|
| 393 |
+
"If this is intentional, you may ignore this warning but should adjust "
|
| 394 |
+
"`vllm_gpu_memory_utilization` accordingly."
|
| 395 |
+
)
|
| 396 |
+
# vLLM is not compatible with accelerate. So we need to patch it to make sure we can (1) place the vLLM
|
| 397 |
+
# model on the desired device (world_size_patch) and (2) avoid a test that is not designed for our
|
| 398 |
+
# setting (profiling_patch).
|
| 399 |
+
world_size_patch = patch("torch.distributed.get_world_size", return_value=1)
|
| 400 |
+
profiling_patch = patch(
|
| 401 |
+
"vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling", return_value=None
|
| 402 |
+
)
|
| 403 |
+
with world_size_patch, profiling_patch:
|
| 404 |
+
self.llm = LLM(
|
| 405 |
+
model=model.name_or_path,
|
| 406 |
+
device=vllm_device,
|
| 407 |
+
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
|
| 408 |
+
dtype=self.args.vllm_dtype,
|
| 409 |
+
# Automatic Prefix Caching caches the KV cache of existing queries, so that a new query can
|
| 410 |
+
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
| 411 |
+
# This is particularly useful here because we generate completions from the same prompts.
|
| 412 |
+
enable_prefix_caching=True,
|
| 413 |
+
max_model_len=self.args.vllm_max_model_len,
|
| 414 |
+
)
|
| 415 |
+
self.sampling_params = SamplingParams(
|
| 416 |
+
temperature=args.temperature,
|
| 417 |
+
max_tokens=self.max_completion_length,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
|
| 421 |
+
|
| 422 |
+
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
|
| 423 |
+
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
|
| 424 |
+
# synchronize all processes after vLLM has been fully initialized.
|
| 425 |
+
self.accelerator.wait_for_everyone()
|
| 426 |
+
else:
|
| 427 |
+
self.generation_config = GenerationConfig(
|
| 428 |
+
max_new_tokens=self.max_completion_length,
|
| 429 |
+
do_sample=True,
|
| 430 |
+
temperature=args.temperature,
|
| 431 |
+
pad_token_id=processing_class.pad_token_id,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 435 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 436 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 437 |
+
self.model_accepts_loss_kwargs = False
|
| 438 |
+
|
| 439 |
+
# Add tags to the model
|
| 440 |
+
self.model.add_model_tags(self._tag_names)
|
| 441 |
+
|
| 442 |
+
if self.ref_model is not None:
|
| 443 |
+
if self.is_deepspeed_enabled:
|
| 444 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 445 |
+
else:
|
| 446 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 447 |
+
|
| 448 |
+
if args.sync_ref_model:
|
| 449 |
+
self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
|
| 450 |
+
|
| 451 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 452 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 453 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
| 454 |
+
|
| 455 |
+
def _set_signature_columns_if_needed(self):
|
| 456 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 457 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
| 458 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
| 459 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
| 460 |
+
if self._signature_columns is None:
|
| 461 |
+
self._signature_columns = ["prompt"]
|
| 462 |
+
|
| 463 |
+
def _get_train_sampler(self) -> Sampler:
|
| 464 |
+
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
| 465 |
+
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
| 466 |
+
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
| 467 |
+
# preventing discrepancies in group formation.
|
| 468 |
+
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
|
| 469 |
+
|
| 470 |
+
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
| 471 |
+
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
| 472 |
+
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
| 473 |
+
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
| 474 |
+
# preventing discrepancies in group formation.
|
| 475 |
+
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
|
| 476 |
+
|
| 477 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
| 478 |
+
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
| 479 |
+
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
| 480 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
| 481 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
| 482 |
+
|
| 483 |
+
input_ids = input_ids[:, -logits_to_keep:]
|
| 484 |
+
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
| 485 |
+
# See https://github.com/huggingface/trl/issues/2770
|
| 486 |
+
logits = logits[:, -logits_to_keep:]
|
| 487 |
+
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
| 488 |
+
|
| 489 |
+
def _move_model_to_vllm(self):
|
| 490 |
+
with unwrap_model_for_generation(
|
| 491 |
+
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
| 492 |
+
) as unwrapped_model:
|
| 493 |
+
if is_compiled_module(unwrapped_model):
|
| 494 |
+
unwrapped_model = unwrapped_model._orig_mod
|
| 495 |
+
if is_peft_model(unwrapped_model):
|
| 496 |
+
unwrapped_model.merge_adapter()
|
| 497 |
+
state_dict = unwrapped_model.state_dict()
|
| 498 |
+
# Remove base_model and base_layer prefixes
|
| 499 |
+
state_dict = {
|
| 500 |
+
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
|
| 501 |
+
}
|
| 502 |
+
# Remove values with adapter prefix (example: "_lora")
|
| 503 |
+
state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
|
| 504 |
+
# When module to save, remove its prefix and discard the original module
|
| 505 |
+
state_dict = {
|
| 506 |
+
k.replace("modules_to_save.default.", ""): v
|
| 507 |
+
for k, v in state_dict.items()
|
| 508 |
+
if "original_module" not in k
|
| 509 |
+
}
|
| 510 |
+
else:
|
| 511 |
+
state_dict = unwrapped_model.state_dict()
|
| 512 |
+
if self.accelerator.is_main_process:
|
| 513 |
+
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
| 514 |
+
llm_model.load_weights(state_dict.items())
|
| 515 |
+
# Unmerge the adapter to restore the model to its original state.
|
| 516 |
+
# This must be done after loading weights to ensure they correspond to the merged state.
|
| 517 |
+
if is_peft_model(unwrapped_model):
|
| 518 |
+
unwrapped_model.unmerge_adapter()
|
| 519 |
+
|
| 520 |
+
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
| 521 |
+
device = self.accelerator.device
|
| 522 |
+
prompts = [x["prompt"] for x in inputs]
|
| 523 |
+
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
| 524 |
+
prompt_inputs = self.processing_class(
|
| 525 |
+
prompts_text, return_tensors="pt", padding=True, padding_side="left", add_special_tokens=False
|
| 526 |
+
)
|
| 527 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
| 528 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
| 529 |
+
|
| 530 |
+
if self.max_prompt_length is not None:
|
| 531 |
+
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
| 532 |
+
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
| 533 |
+
|
| 534 |
+
# Generate completions using either vLLM or regular generation
|
| 535 |
+
if self.args.use_vllm:
|
| 536 |
+
# First, have main process load weights if needed
|
| 537 |
+
if self.state.global_step != self._last_loaded_step:
|
| 538 |
+
self._move_model_to_vllm()
|
| 539 |
+
self._last_loaded_step = self.state.global_step
|
| 540 |
+
|
| 541 |
+
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
| 542 |
+
all_prompts_text = gather_object(prompts_text)
|
| 543 |
+
if self.accelerator.is_main_process:
|
| 544 |
+
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
|
| 545 |
+
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
| 546 |
+
else:
|
| 547 |
+
completion_ids = [None] * len(all_prompts_text)
|
| 548 |
+
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
| 549 |
+
# corresponding slice.
|
| 550 |
+
completion_ids = broadcast_object_list(completion_ids, from_process=0)
|
| 551 |
+
process_slice = slice(
|
| 552 |
+
self.accelerator.process_index * len(prompts),
|
| 553 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 554 |
+
)
|
| 555 |
+
completion_ids = completion_ids[process_slice]
|
| 556 |
+
|
| 557 |
+
# Pad the completions, and concatenate them with the prompts
|
| 558 |
+
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
|
| 559 |
+
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
|
| 560 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 561 |
+
else:
|
| 562 |
+
print("about to generate!!")
|
| 563 |
+
# Regular generation path
|
| 564 |
+
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
| 565 |
+
prompt_completion_ids = unwrapped_model.generate(
|
| 566 |
+
prompt_ids, attention_mask=prompt_mask, generation_config=self.generation_config
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
print('prompts_ids', prompt_ids, 'attention_mask', prompt_mask)
|
| 570 |
+
print('prompt_completion_ids', prompt_completion_ids)
|
| 571 |
+
print('prompt len', prompt_ids.size(1))
|
| 572 |
+
|
| 573 |
+
# Compute prompt length and extract completion ids
|
| 574 |
+
prompt_length = prompt_ids.size(1)
|
| 575 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
| 576 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
| 577 |
+
|
| 578 |
+
# Mask everything after the first EOS token
|
| 579 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
| 580 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
| 581 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 582 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
| 583 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 584 |
+
|
| 585 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
| 586 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
| 587 |
+
|
| 588 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
| 589 |
+
|
| 590 |
+
with torch.inference_mode():
|
| 591 |
+
if self.ref_model is not None:
|
| 592 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 593 |
+
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
|
| 594 |
+
)
|
| 595 |
+
else:
|
| 596 |
+
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
| 597 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 598 |
+
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Decode the generated completions
|
| 602 |
+
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
| 603 |
+
if is_conversational(inputs[0]):
|
| 604 |
+
completions = []
|
| 605 |
+
for prompt, completion in zip(prompts, completions_text):
|
| 606 |
+
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
|
| 607 |
+
completions.append([{"role": "assistant", "content": bootstrap + completion}])
|
| 608 |
+
else:
|
| 609 |
+
completions = completions_text
|
| 610 |
+
|
| 611 |
+
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
| 612 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 613 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 614 |
+
):
|
| 615 |
+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
| 616 |
+
if is_conversational(inputs[0]):
|
| 617 |
+
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
| 618 |
+
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
| 619 |
+
else:
|
| 620 |
+
texts = [p + c for p, c in zip(prompts, completions)]
|
| 621 |
+
reward_inputs = reward_processing_class(
|
| 622 |
+
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
| 623 |
+
)
|
| 624 |
+
reward_inputs = super()._prepare_inputs(reward_inputs)
|
| 625 |
+
with torch.inference_mode():
|
| 626 |
+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
| 627 |
+
else:
|
| 628 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
| 629 |
+
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
|
| 630 |
+
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
|
| 631 |
+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
| 632 |
+
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
| 633 |
+
|
| 634 |
+
# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
|
| 635 |
+
# completions may be distributed across processes
|
| 636 |
+
rewards_per_func = gather(rewards_per_func)
|
| 637 |
+
|
| 638 |
+
# Apply weights to each reward function's output and sum
|
| 639 |
+
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
|
| 640 |
+
|
| 641 |
+
# Compute grouped-wise rewards
|
| 642 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 643 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 644 |
+
|
| 645 |
+
# Normalize the rewards to compute the advantages
|
| 646 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 647 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 648 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
| 649 |
+
|
| 650 |
+
# Slice to keep only the local part of the data
|
| 651 |
+
process_slice = slice(
|
| 652 |
+
self.accelerator.process_index * len(prompts),
|
| 653 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 654 |
+
)
|
| 655 |
+
advantages = advantages[process_slice]
|
| 656 |
+
|
| 657 |
+
# Log the metrics
|
| 658 |
+
reward_per_func = rewards_per_func.mean(0)
|
| 659 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 660 |
+
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
| 661 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
| 662 |
+
else:
|
| 663 |
+
reward_func_name = reward_func.__name__
|
| 664 |
+
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
| 665 |
+
|
| 666 |
+
self._metrics["reward"].append(rewards.mean().item())
|
| 667 |
+
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
| 668 |
+
|
| 669 |
+
if (
|
| 670 |
+
self.log_completions
|
| 671 |
+
and self.state.global_step % self.args.logging_steps == 0
|
| 672 |
+
and "wandb" in self.args.report_to
|
| 673 |
+
):
|
| 674 |
+
import pandas as pd
|
| 675 |
+
|
| 676 |
+
# For logging
|
| 677 |
+
table = {
|
| 678 |
+
"step": [str(self.state.global_step)] * len(rewards),
|
| 679 |
+
"prompt": gather_object(prompts_text),
|
| 680 |
+
"completion": gather_object(completions_text),
|
| 681 |
+
"reward": rewards.tolist(),
|
| 682 |
+
}
|
| 683 |
+
df = pd.DataFrame(table)
|
| 684 |
+
|
| 685 |
+
if wandb.run is not None and self.accelerator.is_main_process:
|
| 686 |
+
wandb.log({"completions": wandb.Table(dataframe=df)})
|
| 687 |
+
|
| 688 |
+
return {
|
| 689 |
+
"prompt_ids": prompt_ids,
|
| 690 |
+
"prompt_mask": prompt_mask,
|
| 691 |
+
"completion_ids": completion_ids,
|
| 692 |
+
"completion_mask": completion_mask,
|
| 693 |
+
"ref_per_token_logps": ref_per_token_logps,
|
| 694 |
+
"advantages": advantages,
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 698 |
+
if return_outputs:
|
| 699 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
| 700 |
+
# Compute the per-token log probabilities for the model
|
| 701 |
+
|
| 702 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
| 703 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
| 704 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 705 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 706 |
+
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
| 707 |
+
|
| 708 |
+
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
| 709 |
+
|
| 710 |
+
# Compute the KL divergence between the model and the reference model
|
| 711 |
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
| 712 |
+
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
| 713 |
+
|
| 714 |
+
# x - x.detach() allows for preserving gradients from x
|
| 715 |
+
advantages = inputs["advantages"]
|
| 716 |
+
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
| 717 |
+
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
| 718 |
+
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 719 |
+
|
| 720 |
+
# Log the metrics
|
| 721 |
+
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
| 722 |
+
self._metrics["completion_length"].append(completion_length)
|
| 723 |
+
|
| 724 |
+
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 725 |
+
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 726 |
+
|
| 727 |
+
return loss
|
| 728 |
+
|
| 729 |
+
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
| 730 |
+
inputs = self._prepare_inputs(inputs)
|
| 731 |
+
print("about to loss")
|
| 732 |
+
with torch.no_grad():
|
| 733 |
+
with self.compute_loss_context_manager():
|
| 734 |
+
loss = self.compute_loss(model, inputs)
|
| 735 |
+
loss = loss.mean().detach()
|
| 736 |
+
print("loss computed")
|
| 737 |
+
return loss, None, None
|
| 738 |
+
|
| 739 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 740 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
| 741 |
+
|
| 742 |
+
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
| 743 |
+
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
| 744 |
+
if next(iter(logs.keys())).startswith("eval_"):
|
| 745 |
+
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
| 746 |
+
|
| 747 |
+
logs = {**logs, **metrics}
|
| 748 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 749 |
+
super().log(logs, start_time)
|
| 750 |
+
else: # transformers<=4.46
|
| 751 |
+
super().log(logs)
|
| 752 |
+
self._metrics.clear()
|
| 753 |
+
|
| 754 |
+
def create_model_card(
|
| 755 |
+
self,
|
| 756 |
+
model_name: Optional[str] = None,
|
| 757 |
+
dataset_name: Optional[str] = None,
|
| 758 |
+
tags: Union[str, list[str], None] = None,
|
| 759 |
+
):
|
| 760 |
+
"""
|
| 761 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 762 |
+
|
| 763 |
+
Args:
|
| 764 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 765 |
+
Name of the model.
|
| 766 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 767 |
+
Name of the dataset used for training.
|
| 768 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 769 |
+
Tags to be associated with the model card.
|
| 770 |
+
"""
|
| 771 |
+
if not self.is_world_process_zero():
|
| 772 |
+
return
|
| 773 |
+
|
| 774 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 775 |
+
base_model = self.model.config._name_or_path
|
| 776 |
+
else:
|
| 777 |
+
base_model = None
|
| 778 |
+
|
| 779 |
+
tags = tags or []
|
| 780 |
+
if isinstance(tags, str):
|
| 781 |
+
tags = [tags]
|
| 782 |
+
|
| 783 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 784 |
+
tags.append("unsloth")
|
| 785 |
+
|
| 786 |
+
citation = textwrap.dedent(
|
| 787 |
+
"""\
|
| 788 |
+
@article{zhihong2024deepseekmath,
|
| 789 |
+
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
| 790 |
+
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
| 791 |
+
year = 2024,
|
| 792 |
+
eprint = {arXiv:2402.03300},
|
| 793 |
+
}
|
| 794 |
+
"""
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
model_card = generate_model_card(
|
| 798 |
+
base_model=base_model,
|
| 799 |
+
model_name=model_name,
|
| 800 |
+
hub_model_id=self.hub_model_id,
|
| 801 |
+
dataset_name=dataset_name,
|
| 802 |
+
tags=tags,
|
| 803 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 804 |
+
comet_url=get_comet_experiment_url(),
|
| 805 |
+
trainer_name="GRPO",
|
| 806 |
+
trainer_citation=citation,
|
| 807 |
+
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
| 808 |
+
paper_id="2402.03300",
|
| 809 |
+
)
|
| 810 |
+
|
| 811 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
BioReason/bioreason/trainer/grpo_config.py
ADDED
|
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import Optional, Union
|
| 17 |
+
|
| 18 |
+
from transformers import TrainingArguments
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class DNALLMGRPOConfig(TrainingArguments):
|
| 23 |
+
r"""
|
| 24 |
+
Configuration class for the [`GRPOTrainer`].
|
| 25 |
+
|
| 26 |
+
Only the parameters specific to GRPO training are listed here. For details on other parameters, refer to the
|
| 27 |
+
[`~transformers.TrainingArguments`] documentation.
|
| 28 |
+
|
| 29 |
+
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
| 30 |
+
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
| 31 |
+
command line.
|
| 32 |
+
|
| 33 |
+
Parameters:
|
| 34 |
+
> Parameters that control the model and reference model
|
| 35 |
+
|
| 36 |
+
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
| 37 |
+
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
| 38 |
+
argument of the [`GRPOTrainer`] is provided as a string.
|
| 39 |
+
|
| 40 |
+
> Parameters that control the data preprocessing
|
| 41 |
+
|
| 42 |
+
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
| 43 |
+
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
|
| 44 |
+
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
|
| 45 |
+
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
| 46 |
+
Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left.
|
| 47 |
+
num_generations (`int` or `None`, *optional*, defaults to `8`):
|
| 48 |
+
Number of generations per prompt to sample. The global batch size (num_processes * per_device_batch_size)
|
| 49 |
+
must be divisible by this value.
|
| 50 |
+
max_completion_length (`int` or `None`, *optional*, defaults to `256`):
|
| 51 |
+
Maximum length of the generated completion.
|
| 52 |
+
ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
|
| 53 |
+
This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
|
| 54 |
+
improving generation speed. However, disabling this option allows training models that exceed the VRAM
|
| 55 |
+
capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible
|
| 56 |
+
with vLLM generation.
|
| 57 |
+
|
| 58 |
+
> Parameters that control generation
|
| 59 |
+
|
| 60 |
+
temperature (`float`, defaults to `0.9`):
|
| 61 |
+
Temperature for sampling. The higher the temperature, the more random the completions.
|
| 62 |
+
top_p (`float`, *optional*, defaults to `1.0`):
|
| 63 |
+
Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
|
| 64 |
+
`1.0` to consider all tokens.
|
| 65 |
+
top_k (`int` or `None`, *optional*, defaults to `50`):
|
| 66 |
+
Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
|
| 67 |
+
disabled.
|
| 68 |
+
min_p (`float` or `None`, *optional*, defaults to `None`):
|
| 69 |
+
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
|
| 70 |
+
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
|
| 71 |
+
repetition_penalty (`float`, *optional*, defaults to `1.0`):
|
| 72 |
+
Float that penalizes new tokens based on whether they appear in the prompt and the generated text so far.
|
| 73 |
+
Values > `1.0` encourage the model to use new tokens, while values < `1.0` encourage the model to repeat
|
| 74 |
+
tokens.
|
| 75 |
+
cache_implementation (`str` or `None`, *optional*, defaults to `None`):
|
| 76 |
+
Implementation of the cache method for faster generation when use_vllm is set to False.
|
| 77 |
+
|
| 78 |
+
> Parameters that control generation acceleration powered by vLLM
|
| 79 |
+
|
| 80 |
+
use_vllm (`bool`, *optional*, defaults to `False`):
|
| 81 |
+
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
| 82 |
+
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
| 83 |
+
vllm_device (`str`, *optional*, defaults to `"auto"`):
|
| 84 |
+
Device where vLLM generation will run, e.g. `"cuda:1"`. If set to `"auto"` (default), the system will
|
| 85 |
+
automatically select the next available GPU after the last one used for training. This assumes that
|
| 86 |
+
training has not already occupied all available GPUs. If only one device is available, the device will be
|
| 87 |
+
shared between both training and vLLM.
|
| 88 |
+
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.9`):
|
| 89 |
+
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the
|
| 90 |
+
device dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus
|
| 91 |
+
improve the model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors
|
| 92 |
+
during initialization.
|
| 93 |
+
vllm_dtype (`str`, *optional*, defaults to `"auto"`):
|
| 94 |
+
Data type to use for vLLM generation. If set to `"auto"`, the data type will be automatically determined
|
| 95 |
+
based on the model configuration. Find the supported values in the vLLM documentation.
|
| 96 |
+
vllm_max_model_len (`int` or `None`, *optional*, defaults to `None`):
|
| 97 |
+
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
| 98 |
+
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
| 99 |
+
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
| 100 |
+
vllm_enable_prefix_caching (`bool`, *optional*, defaults to `True`):
|
| 101 |
+
Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and the hardware
|
| 102 |
+
support this feature.
|
| 103 |
+
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
|
| 104 |
+
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
|
| 105 |
+
|
| 106 |
+
> Parameters that control the training
|
| 107 |
+
|
| 108 |
+
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
| 109 |
+
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
| 110 |
+
[`~transformers.TrainingArguments`].
|
| 111 |
+
beta (`float`, *optional*, defaults to `0.04`):
|
| 112 |
+
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
|
| 113 |
+
speed, but may be numerically unstable for long training runs.
|
| 114 |
+
num_iterations (`int`, *optional*, defaults to `1`):
|
| 115 |
+
Number of iterations per batch (denoted as μ in the algorithm).
|
| 116 |
+
epsilon (`float`, *optional*, defaults to `0.2`):
|
| 117 |
+
Epsilon value for clipping.
|
| 118 |
+
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
|
| 119 |
+
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
|
| 120 |
+
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
|
| 121 |
+
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
| 122 |
+
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
| 123 |
+
weighted equally with weight `1.0`.
|
| 124 |
+
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
| 125 |
+
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
| 126 |
+
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
| 127 |
+
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
| 128 |
+
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
|
| 129 |
+
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
| 130 |
+
between the current policy and the previous reference policy during updates. The reference policy is
|
| 131 |
+
updated according to the equation: `π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you
|
| 132 |
+
must set `sync_ref_model=True`.
|
| 133 |
+
ref_model_sync_steps (`int`, *optional*, defaults to `512`):
|
| 134 |
+
τ parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which determines how
|
| 135 |
+
frequently the current policy is synchronized with the reference policy. To use this parameter, you must
|
| 136 |
+
set `sync_ref_model=True`.
|
| 137 |
+
|
| 138 |
+
> Parameters that control the logging
|
| 139 |
+
|
| 140 |
+
log_completions (`bool`, *optional*, defaults to `False`):
|
| 141 |
+
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is
|
| 142 |
+
installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
# Parameters that control the model and reference model
|
| 146 |
+
model_init_kwargs: Optional[dict] = field(
|
| 147 |
+
default=None,
|
| 148 |
+
metadata={
|
| 149 |
+
"help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
|
| 150 |
+
"argument of the `GRPOTrainer` is provided as a string."
|
| 151 |
+
},
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Parameters that control the data preprocessing
|
| 155 |
+
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
|
| 156 |
+
# additional columns to compute the reward
|
| 157 |
+
remove_unused_columns: Optional[bool] = field(
|
| 158 |
+
default=False,
|
| 159 |
+
metadata={
|
| 160 |
+
"help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
|
| 161 |
+
"that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
|
| 162 |
+
},
|
| 163 |
+
)
|
| 164 |
+
max_prompt_length: Optional[int] = field(
|
| 165 |
+
default=512,
|
| 166 |
+
metadata={
|
| 167 |
+
"help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
|
| 168 |
+
},
|
| 169 |
+
)
|
| 170 |
+
num_generations: Optional[int] = field(
|
| 171 |
+
default=8,
|
| 172 |
+
metadata={
|
| 173 |
+
"help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
|
| 174 |
+
"must be divisible by this value."
|
| 175 |
+
},
|
| 176 |
+
)
|
| 177 |
+
max_completion_length: Optional[int] = field(
|
| 178 |
+
default=800,
|
| 179 |
+
metadata={"help": "Maximum length of the generated completion."},
|
| 180 |
+
)
|
| 181 |
+
ds3_gather_for_generation: bool = field(
|
| 182 |
+
default=True,
|
| 183 |
+
metadata={
|
| 184 |
+
"help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
|
| 185 |
+
"generation, improving generation speed. However, disabling this option allows training models that "
|
| 186 |
+
"exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
|
| 187 |
+
"is not compatible with vLLM generation."
|
| 188 |
+
},
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
# Parameters that control generation
|
| 192 |
+
temperature: float = field(
|
| 193 |
+
default=0.6,
|
| 194 |
+
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
|
| 195 |
+
)
|
| 196 |
+
top_p: float = field(
|
| 197 |
+
default=0.95,
|
| 198 |
+
metadata={
|
| 199 |
+
"help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
|
| 200 |
+
"Set to 1.0 to consider all tokens."
|
| 201 |
+
},
|
| 202 |
+
)
|
| 203 |
+
top_k: Optional[int] = field(
|
| 204 |
+
default=20,
|
| 205 |
+
metadata={
|
| 206 |
+
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
|
| 207 |
+
"top-k-filtering is disabled."
|
| 208 |
+
},
|
| 209 |
+
)
|
| 210 |
+
min_p: Optional[float] = field(
|
| 211 |
+
default=None,
|
| 212 |
+
metadata={
|
| 213 |
+
"help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
|
| 214 |
+
"must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
|
| 215 |
+
},
|
| 216 |
+
)
|
| 217 |
+
repetition_penalty: float = field(
|
| 218 |
+
default=1.0,
|
| 219 |
+
metadata={
|
| 220 |
+
"help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
|
| 221 |
+
"text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
|
| 222 |
+
"to repeat tokens."
|
| 223 |
+
},
|
| 224 |
+
)
|
| 225 |
+
cache_implementation: Optional[str] = field(
|
| 226 |
+
default=None,
|
| 227 |
+
metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Parameters that control generation acceleration powered by vLLM
|
| 231 |
+
use_vllm: Optional[bool] = field(
|
| 232 |
+
default=False,
|
| 233 |
+
metadata={
|
| 234 |
+
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept "
|
| 235 |
+
"unused for training, as vLLM will require one for generation. vLLM must be installed "
|
| 236 |
+
"(`pip install vllm`)."
|
| 237 |
+
},
|
| 238 |
+
)
|
| 239 |
+
vllm_device: Optional[str] = field(
|
| 240 |
+
default="auto",
|
| 241 |
+
metadata={
|
| 242 |
+
"help": "Device where vLLM generation will run, e.g. 'cuda:1'. If set to 'auto' (default), the system "
|
| 243 |
+
"will automatically select the next available GPU after the last one used for training. This assumes "
|
| 244 |
+
"that training has not already occupied all available GPUs."
|
| 245 |
+
},
|
| 246 |
+
)
|
| 247 |
+
vllm_gpu_memory_utilization: float = field(
|
| 248 |
+
default=0.9,
|
| 249 |
+
metadata={
|
| 250 |
+
"help": "Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV "
|
| 251 |
+
"cache on the device dedicated to generation powered by vLLM. Higher values will increase the KV cache "
|
| 252 |
+
"size and thus improve the model's throughput. However, if the value is too high, it may cause "
|
| 253 |
+
"out-of-memory (OOM) errors during initialization."
|
| 254 |
+
},
|
| 255 |
+
)
|
| 256 |
+
vllm_dtype: Optional[str] = field(
|
| 257 |
+
default="auto",
|
| 258 |
+
metadata={
|
| 259 |
+
"help": "Data type to use for vLLM generation. If set to 'auto', the data type will be automatically "
|
| 260 |
+
"determined based on the model configuration. Find the supported values in the vLLM documentation."
|
| 261 |
+
},
|
| 262 |
+
)
|
| 263 |
+
vllm_max_model_len: Optional[int] = field(
|
| 264 |
+
default=None,
|
| 265 |
+
metadata={
|
| 266 |
+
"help": "If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced "
|
| 267 |
+
"`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model "
|
| 268 |
+
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
| 269 |
+
},
|
| 270 |
+
)
|
| 271 |
+
vllm_enable_prefix_caching: Optional[bool] = field(
|
| 272 |
+
default=True,
|
| 273 |
+
metadata={
|
| 274 |
+
"help": "Whether to enable prefix caching in vLLM. If set to `True` (default), ensure that the model and "
|
| 275 |
+
"the hardware support this feature."
|
| 276 |
+
},
|
| 277 |
+
)
|
| 278 |
+
vllm_guided_decoding_regex: Optional[str] = field(
|
| 279 |
+
default=None,
|
| 280 |
+
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Parameters that control the training
|
| 284 |
+
learning_rate: float = field(
|
| 285 |
+
default=1e-6,
|
| 286 |
+
metadata={
|
| 287 |
+
"help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
|
| 288 |
+
"`transformers.TrainingArguments`."
|
| 289 |
+
},
|
| 290 |
+
)
|
| 291 |
+
beta: float = field(
|
| 292 |
+
default=0.04,
|
| 293 |
+
metadata={
|
| 294 |
+
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
|
| 295 |
+
"training speed, but may be numerically unstable for long training runs."
|
| 296 |
+
},
|
| 297 |
+
)
|
| 298 |
+
num_iterations: int = field(
|
| 299 |
+
default=1,
|
| 300 |
+
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
|
| 301 |
+
)
|
| 302 |
+
epsilon: float = field(
|
| 303 |
+
default=0.2,
|
| 304 |
+
metadata={"help": "Epsilon value for clipping."},
|
| 305 |
+
)
|
| 306 |
+
epsilon_high: Optional[float] = field(
|
| 307 |
+
default=None,
|
| 308 |
+
metadata={
|
| 309 |
+
"help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
|
| 310 |
+
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
|
| 311 |
+
},
|
| 312 |
+
)
|
| 313 |
+
reward_weights: Optional[list[float]] = field(
|
| 314 |
+
default=None,
|
| 315 |
+
metadata={
|
| 316 |
+
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
|
| 317 |
+
"rewards are weighted equally with weight `1.0`."
|
| 318 |
+
},
|
| 319 |
+
)
|
| 320 |
+
sync_ref_model: bool = field(
|
| 321 |
+
default=False,
|
| 322 |
+
metadata={
|
| 323 |
+
"help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
|
| 324 |
+
"steps, using the `ref_model_mixup_alpha` parameter."
|
| 325 |
+
},
|
| 326 |
+
)
|
| 327 |
+
ref_model_mixup_alpha: float = field(
|
| 328 |
+
default=0.6,
|
| 329 |
+
metadata={
|
| 330 |
+
"help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
|
| 331 |
+
"previous reference policy during updates. The reference policy is updated according to the equation: "
|
| 332 |
+
"`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
|
| 333 |
+
},
|
| 334 |
+
)
|
| 335 |
+
ref_model_sync_steps: int = field(
|
| 336 |
+
default=512,
|
| 337 |
+
metadata={
|
| 338 |
+
"help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
|
| 339 |
+
"synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
|
| 340 |
+
},
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Parameters that control the logging
|
| 344 |
+
log_completions: bool = field(
|
| 345 |
+
default=True,
|
| 346 |
+
metadata={
|
| 347 |
+
"help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
|
| 348 |
+
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
|
| 349 |
+
},
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
report_to: Union[None, str, list[str]] = field(
|
| 353 |
+
default="wandb", metadata={"help": "The list of integrations to report the results and logs to."}
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
logging_first_step: bool = field(default=False, metadata={"help": "Log the first global_step"})
|
| 357 |
+
logging_steps: float = field(
|
| 358 |
+
default=2,
|
| 359 |
+
metadata={
|
| 360 |
+
"help": (
|
| 361 |
+
"Log every X updates steps. Should be an integer or a float in range `[0,1)`. "
|
| 362 |
+
"If smaller than 1, will be interpreted as ratio of total training steps."
|
| 363 |
+
)
|
| 364 |
+
},
|
| 365 |
+
)
|
BioReason/bioreason/trainer/grpo_trainer.py
ADDED
|
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
import textwrap
|
| 18 |
+
import pandas as pd
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from typing import Any, Callable, Optional, Union, Sized
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.data
|
| 24 |
+
import transformers
|
| 25 |
+
from datasets import Dataset, IterableDataset
|
| 26 |
+
from packaging import version
|
| 27 |
+
from transformers import (
|
| 28 |
+
AriaForConditionalGeneration,
|
| 29 |
+
AriaProcessor,
|
| 30 |
+
AutoModelForCausalLM,
|
| 31 |
+
AutoModelForSequenceClassification,
|
| 32 |
+
AutoProcessor,
|
| 33 |
+
AutoTokenizer,
|
| 34 |
+
GenerationConfig,
|
| 35 |
+
PreTrainedModel,
|
| 36 |
+
PreTrainedTokenizerBase,
|
| 37 |
+
Qwen2VLForConditionalGeneration,
|
| 38 |
+
Qwen2_5_VLForConditionalGeneration,
|
| 39 |
+
Trainer,
|
| 40 |
+
TrainerCallback,
|
| 41 |
+
is_wandb_available,
|
| 42 |
+
)
|
| 43 |
+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
| 44 |
+
from transformers.utils import is_peft_available
|
| 45 |
+
|
| 46 |
+
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
| 47 |
+
from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
| 48 |
+
from trl.trainer.grpo_config import GRPOConfig
|
| 49 |
+
from trl.trainer.utils import generate_model_card, get_comet_experiment_url
|
| 50 |
+
# from trl import GRPOTrainer
|
| 51 |
+
|
| 52 |
+
from accelerate.utils import is_peft_model, set_seed, gather_object
|
| 53 |
+
import PIL.Image
|
| 54 |
+
|
| 55 |
+
import copy
|
| 56 |
+
from torch.utils.data import Sampler
|
| 57 |
+
import warnings
|
| 58 |
+
|
| 59 |
+
if is_peft_available():
|
| 60 |
+
from peft import PeftConfig, get_peft_model, prepare_model_for_kbit_training
|
| 61 |
+
|
| 62 |
+
if is_wandb_available():
|
| 63 |
+
import wandb
|
| 64 |
+
|
| 65 |
+
from bioreason.dna_modules.dna_module import DNABaseModule
|
| 66 |
+
from bioreason.trainer import DNALLMGRPOConfig
|
| 67 |
+
# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
|
| 68 |
+
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
| 69 |
+
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RepeatRandomSampler(Sampler):
|
| 73 |
+
"""
|
| 74 |
+
Sampler that repeats the indices of a dataset in a structured manner.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
data_source (`Sized`):
|
| 78 |
+
Dataset to sample from.
|
| 79 |
+
mini_repeat_count (`int`):
|
| 80 |
+
Number of times to repeat each index per batch.
|
| 81 |
+
batch_size (`int`, *optional*, defaults to `1`):
|
| 82 |
+
Number of unique indices per batch.
|
| 83 |
+
repeat_count (`int`, *optional*, defaults to `1`):
|
| 84 |
+
Number of times to repeat the full sampling process.
|
| 85 |
+
seed (`int` or `None`, *optional*, defaults to `None`):
|
| 86 |
+
Random seed for reproducibility.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
data_source: Sized,
|
| 92 |
+
mini_repeat_count: int,
|
| 93 |
+
batch_size: int = 1,
|
| 94 |
+
repeat_count: int = 1,
|
| 95 |
+
seed: Optional[int] = None,
|
| 96 |
+
):
|
| 97 |
+
self.data_source = data_source
|
| 98 |
+
self.mini_repeat_count = mini_repeat_count
|
| 99 |
+
self.batch_size = batch_size
|
| 100 |
+
self.repeat_count = repeat_count
|
| 101 |
+
self.num_samples = len(data_source)
|
| 102 |
+
self.seed = seed
|
| 103 |
+
self.generator = torch.Generator()
|
| 104 |
+
if seed is not None:
|
| 105 |
+
self.generator.manual_seed(seed)
|
| 106 |
+
|
| 107 |
+
def __iter__(self):
|
| 108 |
+
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
| 109 |
+
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
|
| 110 |
+
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
|
| 111 |
+
|
| 112 |
+
for chunk in indexes:
|
| 113 |
+
for _ in range(self.repeat_count):
|
| 114 |
+
for index in chunk:
|
| 115 |
+
for _ in range(self.mini_repeat_count):
|
| 116 |
+
yield index
|
| 117 |
+
|
| 118 |
+
def __len__(self) -> int:
|
| 119 |
+
return self.num_samples * self.mini_repeat_count * self.repeat_count
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class DNALLMGRPOTrainer(Trainer):
|
| 123 |
+
"""
|
| 124 |
+
Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
|
| 125 |
+
paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
|
| 126 |
+
|
| 127 |
+
Example:
|
| 128 |
+
|
| 129 |
+
```python
|
| 130 |
+
from datasets import load_dataset
|
| 131 |
+
from trl import GRPOTrainer
|
| 132 |
+
|
| 133 |
+
dataset = load_dataset("trl-lib/tldr", split="train")
|
| 134 |
+
|
| 135 |
+
trainer = GRPOTrainer(
|
| 136 |
+
model="Qwen/Qwen2-0.5B-Instruct",
|
| 137 |
+
reward_funcs="weqweasdas/RM-Gemma-2B",
|
| 138 |
+
train_dataset=dataset,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
trainer.train()
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Args:
|
| 145 |
+
model (`Union[str, PreTrainedModel]`):
|
| 146 |
+
Model to be trained. Can be either:
|
| 147 |
+
|
| 148 |
+
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
| 149 |
+
a path to a *directory* containing model weights saved using
|
| 150 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
| 151 |
+
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
| 152 |
+
in `args.model_init_kwargs`.
|
| 153 |
+
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
| 154 |
+
reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
|
| 155 |
+
Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
|
| 156 |
+
functions with the prompts and completions and sum the rewards. Can be either:
|
| 157 |
+
|
| 158 |
+
- A single reward function, such as:
|
| 159 |
+
- A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
| 160 |
+
path to a *directory* containing model weights saved using
|
| 161 |
+
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
| 162 |
+
using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
|
| 163 |
+
keyword arguments in `args.model_init_kwargs`.
|
| 164 |
+
- A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
|
| 165 |
+
- A custom reward function: The function is provided with the prompts and the generated completions,
|
| 166 |
+
plus any additional columns in the dataset. It should return a list of rewards. For more details, see
|
| 167 |
+
[Using a custom reward function](#using-a-custom-reward-function).
|
| 168 |
+
- A list of reward functions, where each item can independently be any of the above types. Mixing different
|
| 169 |
+
types within the list (e.g., a string model ID and a custom reward function) is allowed.
|
| 170 |
+
args ([`GRPOConfig`], *optional*, defaults to `None`):
|
| 171 |
+
Configuration for this trainer. If `None`, a default configuration is used.
|
| 172 |
+
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
| 173 |
+
Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
|
| 174 |
+
ignored. The format of the samples can be either:
|
| 175 |
+
|
| 176 |
+
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
| 177 |
+
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
| 178 |
+
and content).
|
| 179 |
+
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
| 180 |
+
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
| 181 |
+
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
| 182 |
+
Processing class used to process the data. The padding side must be set to "left". If `None`, the
|
| 183 |
+
processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
|
| 184 |
+
reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
|
| 185 |
+
Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
|
| 186 |
+
|
| 187 |
+
- A single processing class: Used when `reward_funcs` contains only one reward function.
|
| 188 |
+
- A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
|
| 189 |
+
If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
|
| 190 |
+
`None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
|
| 191 |
+
For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
|
| 192 |
+
the corresponding entries in `reward_processing_classes` are ignored.
|
| 193 |
+
callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
|
| 194 |
+
List of callbacks to customize the training loop. Will add those to the list of default callbacks
|
| 195 |
+
detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
| 196 |
+
|
| 197 |
+
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
| 198 |
+
method.
|
| 199 |
+
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
|
| 200 |
+
A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
|
| 201 |
+
model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
|
| 202 |
+
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
| 203 |
+
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(
|
| 207 |
+
self,
|
| 208 |
+
model: Union[str, PreTrainedModel],
|
| 209 |
+
reward_funcs: Union[RewardFunc, list[RewardFunc]],
|
| 210 |
+
args: DNALLMGRPOConfig = None,
|
| 211 |
+
dna_module: DNABaseModule = None,
|
| 212 |
+
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
| 213 |
+
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
|
| 214 |
+
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
| 215 |
+
reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
|
| 216 |
+
callbacks: Optional[list[TrainerCallback]] = None,
|
| 217 |
+
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
| 218 |
+
peft_config: Optional["PeftConfig"] = None,
|
| 219 |
+
freeze_dna_modules: Optional[bool] = False,
|
| 220 |
+
attn_implementation: str = "flash_attention_2",
|
| 221 |
+
torch_dtype: str = "bfloat16",
|
| 222 |
+
**kwargs,
|
| 223 |
+
):
|
| 224 |
+
# Args
|
| 225 |
+
if args is None:
|
| 226 |
+
model_name = model if isinstance(model, str) else model.config._name_or_path
|
| 227 |
+
model_name = model_name.split("/")[-1]
|
| 228 |
+
args = GRPOConfig(f"{model_name}-GRPO")
|
| 229 |
+
|
| 230 |
+
self.dna_module = dna_module
|
| 231 |
+
|
| 232 |
+
# Models
|
| 233 |
+
# Trained model
|
| 234 |
+
model_init_kwargs = args.model_init_kwargs or {}
|
| 235 |
+
# FIXME
|
| 236 |
+
# Remember to modify it in the invernvl
|
| 237 |
+
model_init_kwargs["attn_implementation"] = attn_implementation
|
| 238 |
+
if model_init_kwargs.get("torch_dtype") is None:
|
| 239 |
+
model_init_kwargs["torch_dtype"] = torch_dtype
|
| 240 |
+
|
| 241 |
+
assert not isinstance(model, str), "model must NOT be a string in the current implementation"
|
| 242 |
+
|
| 243 |
+
torch_dtype = model_init_kwargs.get("torch_dtype")
|
| 244 |
+
if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
|
| 245 |
+
pass # torch_dtype is already a torch.dtype or "auto" or None
|
| 246 |
+
elif isinstance(torch_dtype, str): # it's a str, but not "auto"
|
| 247 |
+
torch_dtype = getattr(torch, torch_dtype)
|
| 248 |
+
else:
|
| 249 |
+
raise ValueError(
|
| 250 |
+
"Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
|
| 251 |
+
f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
|
| 252 |
+
)
|
| 253 |
+
# Disable caching if gradient checkpointing is enabled (not supported)
|
| 254 |
+
model_init_kwargs["use_cache"] = (
|
| 255 |
+
False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
# LoRA
|
| 259 |
+
self.dna_modules_keywords = self.dna_module.get_dnallm_modules_keywords()
|
| 260 |
+
if peft_config is not None:
|
| 261 |
+
print("Applying LoRA...")
|
| 262 |
+
def find_all_linear_names(model, multimodal_keywords):
|
| 263 |
+
cls = torch.nn.Linear
|
| 264 |
+
lora_module_names = set()
|
| 265 |
+
for name, module in model.named_modules():
|
| 266 |
+
print('name:', name, 'module:', module)
|
| 267 |
+
# LoRA is not applied to the DNA modules
|
| 268 |
+
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
|
| 269 |
+
continue
|
| 270 |
+
if isinstance(module, cls):
|
| 271 |
+
lora_module_names.add(name)
|
| 272 |
+
for m in lora_module_names: # needed for 16-bit
|
| 273 |
+
if "embed_tokens" in m:
|
| 274 |
+
lora_module_names.remove(m)
|
| 275 |
+
return list(lora_module_names)
|
| 276 |
+
target_modules = find_all_linear_names(model, self.dna_modules_keywords)
|
| 277 |
+
peft_config.target_modules = target_modules
|
| 278 |
+
model = prepare_model_for_kbit_training(model)
|
| 279 |
+
model = get_peft_model(model, peft_config)
|
| 280 |
+
|
| 281 |
+
# Freeze DNA modules
|
| 282 |
+
if freeze_dna_modules:
|
| 283 |
+
print("Freezing DNA modules...")
|
| 284 |
+
for p in model.dna_model.parameters():
|
| 285 |
+
p.requires_grad = False
|
| 286 |
+
|
| 287 |
+
# Make projection layer trainable
|
| 288 |
+
for p in model.dna_projection.parameters():
|
| 289 |
+
p.required_grad = True
|
| 290 |
+
|
| 291 |
+
# Compute the number of trainable parameters and print the parameter that is trainable
|
| 292 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 293 |
+
total_params = sum(p.numel() for p in trainable_params)
|
| 294 |
+
# for n, p in model.named_parameters():
|
| 295 |
+
# if p.requires_grad:
|
| 296 |
+
# print(n, p.shape)
|
| 297 |
+
print(f"Total trainable parameters: {total_params}")
|
| 298 |
+
|
| 299 |
+
# Enable gradient checkpointing if requested
|
| 300 |
+
if args.gradient_checkpointing:
|
| 301 |
+
model = self._enable_gradient_checkpointing(model, args)
|
| 302 |
+
|
| 303 |
+
# Reference model
|
| 304 |
+
self.beta = args.beta
|
| 305 |
+
if self.beta == 0.0:
|
| 306 |
+
# If beta is 0.0, the reference model is not needed
|
| 307 |
+
self.ref_model = None
|
| 308 |
+
elif is_deepspeed_zero3_enabled():
|
| 309 |
+
self.ref_model = model_cls.from_pretrained(model_id, **model_init_kwargs)
|
| 310 |
+
elif is_peft_model(model):
|
| 311 |
+
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
| 312 |
+
# to revert to the initial model.
|
| 313 |
+
self.ref_model = None
|
| 314 |
+
else:
|
| 315 |
+
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
| 316 |
+
self.ref_model = create_reference_model(model)
|
| 317 |
+
|
| 318 |
+
# Processing class
|
| 319 |
+
if processing_class is None:
|
| 320 |
+
processing_cls = self.dna_module.get_processing_class()
|
| 321 |
+
|
| 322 |
+
#if isinstance(model.text_model)
|
| 323 |
+
processing_class = processing_cls(tokenizer=model.text_tokenizer, dna_tokenizer=model.dna_tokenizer)
|
| 324 |
+
# print(model.tokenizer.chat_template)
|
| 325 |
+
for component, processing_keyword in self.dna_module.get_custom_processing_keywords():
|
| 326 |
+
if processing_keyword in kwargs:
|
| 327 |
+
# If we cannot find component in processing_class, return the processing_class itself
|
| 328 |
+
processing_component = getattr(processing_class, component, processing_class)
|
| 329 |
+
setattr(processing_component, processing_keyword, kwargs[processing_keyword])
|
| 330 |
+
if getattr(processing_class, "tokenizer", None) is not None:
|
| 331 |
+
pad_token_id = processing_class.tokenizer.pad_token_id
|
| 332 |
+
processing_class.pad_token_id = pad_token_id
|
| 333 |
+
processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
|
| 334 |
+
else:
|
| 335 |
+
assert isinstance(processing_class, PreTrainedTokenizerBase), "processing_class must be an instance of PreTrainedTokenizerBase if it has no tokenizer attribute"
|
| 336 |
+
pad_token_id = processing_class.pad_token_id
|
| 337 |
+
|
| 338 |
+
self.dna_module.post_model_init(model, processing_class)
|
| 339 |
+
self.dna_module.post_model_init(self.ref_model, processing_class)
|
| 340 |
+
|
| 341 |
+
# Reward functions
|
| 342 |
+
if not isinstance(reward_funcs, list):
|
| 343 |
+
reward_funcs = [reward_funcs]
|
| 344 |
+
for i, reward_func in enumerate(reward_funcs):
|
| 345 |
+
if isinstance(reward_func, str):
|
| 346 |
+
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
|
| 347 |
+
reward_func, num_labels=1, **model_init_kwargs
|
| 348 |
+
)
|
| 349 |
+
self.reward_funcs = reward_funcs
|
| 350 |
+
|
| 351 |
+
# Reward processing class
|
| 352 |
+
if reward_processing_classes is None:
|
| 353 |
+
reward_processing_classes = [None] * len(reward_funcs)
|
| 354 |
+
elif not isinstance(reward_processing_classes, list):
|
| 355 |
+
reward_processing_classes = [reward_processing_classes]
|
| 356 |
+
else:
|
| 357 |
+
if len(reward_processing_classes) != len(reward_funcs):
|
| 358 |
+
raise ValueError("The number of reward processing classes must match the number of reward functions.")
|
| 359 |
+
|
| 360 |
+
for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
|
| 361 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 362 |
+
if reward_processing_class is None:
|
| 363 |
+
reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
|
| 364 |
+
if reward_processing_class.pad_token_id is None:
|
| 365 |
+
reward_processing_class.pad_token = reward_processing_class.eos_token
|
| 366 |
+
# The reward model computes the reward for the latest non-padded token in the input sequence.
|
| 367 |
+
# So it's important to set the pad token ID to the padding token ID of the processing class.
|
| 368 |
+
reward_func.config.pad_token_id = reward_processing_class.pad_token_id
|
| 369 |
+
reward_processing_classes[i] = reward_processing_class
|
| 370 |
+
self.reward_processing_classes = reward_processing_classes
|
| 371 |
+
|
| 372 |
+
# Data collator
|
| 373 |
+
def data_collator(features): # No data collation is needed in GRPO
|
| 374 |
+
return features
|
| 375 |
+
|
| 376 |
+
# Training arguments
|
| 377 |
+
self.max_prompt_length = args.max_prompt_length
|
| 378 |
+
self.max_prompt_length = None
|
| 379 |
+
if args.max_prompt_length is not None:
|
| 380 |
+
warnings.warn("Setting max_prompt_length is currently not supported, it has been set to None")
|
| 381 |
+
|
| 382 |
+
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
|
| 383 |
+
self.num_generations = args.num_generations # = G in the GRPO paper
|
| 384 |
+
self.generation_config = GenerationConfig(
|
| 385 |
+
max_new_tokens=self.max_completion_length,
|
| 386 |
+
do_sample=True,
|
| 387 |
+
temperature=0.6,
|
| 388 |
+
top_p=0.95,
|
| 389 |
+
top_k=20,
|
| 390 |
+
pad_token_id=pad_token_id,
|
| 391 |
+
)
|
| 392 |
+
if hasattr(self.dna_module, "get_eos_token_id"): # For InternVL
|
| 393 |
+
self.generation_config.eos_token_id = self.dna_module.get_eos_token_id(processing_class)
|
| 394 |
+
self.beta = args.beta
|
| 395 |
+
self.epsilon_low = args.epsilon
|
| 396 |
+
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
|
| 397 |
+
|
| 398 |
+
# Multi-step
|
| 399 |
+
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
| 400 |
+
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle
|
| 401 |
+
self._step = 0
|
| 402 |
+
# Buffer the batch to reuse generated outputs across multiple updates
|
| 403 |
+
self._buffered_inputs = [None] * args.gradient_accumulation_steps
|
| 404 |
+
|
| 405 |
+
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
| 406 |
+
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
| 407 |
+
# "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
|
| 408 |
+
# "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
|
| 409 |
+
# suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
|
| 410 |
+
# This acts as a flag to indicate that the warning has already been issued.
|
| 411 |
+
model.warnings_issued["estimate_tokens"] = True
|
| 412 |
+
|
| 413 |
+
# Initialize the metrics
|
| 414 |
+
self._metrics = defaultdict(list)
|
| 415 |
+
self.log_completions = args.log_completions
|
| 416 |
+
|
| 417 |
+
super().__init__(
|
| 418 |
+
model=model,
|
| 419 |
+
args=args,
|
| 420 |
+
data_collator=data_collator,
|
| 421 |
+
train_dataset=train_dataset,
|
| 422 |
+
eval_dataset=eval_dataset,
|
| 423 |
+
processing_class=processing_class,
|
| 424 |
+
callbacks=callbacks,
|
| 425 |
+
optimizers=optimizers,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
# Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
|
| 429 |
+
num_processes = self.accelerator.num_processes
|
| 430 |
+
global_batch_size = args.per_device_train_batch_size * num_processes
|
| 431 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 432 |
+
if self.num_generations not in possible_values:
|
| 433 |
+
raise ValueError(
|
| 434 |
+
f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
|
| 435 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
|
| 436 |
+
f"batch size, the valid values for the number of generations are: {possible_values}."
|
| 437 |
+
)
|
| 438 |
+
if self.args.eval_strategy != "no":
|
| 439 |
+
global_batch_size = args.per_device_eval_batch_size * num_processes
|
| 440 |
+
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
|
| 441 |
+
if self.num_generations not in possible_values:
|
| 442 |
+
raise ValueError(
|
| 443 |
+
f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
|
| 444 |
+
f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
|
| 445 |
+
f"eval batch size, the valid values for the number of generations are: {possible_values}."
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
|
| 449 |
+
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
|
| 450 |
+
# it's safer to set it in all cases.
|
| 451 |
+
set_seed(args.seed, device_specific=True)
|
| 452 |
+
|
| 453 |
+
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
| 454 |
+
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
| 455 |
+
# self.model_accepts_loss_kwargs to False to enable scaling.
|
| 456 |
+
self.model_accepts_loss_kwargs = False
|
| 457 |
+
|
| 458 |
+
if self.ref_model is not None:
|
| 459 |
+
# if self.is_deepspeed_enabled:
|
| 460 |
+
if is_deepspeed_zero3_enabled():
|
| 461 |
+
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
| 462 |
+
else:
|
| 463 |
+
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
| 464 |
+
|
| 465 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 466 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 467 |
+
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
| 468 |
+
|
| 469 |
+
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
|
| 470 |
+
"""Enables gradient checkpointing for the model."""
|
| 471 |
+
# Ensure use_cache is disabled
|
| 472 |
+
model.config.use_cache = False
|
| 473 |
+
|
| 474 |
+
# Enable gradient checkpointing on the base model for PEFT
|
| 475 |
+
if is_peft_model(model):
|
| 476 |
+
model.base_model.gradient_checkpointing_enable()
|
| 477 |
+
# Enable gradient checkpointing for non-PEFT models
|
| 478 |
+
else:
|
| 479 |
+
if getattr(model, "language_model", None) is not None:
|
| 480 |
+
# For InternVL; these operations are copied from the original training script of InternVL
|
| 481 |
+
model.language_model.config.use_cache = False
|
| 482 |
+
model.dna_model.gradient_checkpointing = True
|
| 483 |
+
model.dna_model.encoder.gradient_checkpointing = True
|
| 484 |
+
model.language_model._set_gradient_checkpointing()
|
| 485 |
+
# This line is necessary, otherwise the `model.gradient_checkpointing_enable()` will be executed during the training process, leading to an error since InternVL does not support this operation.
|
| 486 |
+
args.gradient_checkpointing = False
|
| 487 |
+
else:
|
| 488 |
+
model.gradient_checkpointing_enable()
|
| 489 |
+
|
| 490 |
+
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
| 491 |
+
use_reentrant = (
|
| 492 |
+
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if use_reentrant:
|
| 496 |
+
model.enable_input_require_grads()
|
| 497 |
+
|
| 498 |
+
return model
|
| 499 |
+
|
| 500 |
+
def _set_signature_columns_if_needed(self):
|
| 501 |
+
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
| 502 |
+
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
| 503 |
+
# In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
|
| 504 |
+
# Instead, we set them to the columns expected by the `training_step` method, hence the override.
|
| 505 |
+
if self._signature_columns is None:
|
| 506 |
+
self._signature_columns = ["prompt"]
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
# Get the per-token log probabilities for the completions for the model and the reference model
|
| 510 |
+
def _get_per_token_logps(self, model, input_ids, attention_mask, **custom_multimodal_inputs):
|
| 511 |
+
logits = model(input_ids=input_ids, attention_mask=attention_mask, **custom_multimodal_inputs).logits # (B, L, V)
|
| 512 |
+
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
| 513 |
+
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
|
| 514 |
+
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
|
| 515 |
+
per_token_logps = []
|
| 516 |
+
for logits_row, input_ids_row in zip(logits, input_ids):
|
| 517 |
+
log_probs = logits_row.log_softmax(dim=-1)
|
| 518 |
+
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
|
| 519 |
+
per_token_logps.append(token_log_prob)
|
| 520 |
+
return torch.stack(per_token_logps)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def _prepare_inputs(self, inputs):
|
| 524 |
+
# Simple pass-through, just like original
|
| 525 |
+
return inputs
|
| 526 |
+
|
| 527 |
+
def _get_key_from_inputs(self, x, key):
|
| 528 |
+
ele = x.get(key, None)
|
| 529 |
+
assert ele is not None, f"The key {key} is not found in the input"
|
| 530 |
+
if isinstance(ele, list):
|
| 531 |
+
return [e for e in ele]
|
| 532 |
+
else:
|
| 533 |
+
return [ele]
|
| 534 |
+
|
| 535 |
+
def _generate_and_score_completions(self, inputs: dict[str, Union[torch.Tensor, Any]], model) -> dict[str, Union[torch.Tensor, Any]]:
|
| 536 |
+
device = self.accelerator.device
|
| 537 |
+
prompts = [x["prompt"] for x in inputs]
|
| 538 |
+
prompts_text = self.dna_module.prepare_prompt(self.processing_class, inputs)
|
| 539 |
+
# Handle both pre-loaded images and image paths
|
| 540 |
+
batch_dna_sequences = []
|
| 541 |
+
print("_generate_and_score_completions (GRPO):")
|
| 542 |
+
for x in inputs:
|
| 543 |
+
#print('---')
|
| 544 |
+
#print(x)
|
| 545 |
+
if 'dna_sequences' in x:
|
| 546 |
+
dnas = self._get_key_from_inputs(x, "dna_sequences")
|
| 547 |
+
|
| 548 |
+
for dna in dnas:
|
| 549 |
+
# clean if desired
|
| 550 |
+
pass
|
| 551 |
+
batch_dna_sequences.append(dnas)
|
| 552 |
+
# NOTE: typically appends dna, so dna_sequences is all the dna in one list
|
| 553 |
+
# odd. trying this instead
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
prompt_inputs = self.dna_module.prepare_model_inputs(
|
| 557 |
+
self.processing_class,
|
| 558 |
+
model,
|
| 559 |
+
prompts_text,
|
| 560 |
+
batch_dna_sequences,
|
| 561 |
+
return_tensors="pt",
|
| 562 |
+
padding=True,
|
| 563 |
+
padding_side="left",
|
| 564 |
+
add_special_tokens=False,
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
| 568 |
+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
| 569 |
+
|
| 570 |
+
# max_prompt_length is not supported yet
|
| 571 |
+
# if self.max_prompt_length is not None:
|
| 572 |
+
# prompt_ids = prompt_ids[:, -self.max_prompt_length :]
|
| 573 |
+
# prompt_inputs["input_ids"] = prompt_ids
|
| 574 |
+
# prompt_mask = prompt_mask[:, -self.max_prompt_length :]
|
| 575 |
+
# prompt_inputs["attention_mask"] = prompt_mask
|
| 576 |
+
|
| 577 |
+
# Generate completions
|
| 578 |
+
start = time.time()
|
| 579 |
+
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
| 580 |
+
kwargs = {k: v for k, v in prompt_inputs.items() if k not in self.dna_module.get_non_generate_params()}
|
| 581 |
+
generate_returned_result = unwrapped_model.generate(
|
| 582 |
+
**kwargs,
|
| 583 |
+
generation_config=self.generation_config
|
| 584 |
+
)
|
| 585 |
+
end = time.time()
|
| 586 |
+
print(f"Generation time: {end - start:.9f} seconds")
|
| 587 |
+
prompt_length = prompt_ids.size(1)
|
| 588 |
+
if not self.dna_module.is_embeds_input():
|
| 589 |
+
prompt_completion_ids = generate_returned_result
|
| 590 |
+
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
| 591 |
+
completion_ids = prompt_completion_ids[:, prompt_length:]
|
| 592 |
+
else:
|
| 593 |
+
# In this case, the input of the LLM backbone is the embedding of the combination of the image and text prompt
|
| 594 |
+
# So the returned result of the `generate` method only contains the completion ids
|
| 595 |
+
completion_ids = generate_returned_result
|
| 596 |
+
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 597 |
+
|
| 598 |
+
# Mask everything after the first EOS token
|
| 599 |
+
# print('completion:', completion_ids)
|
| 600 |
+
# print('generate_returned_result', generate_returned_result, generate_returned_result.shape)
|
| 601 |
+
# print('prompt_inputs["input_ids"]', prompt_inputs["input_ids"], prompt_inputs["input_ids"].shape)
|
| 602 |
+
# print('prompt_ids', prompt_ids, prompt_ids.shape)
|
| 603 |
+
# print('prompt_length', prompt_length)
|
| 604 |
+
# print('prompt_completion_ids', prompt_completion_ids, prompt_completion_ids.shape)
|
| 605 |
+
is_eos = completion_ids == self.processing_class.eos_token_id
|
| 606 |
+
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
|
| 607 |
+
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
|
| 608 |
+
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
| 609 |
+
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
| 610 |
+
|
| 611 |
+
# Concatenate prompt_mask with completion_mask for logit computation
|
| 612 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
| 613 |
+
|
| 614 |
+
# Get the multimodal inputs
|
| 615 |
+
multimodal_keywords = self.dna_module.get_custom_multimodal_keywords()
|
| 616 |
+
multimodal_inputs = {k: prompt_inputs[k] if k in prompt_inputs else None for k in multimodal_keywords}
|
| 617 |
+
with torch.no_grad():
|
| 618 |
+
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its
|
| 619 |
+
# computation here, and use per_token_logps.detach() instead.
|
| 620 |
+
if self.num_iterations > 1:
|
| 621 |
+
old_per_token_logps = self._get_per_token_logps(
|
| 622 |
+
model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
| 623 |
+
)
|
| 624 |
+
old_per_token_logps = old_per_token_logps[:, prompt_length - 1:]
|
| 625 |
+
else:
|
| 626 |
+
old_per_token_logps = None
|
| 627 |
+
|
| 628 |
+
if self.beta == 0.0:
|
| 629 |
+
ref_per_token_logps = None
|
| 630 |
+
elif self.ref_model is not None:
|
| 631 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 632 |
+
self.ref_model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
| 633 |
+
)
|
| 634 |
+
else:
|
| 635 |
+
with self.accelerator.unwrap_model(model).disable_adapter():
|
| 636 |
+
ref_per_token_logps = self._get_per_token_logps(
|
| 637 |
+
model, prompt_completion_ids, attention_mask, **multimodal_inputs
|
| 638 |
+
)
|
| 639 |
+
if ref_per_token_logps is not None:
|
| 640 |
+
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]
|
| 641 |
+
|
| 642 |
+
# Decode the generated completions
|
| 643 |
+
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
| 644 |
+
if is_conversational(inputs[0]):
|
| 645 |
+
completions = [[{"role": "assistant", "content": completion}] for completion in completions_text]
|
| 646 |
+
else:
|
| 647 |
+
completions = completions_text
|
| 648 |
+
# Compute the rewards
|
| 649 |
+
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
| 650 |
+
print("Reward calculation...")
|
| 651 |
+
rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
|
| 652 |
+
for i, (reward_func, reward_processing_class) in enumerate(
|
| 653 |
+
zip(self.reward_funcs, self.reward_processing_classes)
|
| 654 |
+
):
|
| 655 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 656 |
+
if is_conversational(inputs[0]):
|
| 657 |
+
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
|
| 658 |
+
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
|
| 659 |
+
else:
|
| 660 |
+
texts = [p + c for p, c in zip(prompts, completions)]
|
| 661 |
+
reward_inputs = reward_processing_class(
|
| 662 |
+
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
|
| 663 |
+
)
|
| 664 |
+
reward_inputs = super()._prepare_inputs(reward_inputs)
|
| 665 |
+
with torch.inference_mode():
|
| 666 |
+
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
|
| 667 |
+
else:
|
| 668 |
+
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
|
| 669 |
+
reward_kwargs = {key: [] for key in inputs[0].keys() if key not in ["prompt", "completion"]}
|
| 670 |
+
for key in reward_kwargs:
|
| 671 |
+
for example in inputs:
|
| 672 |
+
# No need to duplicate prompts as we're not generating multiple completions per prompt
|
| 673 |
+
# reward_kwargs[key].extend([example[key]] * self.num_generations)
|
| 674 |
+
reward_kwargs[key].extend([example[key]])
|
| 675 |
+
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
|
| 676 |
+
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
|
| 677 |
+
|
| 678 |
+
# Gather rewards across processes
|
| 679 |
+
rewards_per_func = self.accelerator.gather(rewards_per_func)
|
| 680 |
+
|
| 681 |
+
# Sum the rewards from all reward functions
|
| 682 |
+
rewards = rewards_per_func.sum(dim=1)
|
| 683 |
+
|
| 684 |
+
# Compute grouped-wise rewards
|
| 685 |
+
# Each group consists of num_generations completions for the same prompt
|
| 686 |
+
mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
|
| 687 |
+
std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
|
| 688 |
+
|
| 689 |
+
# Normalize the rewards to compute the advantages
|
| 690 |
+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 691 |
+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
| 692 |
+
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
| 693 |
+
|
| 694 |
+
# Get only the local slice of advantages
|
| 695 |
+
process_slice = slice(
|
| 696 |
+
self.accelerator.process_index * len(prompts),
|
| 697 |
+
(self.accelerator.process_index + 1) * len(prompts),
|
| 698 |
+
)
|
| 699 |
+
advantages = advantages[process_slice]
|
| 700 |
+
|
| 701 |
+
# Log the metrics
|
| 702 |
+
print("Logging metrics...")
|
| 703 |
+
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
| 704 |
+
self._metrics["completion_length"].append(completion_length)
|
| 705 |
+
|
| 706 |
+
reward_per_func = self.accelerator.gather_for_metrics(rewards_per_func).mean(0)
|
| 707 |
+
for i, reward_func in enumerate(self.reward_funcs):
|
| 708 |
+
if isinstance(reward_func, PreTrainedModel):
|
| 709 |
+
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
| 710 |
+
else:
|
| 711 |
+
reward_func_name = reward_func.__name__
|
| 712 |
+
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
| 713 |
+
|
| 714 |
+
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
| 715 |
+
|
| 716 |
+
self._metrics["reward_std"].append(self.accelerator.gather_for_metrics(std_grouped_rewards).mean().item())
|
| 717 |
+
|
| 718 |
+
print(self.log_completions, self.state.global_step, self.args.logging_steps, self.args.report_to)
|
| 719 |
+
if (
|
| 720 |
+
self.log_completions
|
| 721 |
+
and self.state.global_step % self.args.logging_steps == 0
|
| 722 |
+
and "wandb" in self.args.report_to
|
| 723 |
+
):
|
| 724 |
+
timestamp = time.time()
|
| 725 |
+
|
| 726 |
+
# Get the length of one of the other arrays
|
| 727 |
+
num_items = len(gather_object(prompts_text))
|
| 728 |
+
|
| 729 |
+
table = {
|
| 730 |
+
"step": [f"{self.state.global_step}_{timestamp}"] * num_items, # Repeat to match length
|
| 731 |
+
"prompt": gather_object(prompts_text),
|
| 732 |
+
"completion": gather_object(completions_text),
|
| 733 |
+
"reward": rewards.tolist(),
|
| 734 |
+
}
|
| 735 |
+
df = pd.DataFrame(table)
|
| 736 |
+
|
| 737 |
+
if wandb.run is not None and self.accelerator.is_main_process:
|
| 738 |
+
wandb.log({f"completions_{self.state.global_step}_{timestamp}": wandb.Table(dataframe=df)})
|
| 739 |
+
|
| 740 |
+
return {
|
| 741 |
+
"prompt_ids": prompt_ids,
|
| 742 |
+
"prompt_mask": prompt_mask,
|
| 743 |
+
"completion_ids": completion_ids,
|
| 744 |
+
"completion_mask": completion_mask,
|
| 745 |
+
"old_per_token_logps": old_per_token_logps,
|
| 746 |
+
"ref_per_token_logps": ref_per_token_logps,
|
| 747 |
+
"advantages": advantages,
|
| 748 |
+
"multimodal_inputs": multimodal_inputs
|
| 749 |
+
}
|
| 750 |
+
|
| 751 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 752 |
+
if return_outputs:
|
| 753 |
+
raise ValueError("The GRPOTrainer does not support returning outputs")
|
| 754 |
+
|
| 755 |
+
# Check if we need to generate new completions or use buffered ones
|
| 756 |
+
print("index 1")
|
| 757 |
+
if self.state.global_step % self.num_iterations == 0:
|
| 758 |
+
inputs = self._generate_and_score_completions(inputs, model)
|
| 759 |
+
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
|
| 760 |
+
else:
|
| 761 |
+
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
|
| 762 |
+
self._step += 1
|
| 763 |
+
|
| 764 |
+
print("index 2")
|
| 765 |
+
# Get the prepared inputs
|
| 766 |
+
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
|
| 767 |
+
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
|
| 768 |
+
multimodal_inputs = inputs["multimodal_inputs"]
|
| 769 |
+
|
| 770 |
+
# Concatenate for full sequence
|
| 771 |
+
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
|
| 772 |
+
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
|
| 773 |
+
print("index 3")
|
| 774 |
+
# Get the current policy's log probabilities
|
| 775 |
+
|
| 776 |
+
print("index 4")
|
| 777 |
+
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, **multimodal_inputs)
|
| 778 |
+
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
|
| 779 |
+
per_token_logps = per_token_logps[:, prompt_ids.size(1) - 1:]
|
| 780 |
+
|
| 781 |
+
# Get the advantages from inputs
|
| 782 |
+
advantages = inputs["advantages"]
|
| 783 |
+
print("index 5")
|
| 784 |
+
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
|
| 785 |
+
# and use per_token_logps.detach() instead
|
| 786 |
+
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
|
| 787 |
+
|
| 788 |
+
# Compute the policy ratio and clipped version
|
| 789 |
+
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
| 790 |
+
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
|
| 791 |
+
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
| 792 |
+
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
| 793 |
+
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
| 794 |
+
print("index 6")
|
| 795 |
+
# Add KL penalty if beta > 0
|
| 796 |
+
if self.beta > 0:
|
| 797 |
+
ref_per_token_logps = inputs["ref_per_token_logps"]
|
| 798 |
+
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
| 799 |
+
per_token_loss = per_token_loss + self.beta * per_token_kl
|
| 800 |
+
|
| 801 |
+
# Log KL divergence
|
| 802 |
+
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 803 |
+
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
| 804 |
+
|
| 805 |
+
# Compute final loss
|
| 806 |
+
print("Computing final loss...")
|
| 807 |
+
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
| 808 |
+
|
| 809 |
+
# Log clip ratio
|
| 810 |
+
is_clipped = (per_token_loss1 < per_token_loss2).float()
|
| 811 |
+
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
| 812 |
+
self._metrics["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
| 813 |
+
|
| 814 |
+
return loss
|
| 815 |
+
|
| 816 |
+
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
| 817 |
+
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
| 818 |
+
logs = {**logs, **metrics}
|
| 819 |
+
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
| 820 |
+
super().log(logs, start_time)
|
| 821 |
+
else: # transformers<=4.46
|
| 822 |
+
super().log(logs)
|
| 823 |
+
self._metrics.clear()
|
| 824 |
+
|
| 825 |
+
def create_model_card(
|
| 826 |
+
self,
|
| 827 |
+
model_name: Optional[str] = None,
|
| 828 |
+
dataset_name: Optional[str] = None,
|
| 829 |
+
tags: Union[str, list[str], None] = None,
|
| 830 |
+
):
|
| 831 |
+
"""
|
| 832 |
+
Creates a draft of a model card using the information available to the `Trainer`.
|
| 833 |
+
|
| 834 |
+
Args:
|
| 835 |
+
model_name (`str` or `None`, *optional*, defaults to `None`):
|
| 836 |
+
Name of the model.
|
| 837 |
+
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
| 838 |
+
Name of the dataset used for training.
|
| 839 |
+
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
| 840 |
+
Tags to be associated with the model card.
|
| 841 |
+
"""
|
| 842 |
+
if not self.is_world_process_zero():
|
| 843 |
+
return
|
| 844 |
+
|
| 845 |
+
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
| 846 |
+
base_model = self.model.config._name_or_path
|
| 847 |
+
else:
|
| 848 |
+
base_model = None
|
| 849 |
+
|
| 850 |
+
tags = tags or []
|
| 851 |
+
if isinstance(tags, str):
|
| 852 |
+
tags = [tags]
|
| 853 |
+
|
| 854 |
+
if hasattr(self.model.config, "unsloth_version"):
|
| 855 |
+
tags.append("unsloth")
|
| 856 |
+
|
| 857 |
+
citation = textwrap.dedent(
|
| 858 |
+
"""\
|
| 859 |
+
@article{zhihong2024deepseekmath,
|
| 860 |
+
title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
|
| 861 |
+
author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
|
| 862 |
+
year = 2024,
|
| 863 |
+
eprint = {arXiv:2402.03300},
|
| 864 |
+
"""
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
model_card = generate_model_card(
|
| 868 |
+
base_model=base_model,
|
| 869 |
+
model_name=model_name,
|
| 870 |
+
hub_model_id=self.hub_model_id,
|
| 871 |
+
dataset_name=dataset_name,
|
| 872 |
+
tags=tags,
|
| 873 |
+
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
| 874 |
+
comet_url=get_comet_experiment_url(),
|
| 875 |
+
trainer_name="GRPO",
|
| 876 |
+
trainer_citation=citation,
|
| 877 |
+
paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
|
| 878 |
+
paper_id="2402.03300",
|
| 879 |
+
)
|
| 880 |
+
|
| 881 |
+
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
| 882 |
+
|
| 883 |
+
def _get_train_sampler(self) -> Sampler:
|
| 884 |
+
"""Returns a sampler that ensures proper data sampling for GRPO training."""
|
| 885 |
+
effective_batch_size = (
|
| 886 |
+
self.args.per_device_train_batch_size
|
| 887 |
+
* self.accelerator.num_processes
|
| 888 |
+
* self.args.gradient_accumulation_steps
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
return RepeatRandomSampler(
|
| 892 |
+
data_source=self.train_dataset,
|
| 893 |
+
mini_repeat_count=self.num_generations,
|
| 894 |
+
batch_size=effective_batch_size // self.num_generations,
|
| 895 |
+
repeat_count=self.num_iterations,
|
| 896 |
+
seed=self.args.seed,
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
| 900 |
+
"""Returns a sampler for evaluation."""
|
| 901 |
+
return RepeatRandomSampler(
|
| 902 |
+
data_source=eval_dataset,
|
| 903 |
+
mini_repeat_count=self.num_generations,
|
| 904 |
+
seed=self.args.seed,
|
| 905 |
+
)
|
BioReason/bioreason/utils/__init__.py
ADDED
|
File without changes
|
BioReason/bioreason/utils/dna_utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TYPE_CHECKING, Callable, Optional, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from transformers.utils import is_torch_available
|
| 6 |
+
|
| 7 |
+
if is_torch_available():
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
DNAInput = Union[
|
| 11 |
+
str, list[int], np.ndarray, "torch.Tensor", list[str], list[list[int]], list[np.ndarray], list["torch.Tensor"]
|
| 12 |
+
] # noqa
|
BioReason/data/BioReasoning_DataCuration_KEGG.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
BioReason/data/Clinvar_Coding.ipynb
ADDED
|
@@ -0,0 +1,2481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "83c9cd1f",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"## Setup and Data Preparation\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"Initial setup steps to prepare the working environment and extract ClinVar data."
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "markdown",
|
| 15 |
+
"id": "81a36253-9050-4d58-96cd-8238aae51e0e",
|
| 16 |
+
"metadata": {},
|
| 17 |
+
"source": [
|
| 18 |
+
"# ClinVar Coding Variants Data Processing\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"This notebook processes ClinVar coding variants data by extracting additional information including gene names, gene IDs, and associated diseases from ClinVar XML records.\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"## Overview\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"The workflow includes:\n",
|
| 25 |
+
"1. **Data Extraction**: Filter ClinVar entries from VEP-annotated pathogenic coding variants\n",
|
| 26 |
+
"2. **XML Processing**: Parse ClinVar XML records to extract gene and disease information\n",
|
| 27 |
+
"3. **Gene Annotation**: Map gene IDs to gene names using NCBI Entrez utilities\n",
|
| 28 |
+
"4. **Data Integration**: Combine all information into a comprehensive dataset\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"## Requirements\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"- Python 3.7+\n",
|
| 33 |
+
"- pandas library\n",
|
| 34 |
+
"- xml.etree.ElementTree (built-in)\n",
|
| 35 |
+
"- NCBI Entrez Direct tools (for gene name mapping)\n",
|
| 36 |
+
"- Input data: VEP-annotated pathogenic coding variants CSV file\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"## Data Structure\n",
|
| 39 |
+
"\n",
|
| 40 |
+
"The processing creates a dataset with the following key columns:\n",
|
| 41 |
+
"- Variant information (chromosome, position, alleles)\n",
|
| 42 |
+
"- ClinVar ID and significance\n",
|
| 43 |
+
"- Gene symbols and IDs\n",
|
| 44 |
+
"- Associated disease/phenotype information"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": null,
|
| 50 |
+
"id": "cb351234-50a3-4061-81ce-bdce5343e790",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"# Create working directory for ClinVar data processing\n",
|
| 55 |
+
"import os\n",
|
| 56 |
+
"os.makedirs('clinvar', exist_ok=True)\n",
|
| 57 |
+
"print(\"✅ Created 'clinvar' directory\")"
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "code",
|
| 62 |
+
"execution_count": null,
|
| 63 |
+
"id": "443ccab8-50a1-45ae-950c-8425eb318e93",
|
| 64 |
+
"metadata": {},
|
| 65 |
+
"outputs": [],
|
| 66 |
+
"source": [
|
| 67 |
+
"import os\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# Navigate to clinvar directory\n",
|
| 70 |
+
"os.chdir('clinvar')\n",
|
| 71 |
+
"print(f\"📁 Current working directory: {os.getcwd()}\")\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"with open('vep_pathogenic_coding.csv') as infile, open('clinvar_coding_raw.csv', 'w') as outfile:\n",
|
| 74 |
+
" for line in infile:\n",
|
| 75 |
+
" if 'ClinVar' in line:\n",
|
| 76 |
+
" outfile.write(line)"
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"cell_type": "code",
|
| 81 |
+
"execution_count": null,
|
| 82 |
+
"id": "e1f92675-b85c-4baa-8680-9c3776e04ac9",
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"outputs": [],
|
| 85 |
+
"source": [
|
| 86 |
+
"# Extract ClinVar entries from VEP-annotated pathogenic coding variants\n",
|
| 87 |
+
"# Note: Update the input file path to match your data location\n",
|
| 88 |
+
"input_file = \"../data/vep_pathogenic_coding.csv\" # Adjust path as needed\n",
|
| 89 |
+
"output_file = \"clinvar_coding_raw.csv\"\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# Use shell command to filter ClinVar entries\n",
|
| 92 |
+
"import subprocess\n",
|
| 93 |
+
"try:\n",
|
| 94 |
+
" result = subprocess.run(\n",
|
| 95 |
+
" [\"grep\", \"ClinVar\", input_file],\n",
|
| 96 |
+
" capture_output=True,\n",
|
| 97 |
+
" text=True,\n",
|
| 98 |
+
" check=True\n",
|
| 99 |
+
" )\n",
|
| 100 |
+
" \n",
|
| 101 |
+
" with open(output_file, 'w') as f:\n",
|
| 102 |
+
" f.write(result.stdout)\n",
|
| 103 |
+
" \n",
|
| 104 |
+
" print(f\"✅ Extracted ClinVar entries to {output_file}\")\n",
|
| 105 |
+
" print(f\"📊 Found {len(result.stdout.strip().split('\\n'))} ClinVar entries\")\n",
|
| 106 |
+
" \n",
|
| 107 |
+
"except subprocess.CalledProcessError:\n",
|
| 108 |
+
" print(f\"❌ Error: Could not find ClinVar entries in {input_file}\")\n",
|
| 109 |
+
" print(\"Please ensure the input file exists and contains ClinVar annotations\")\n",
|
| 110 |
+
"except FileNotFoundError:\n",
|
| 111 |
+
" print(f\"❌ Error: Input file {input_file} not found\")\n",
|
| 112 |
+
" print(\"Please update the input_file path to point to your VEP-annotated data\")"
|
| 113 |
+
]
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"cell_type": "code",
|
| 117 |
+
"execution_count": null,
|
| 118 |
+
"id": "7e560308-135b-4189-9146-ff50845839a4",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"outputs": [],
|
| 121 |
+
"source": [
|
| 122 |
+
"# Extract ClinVar IDs from the filtered data (assuming ID is in column 8)\n",
|
| 123 |
+
"# Note: Adjust column number if your data structure is different\n",
|
| 124 |
+
"import pandas as pd\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"try:\n",
|
| 127 |
+
" # Read the raw ClinVar data to determine structure\n",
|
| 128 |
+
" df_temp = pd.read_csv(\"clinvar_coding_raw.csv\")\n",
|
| 129 |
+
" print(f\"📋 Data shape: {df_temp.shape}\")\n",
|
| 130 |
+
" print(f\"📋 Columns: {list(df_temp.columns)}\")\n",
|
| 131 |
+
" \n",
|
| 132 |
+
" # Extract ClinVar IDs (adjust column index as needed)\n",
|
| 133 |
+
" # Column 8 corresponds to index 7 in Python (0-based)\n",
|
| 134 |
+
" if df_temp.shape[1] >= 8:\n",
|
| 135 |
+
" clinvar_ids = df_temp.iloc[:, 7] # 8th column (0-based index 7)\n",
|
| 136 |
+
" \n",
|
| 137 |
+
" # Save IDs to file\n",
|
| 138 |
+
" with open(\"Clinvar_ID.txt\", 'w') as f:\n",
|
| 139 |
+
" for id_val in clinvar_ids:\n",
|
| 140 |
+
" if pd.notna(id_val):\n",
|
| 141 |
+
" f.write(f\"{id_val}\\n\")\n",
|
| 142 |
+
" \n",
|
| 143 |
+
" print(f\"✅ Extracted {len(clinvar_ids.dropna())} ClinVar IDs to Clinvar_ID.txt\")\n",
|
| 144 |
+
" else:\n",
|
| 145 |
+
" print(f\"❌ Error: Expected at least 8 columns, found {df_temp.shape[1]}\")\n",
|
| 146 |
+
" \n",
|
| 147 |
+
"except FileNotFoundError:\n",
|
| 148 |
+
" print(\"❌ Error: clinvar_coding_raw.csv not found\")\n",
|
| 149 |
+
" print(\"Please run the previous cell first to extract ClinVar data\")\n",
|
| 150 |
+
"except Exception as e:\n",
|
| 151 |
+
" print(f\"❌ Error processing ClinVar data: {e}\")"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "code",
|
| 156 |
+
"execution_count": null,
|
| 157 |
+
"id": "53b0dfd8-8d49-4c3f-adb4-4c6bfbffcfa9",
|
| 158 |
+
"metadata": {},
|
| 159 |
+
"outputs": [],
|
| 160 |
+
"source": [
|
| 161 |
+
"chmod +x Clinvar_esearch.sh\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"## XML Data Retrieval\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"**Note**: This step requires creating a shell script (`Clinvar_esearch.sh`) to fetch XML data from NCBI.\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"The script should:\n",
|
| 168 |
+
"1. Read ClinVar IDs from `Clinvar_ID.txt`\n",
|
| 169 |
+
"2. Use NCBI Entrez Direct tools to fetch XML records\n",
|
| 170 |
+
"3. Save XML files in a `data/` subdirectory\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"Example script content:\n",
|
| 173 |
+
"```bash\n",
|
| 174 |
+
"#!/bin/bash\n",
|
| 175 |
+
"mkdir -p data\n",
|
| 176 |
+
"while read -r id; do\n",
|
| 177 |
+
" esearch -db clinvar -query \"$id\" | efetch -format xml > \"data/${id}.xml\"\n",
|
| 178 |
+
" echo \"Downloaded XML for ClinVar ID: $id\"\n",
|
| 179 |
+
"done < Clinvar_ID.txt\n",
|
| 180 |
+
"```\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"**Prerequisites**: Install NCBI Entrez Direct tools:\n",
|
| 183 |
+
"- macOS: `brew install brewsci/bio/edirect`\n",
|
| 184 |
+
"- Linux: Follow NCBI EDirect installation guide"
|
| 185 |
+
]
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"cell_type": "code",
|
| 189 |
+
"execution_count": null,
|
| 190 |
+
"id": "0755ad6d",
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"outputs": [],
|
| 193 |
+
"source": [
|
| 194 |
+
"# Parsing XML for Gene and Disease\n",
|
| 195 |
+
"\n",
|
| 196 |
+
"# Make the ClinVar search script executable and run it\n",
|
| 197 |
+
"# Note: This assumes you have created the Clinvar_esearch.sh script\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"import os\n",
|
| 200 |
+
"import subprocess\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"script_path = \"Clinvar_esearch.sh\"\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"if os.path.exists(script_path):\n",
|
| 205 |
+
" # Make script executable\n",
|
| 206 |
+
" os.chmod(script_path, 0o755)\n",
|
| 207 |
+
" print(f\"✅ Made {script_path} executable\")\n",
|
| 208 |
+
" \n",
|
| 209 |
+
" # Optionally run the script (uncomment if you want to execute automatically)\n",
|
| 210 |
+
" # print(\"🚀 Running ClinVar XML download script...\")\n",
|
| 211 |
+
" # result = subprocess.run([f\"./{script_path}\"], capture_output=True, text=True)\n",
|
| 212 |
+
" # if result.returncode == 0:\n",
|
| 213 |
+
" # print(\"✅ XML download completed successfully\")\n",
|
| 214 |
+
" # else:\n",
|
| 215 |
+
" # print(f\"❌ Script execution failed: {result.stderr}\")\n",
|
| 216 |
+
"else:\n",
|
| 217 |
+
" print(f\"⚠️ Warning: {script_path} not found\")\n",
|
| 218 |
+
" print(\"Please create this script manually to download ClinVar XML data\")\n",
|
| 219 |
+
" print(\"See the documentation in the previous cell for script template\")"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"cell_type": "code",
|
| 224 |
+
"execution_count": null,
|
| 225 |
+
"id": "d21a188b-a0dc-4af2-9b71-5a44d8cd4673",
|
| 226 |
+
"metadata": {},
|
| 227 |
+
"outputs": [],
|
| 228 |
+
"source": [
|
| 229 |
+
"# Import required libraries\n",
|
| 230 |
+
"import pandas as pd\n",
|
| 231 |
+
"import xml.etree.ElementTree as ET\n",
|
| 232 |
+
"import json\n",
|
| 233 |
+
"import os\n",
|
| 234 |
+
"from pathlib import Path\n",
|
| 235 |
+
"\n",
|
| 236 |
+
"print(\"📚 Libraries imported successfully\")\n",
|
| 237 |
+
"print(f\"📁 Current directory: {os.getcwd()}\")\n",
|
| 238 |
+
"print(f\"📊 Pandas version: {pd.__version__}\")"
|
| 239 |
+
]
|
| 240 |
+
},
|
| 241 |
+
{
|
| 242 |
+
"cell_type": "code",
|
| 243 |
+
"execution_count": null,
|
| 244 |
+
"id": "1365615b-ee81-4df0-9fca-df001e9f01d4",
|
| 245 |
+
"metadata": {},
|
| 246 |
+
"outputs": [],
|
| 247 |
+
"source": [
|
| 248 |
+
"# Load the raw ClinVar data\n",
|
| 249 |
+
"try:\n",
|
| 250 |
+
" clinvar_raw = pd.read_csv(\"clinvar_coding_raw.csv\")\n",
|
| 251 |
+
" print(f\"✅ Loaded ClinVar data: {clinvar_raw.shape[0]} rows, {clinvar_raw.shape[1]} columns\")\n",
|
| 252 |
+
" print(f\"📋 Columns: {list(clinvar_raw.columns)[:10]}\") # Show first 10 columns\n",
|
| 253 |
+
" \n",
|
| 254 |
+
"except FileNotFoundError:\n",
|
| 255 |
+
" print(\"❌ Error: clinvar_coding_raw.csv not found\")\n",
|
| 256 |
+
" print(\"Please run the data extraction steps first\")\n",
|
| 257 |
+
" clinvar_raw = None\n",
|
| 258 |
+
"except Exception as e:\n",
|
| 259 |
+
" print(f\"❌ Error loading data: {e}\")\n",
|
| 260 |
+
" clinvar_raw = None"
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"cell_type": "code",
|
| 265 |
+
"execution_count": null,
|
| 266 |
+
"id": "7144ddf2-abf7-4680-b578-d4bd4b7195ea",
|
| 267 |
+
"metadata": {},
|
| 268 |
+
"outputs": [],
|
| 269 |
+
"source": [
|
| 270 |
+
"# Remove unnecessary columns to streamline the dataset\n",
|
| 271 |
+
"# Note: Adjust column names based on your actual data structure\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"if clinvar_raw is not None:\n",
|
| 274 |
+
" columns_to_remove = [\n",
|
| 275 |
+
" \"GENOMIC_MUTATION_ID\", \"N_SAMPLES\", \"TOTAL_SAMPLES\", \"FREQ\", \n",
|
| 276 |
+
" \"OMIM\", \"PMID\", \"AC\", \"AN\", \"AF\", \"MAF\", \"MAC\"\n",
|
| 277 |
+
" ]\n",
|
| 278 |
+
" \n",
|
| 279 |
+
" # Only remove columns that actually exist in the dataset\n",
|
| 280 |
+
" existing_columns = [col for col in columns_to_remove if col in clinvar_raw.columns]\n",
|
| 281 |
+
" missing_columns = [col for col in columns_to_remove if col not in clinvar_raw.columns]\n",
|
| 282 |
+
" \n",
|
| 283 |
+
" if existing_columns:\n",
|
| 284 |
+
" clinvar_raw = clinvar_raw.drop(columns=existing_columns)\n",
|
| 285 |
+
" print(f\"✅ Removed {len(existing_columns)} columns: {existing_columns}\")\n",
|
| 286 |
+
" \n",
|
| 287 |
+
" if missing_columns:\n",
|
| 288 |
+
" print(f\"ℹ️ Columns not found (skipped): {missing_columns}\")\n",
|
| 289 |
+
" \n",
|
| 290 |
+
" print(f\"📊 Remaining columns: {clinvar_raw.shape[1]}\")\n",
|
| 291 |
+
"else:\n",
|
| 292 |
+
" print(\"⚠️ Skipping column removal - data not loaded\")"
|
| 293 |
+
]
|
| 294 |
+
},
|
| 295 |
+
{
|
| 296 |
+
"cell_type": "code",
|
| 297 |
+
"execution_count": null,
|
| 298 |
+
"id": "fbffd3cd-7df3-43e2-8d73-01f54e8d1da6",
|
| 299 |
+
"metadata": {},
|
| 300 |
+
"outputs": [
|
| 301 |
+
{
|
| 302 |
+
"data": {
|
| 303 |
+
"text/html": [
|
| 304 |
+
"<div>\n",
|
| 305 |
+
"<style scoped>\n",
|
| 306 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 307 |
+
" vertical-align: middle;\n",
|
| 308 |
+
" }\n",
|
| 309 |
+
"\n",
|
| 310 |
+
" .dataframe tbody tr th {\n",
|
| 311 |
+
" vertical-align: top;\n",
|
| 312 |
+
" }\n",
|
| 313 |
+
"\n",
|
| 314 |
+
" .dataframe thead th {\n",
|
| 315 |
+
" text-align: right;\n",
|
| 316 |
+
" }\n",
|
| 317 |
+
"</style>\n",
|
| 318 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 319 |
+
" <thead>\n",
|
| 320 |
+
" <tr style=\"text-align: right;\">\n",
|
| 321 |
+
" <th></th>\n",
|
| 322 |
+
" <th>CHROM</th>\n",
|
| 323 |
+
" <th>POS</th>\n",
|
| 324 |
+
" <th>REF</th>\n",
|
| 325 |
+
" <th>ALT</th>\n",
|
| 326 |
+
" <th>LABEL</th>\n",
|
| 327 |
+
" <th>SOURCE</th>\n",
|
| 328 |
+
" <th>CONSEQUENCE</th>\n",
|
| 329 |
+
" <th>ID</th>\n",
|
| 330 |
+
" <th>REVIEW_STATUS</th>\n",
|
| 331 |
+
" <th>GENE</th>\n",
|
| 332 |
+
" <th>split</th>\n",
|
| 333 |
+
" <th>INT_LABEL</th>\n",
|
| 334 |
+
" </tr>\n",
|
| 335 |
+
" </thead>\n",
|
| 336 |
+
" <tbody>\n",
|
| 337 |
+
" <tr>\n",
|
| 338 |
+
" <th>0</th>\n",
|
| 339 |
+
" <td>chr1</td>\n",
|
| 340 |
+
" <td>976215</td>\n",
|
| 341 |
+
" <td>A</td>\n",
|
| 342 |
+
" <td>G</td>\n",
|
| 343 |
+
" <td>Pathogenic</td>\n",
|
| 344 |
+
" <td>ClinVar</td>\n",
|
| 345 |
+
" <td>missense_variant</td>\n",
|
| 346 |
+
" <td>1320032</td>\n",
|
| 347 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 348 |
+
" <td>NaN</td>\n",
|
| 349 |
+
" <td>train</td>\n",
|
| 350 |
+
" <td>1</td>\n",
|
| 351 |
+
" </tr>\n",
|
| 352 |
+
" <tr>\n",
|
| 353 |
+
" <th>1</th>\n",
|
| 354 |
+
" <td>chr1</td>\n",
|
| 355 |
+
" <td>1050449</td>\n",
|
| 356 |
+
" <td>G</td>\n",
|
| 357 |
+
" <td>A</td>\n",
|
| 358 |
+
" <td>Pathogenic</td>\n",
|
| 359 |
+
" <td>ClinVar</td>\n",
|
| 360 |
+
" <td>missense_variant</td>\n",
|
| 361 |
+
" <td>1284257</td>\n",
|
| 362 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 363 |
+
" <td>NaN</td>\n",
|
| 364 |
+
" <td>train</td>\n",
|
| 365 |
+
" <td>1</td>\n",
|
| 366 |
+
" </tr>\n",
|
| 367 |
+
" <tr>\n",
|
| 368 |
+
" <th>2</th>\n",
|
| 369 |
+
" <td>chr1</td>\n",
|
| 370 |
+
" <td>1050575</td>\n",
|
| 371 |
+
" <td>G</td>\n",
|
| 372 |
+
" <td>C</td>\n",
|
| 373 |
+
" <td>Pathogenic</td>\n",
|
| 374 |
+
" <td>ClinVar</td>\n",
|
| 375 |
+
" <td>missense_variant</td>\n",
|
| 376 |
+
" <td>18241</td>\n",
|
| 377 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 378 |
+
" <td>NaN</td>\n",
|
| 379 |
+
" <td>train</td>\n",
|
| 380 |
+
" <td>1</td>\n",
|
| 381 |
+
" </tr>\n",
|
| 382 |
+
" <tr>\n",
|
| 383 |
+
" <th>3</th>\n",
|
| 384 |
+
" <td>chr1</td>\n",
|
| 385 |
+
" <td>1213738</td>\n",
|
| 386 |
+
" <td>G</td>\n",
|
| 387 |
+
" <td>A</td>\n",
|
| 388 |
+
" <td>Pathogenic</td>\n",
|
| 389 |
+
" <td>ClinVar</td>\n",
|
| 390 |
+
" <td>missense_variant</td>\n",
|
| 391 |
+
" <td>96692</td>\n",
|
| 392 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 393 |
+
" <td>NaN</td>\n",
|
| 394 |
+
" <td>train</td>\n",
|
| 395 |
+
" <td>1</td>\n",
|
| 396 |
+
" </tr>\n",
|
| 397 |
+
" <tr>\n",
|
| 398 |
+
" <th>4</th>\n",
|
| 399 |
+
" <td>chr1</td>\n",
|
| 400 |
+
" <td>1232279</td>\n",
|
| 401 |
+
" <td>A</td>\n",
|
| 402 |
+
" <td>G</td>\n",
|
| 403 |
+
" <td>Pathogenic</td>\n",
|
| 404 |
+
" <td>ClinVar</td>\n",
|
| 405 |
+
" <td>initiatior_codon_variant,missense_variant</td>\n",
|
| 406 |
+
" <td>60484</td>\n",
|
| 407 |
+
" <td>criteria_provided,_multiple_submitters,_no_con...</td>\n",
|
| 408 |
+
" <td>NaN</td>\n",
|
| 409 |
+
" <td>train</td>\n",
|
| 410 |
+
" <td>1</td>\n",
|
| 411 |
+
" </tr>\n",
|
| 412 |
+
" <tr>\n",
|
| 413 |
+
" <th>...</th>\n",
|
| 414 |
+
" <td>...</td>\n",
|
| 415 |
+
" <td>...</td>\n",
|
| 416 |
+
" <td>...</td>\n",
|
| 417 |
+
" <td>...</td>\n",
|
| 418 |
+
" <td>...</td>\n",
|
| 419 |
+
" <td>...</td>\n",
|
| 420 |
+
" <td>...</td>\n",
|
| 421 |
+
" <td>...</td>\n",
|
| 422 |
+
" <td>...</td>\n",
|
| 423 |
+
" <td>...</td>\n",
|
| 424 |
+
" <td>...</td>\n",
|
| 425 |
+
" <td>...</td>\n",
|
| 426 |
+
" </tr>\n",
|
| 427 |
+
" <tr>\n",
|
| 428 |
+
" <th>22249</th>\n",
|
| 429 |
+
" <td>chrY</td>\n",
|
| 430 |
+
" <td>2787412</td>\n",
|
| 431 |
+
" <td>C</td>\n",
|
| 432 |
+
" <td>T</td>\n",
|
| 433 |
+
" <td>Pathogenic</td>\n",
|
| 434 |
+
" <td>ClinVar</td>\n",
|
| 435 |
+
" <td>missense_variant</td>\n",
|
| 436 |
+
" <td>9747</td>\n",
|
| 437 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 438 |
+
" <td>NaN</td>\n",
|
| 439 |
+
" <td>train</td>\n",
|
| 440 |
+
" <td>1</td>\n",
|
| 441 |
+
" </tr>\n",
|
| 442 |
+
" <tr>\n",
|
| 443 |
+
" <th>22250</th>\n",
|
| 444 |
+
" <td>chrY</td>\n",
|
| 445 |
+
" <td>2787426</td>\n",
|
| 446 |
+
" <td>C</td>\n",
|
| 447 |
+
" <td>G</td>\n",
|
| 448 |
+
" <td>Pathogenic</td>\n",
|
| 449 |
+
" <td>ClinVar</td>\n",
|
| 450 |
+
" <td>missense_variant</td>\n",
|
| 451 |
+
" <td>9739</td>\n",
|
| 452 |
+
" <td>criteria_provided,_single_submitter</td>\n",
|
| 453 |
+
" <td>NaN</td>\n",
|
| 454 |
+
" <td>train</td>\n",
|
| 455 |
+
" <td>1</td>\n",
|
| 456 |
+
" </tr>\n",
|
| 457 |
+
" <tr>\n",
|
| 458 |
+
" <th>22251</th>\n",
|
| 459 |
+
" <td>chrY</td>\n",
|
| 460 |
+
" <td>2787515</td>\n",
|
| 461 |
+
" <td>C</td>\n",
|
| 462 |
+
" <td>A</td>\n",
|
| 463 |
+
" <td>Pathogenic</td>\n",
|
| 464 |
+
" <td>ClinVar</td>\n",
|
| 465 |
+
" <td>missense_variant</td>\n",
|
| 466 |
+
" <td>492908</td>\n",
|
| 467 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 468 |
+
" <td>NaN</td>\n",
|
| 469 |
+
" <td>train</td>\n",
|
| 470 |
+
" <td>1</td>\n",
|
| 471 |
+
" </tr>\n",
|
| 472 |
+
" <tr>\n",
|
| 473 |
+
" <th>22252</th>\n",
|
| 474 |
+
" <td>chrY</td>\n",
|
| 475 |
+
" <td>2787551</td>\n",
|
| 476 |
+
" <td>C</td>\n",
|
| 477 |
+
" <td>T</td>\n",
|
| 478 |
+
" <td>Pathogenic</td>\n",
|
| 479 |
+
" <td>ClinVar</td>\n",
|
| 480 |
+
" <td>missense_variant</td>\n",
|
| 481 |
+
" <td>9754</td>\n",
|
| 482 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 483 |
+
" <td>NaN</td>\n",
|
| 484 |
+
" <td>train</td>\n",
|
| 485 |
+
" <td>1</td>\n",
|
| 486 |
+
" </tr>\n",
|
| 487 |
+
" <tr>\n",
|
| 488 |
+
" <th>22253</th>\n",
|
| 489 |
+
" <td>chrY</td>\n",
|
| 490 |
+
" <td>7063898</td>\n",
|
| 491 |
+
" <td>A</td>\n",
|
| 492 |
+
" <td>T</td>\n",
|
| 493 |
+
" <td>Pathogenic</td>\n",
|
| 494 |
+
" <td>ClinVar</td>\n",
|
| 495 |
+
" <td>missense_variant</td>\n",
|
| 496 |
+
" <td>625467</td>\n",
|
| 497 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 498 |
+
" <td>NaN</td>\n",
|
| 499 |
+
" <td>train</td>\n",
|
| 500 |
+
" <td>1</td>\n",
|
| 501 |
+
" </tr>\n",
|
| 502 |
+
" </tbody>\n",
|
| 503 |
+
"</table>\n",
|
| 504 |
+
"<p>22254 rows × 12 columns</p>\n",
|
| 505 |
+
"</div>"
|
| 506 |
+
],
|
| 507 |
+
"text/plain": [
|
| 508 |
+
" CHROM POS REF ALT LABEL SOURCE \\\n",
|
| 509 |
+
"0 chr1 976215 A G Pathogenic ClinVar \n",
|
| 510 |
+
"1 chr1 1050449 G A Pathogenic ClinVar \n",
|
| 511 |
+
"2 chr1 1050575 G C Pathogenic ClinVar \n",
|
| 512 |
+
"3 chr1 1213738 G A Pathogenic ClinVar \n",
|
| 513 |
+
"4 chr1 1232279 A G Pathogenic ClinVar \n",
|
| 514 |
+
"... ... ... .. .. ... ... \n",
|
| 515 |
+
"22249 chrY 2787412 C T Pathogenic ClinVar \n",
|
| 516 |
+
"22250 chrY 2787426 C G Pathogenic ClinVar \n",
|
| 517 |
+
"22251 chrY 2787515 C A Pathogenic ClinVar \n",
|
| 518 |
+
"22252 chrY 2787551 C T Pathogenic ClinVar \n",
|
| 519 |
+
"22253 chrY 7063898 A T Pathogenic ClinVar \n",
|
| 520 |
+
"\n",
|
| 521 |
+
" CONSEQUENCE ID \\\n",
|
| 522 |
+
"0 missense_variant 1320032 \n",
|
| 523 |
+
"1 missense_variant 1284257 \n",
|
| 524 |
+
"2 missense_variant 18241 \n",
|
| 525 |
+
"3 missense_variant 96692 \n",
|
| 526 |
+
"4 initiatior_codon_variant,missense_variant 60484 \n",
|
| 527 |
+
"... ... ... \n",
|
| 528 |
+
"22249 missense_variant 9747 \n",
|
| 529 |
+
"22250 missense_variant 9739 \n",
|
| 530 |
+
"22251 missense_variant 492908 \n",
|
| 531 |
+
"22252 missense_variant 9754 \n",
|
| 532 |
+
"22253 missense_variant 625467 \n",
|
| 533 |
+
"\n",
|
| 534 |
+
" REVIEW_STATUS GENE split \\\n",
|
| 535 |
+
"0 no_assertion_criteria_provided NaN train \n",
|
| 536 |
+
"1 no_assertion_criteria_provided NaN train \n",
|
| 537 |
+
"2 no_assertion_criteria_provided NaN train \n",
|
| 538 |
+
"3 no_assertion_criteria_provided NaN train \n",
|
| 539 |
+
"4 criteria_provided,_multiple_submitters,_no_con... NaN train \n",
|
| 540 |
+
"... ... ... ... \n",
|
| 541 |
+
"22249 no_assertion_criteria_provided NaN train \n",
|
| 542 |
+
"22250 criteria_provided,_single_submitter NaN train \n",
|
| 543 |
+
"22251 no_assertion_criteria_provided NaN train \n",
|
| 544 |
+
"22252 no_assertion_criteria_provided NaN train \n",
|
| 545 |
+
"22253 no_assertion_criteria_provided NaN train \n",
|
| 546 |
+
"\n",
|
| 547 |
+
" INT_LABEL \n",
|
| 548 |
+
"0 1 \n",
|
| 549 |
+
"1 1 \n",
|
| 550 |
+
"2 1 \n",
|
| 551 |
+
"3 1 \n",
|
| 552 |
+
"4 1 \n",
|
| 553 |
+
"... ... \n",
|
| 554 |
+
"22249 1 \n",
|
| 555 |
+
"22250 1 \n",
|
| 556 |
+
"22251 1 \n",
|
| 557 |
+
"22252 1 \n",
|
| 558 |
+
"22253 1 \n",
|
| 559 |
+
"\n",
|
| 560 |
+
"[22254 rows x 12 columns]"
|
| 561 |
+
]
|
| 562 |
+
},
|
| 563 |
+
"execution_count": 17,
|
| 564 |
+
"metadata": {},
|
| 565 |
+
"output_type": "execute_result"
|
| 566 |
+
}
|
| 567 |
+
],
|
| 568 |
+
"source": [
|
| 569 |
+
"clinvar_raw\n",
|
| 570 |
+
"\n",
|
| 571 |
+
"# Preview the cleaned dataset\n",
|
| 572 |
+
"if clinvar_raw is not None:\n",
|
| 573 |
+
" print(f\"📊 Dataset shape: {clinvar_raw.shape}\")\n",
|
| 574 |
+
" print(f\"📋 Column names: {list(clinvar_raw.columns)}\")\n",
|
| 575 |
+
" print(\"\\n🔍 First few rows:\")\n",
|
| 576 |
+
" display(clinvar_raw.head())\n",
|
| 577 |
+
" \n",
|
| 578 |
+
" # Check for any null values\n",
|
| 579 |
+
" null_counts = clinvar_raw.isnull().sum()\n",
|
| 580 |
+
" if null_counts.sum() > 0:\n",
|
| 581 |
+
" print(\"\\n⚠️ Null values found:\")\n",
|
| 582 |
+
" print(null_counts[null_counts > 0])\n",
|
| 583 |
+
"else:\n",
|
| 584 |
+
" print(\"❌ No data to display\")"
|
| 585 |
+
]
|
| 586 |
+
},
|
| 587 |
+
{
|
| 588 |
+
"cell_type": "code",
|
| 589 |
+
"execution_count": null,
|
| 590 |
+
"id": "e380634b-0c22-4d1e-8520-6fc5728e7de5",
|
| 591 |
+
"metadata": {},
|
| 592 |
+
"outputs": [],
|
| 593 |
+
"source": [
|
| 594 |
+
"# Add new columns for gene information\n",
|
| 595 |
+
"if clinvar_raw is not None:\n",
|
| 596 |
+
" clinvar_raw['GENE_ID'] = \"\"\n",
|
| 597 |
+
" clinvar_raw['GENE'] = \"\"\n",
|
| 598 |
+
" print(\"✅ Added GENE_ID and GENE columns\")\n",
|
| 599 |
+
" print(f\"📊 Updated dataset shape: {clinvar_raw.shape}\")\n",
|
| 600 |
+
"else:\n",
|
| 601 |
+
" print(\"⚠️ Cannot add columns - data not loaded\")"
|
| 602 |
+
]
|
| 603 |
+
},
|
| 604 |
+
{
|
| 605 |
+
"cell_type": "code",
|
| 606 |
+
"execution_count": null,
|
| 607 |
+
"id": "92b159f5-694d-4ee4-9616-1ebf00f71904",
|
| 608 |
+
"metadata": {},
|
| 609 |
+
"outputs": [
|
| 610 |
+
{
|
| 611 |
+
"data": {
|
| 612 |
+
"text/html": [
|
| 613 |
+
"<div>\n",
|
| 614 |
+
"<style scoped>\n",
|
| 615 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 616 |
+
" vertical-align: middle;\n",
|
| 617 |
+
" }\n",
|
| 618 |
+
"\n",
|
| 619 |
+
" .dataframe tbody tr th {\n",
|
| 620 |
+
" vertical-align: top;\n",
|
| 621 |
+
" }\n",
|
| 622 |
+
"\n",
|
| 623 |
+
" .dataframe thead th {\n",
|
| 624 |
+
" text-align: right;\n",
|
| 625 |
+
" }\n",
|
| 626 |
+
"</style>\n",
|
| 627 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 628 |
+
" <thead>\n",
|
| 629 |
+
" <tr style=\"text-align: right;\">\n",
|
| 630 |
+
" <th></th>\n",
|
| 631 |
+
" <th>CHROM</th>\n",
|
| 632 |
+
" <th>POS</th>\n",
|
| 633 |
+
" <th>REF</th>\n",
|
| 634 |
+
" <th>ALT</th>\n",
|
| 635 |
+
" <th>LABEL</th>\n",
|
| 636 |
+
" <th>SOURCE</th>\n",
|
| 637 |
+
" <th>CONSEQUENCE</th>\n",
|
| 638 |
+
" <th>ID</th>\n",
|
| 639 |
+
" <th>REVIEW_STATUS</th>\n",
|
| 640 |
+
" <th>GENE</th>\n",
|
| 641 |
+
" <th>split</th>\n",
|
| 642 |
+
" <th>INT_LABEL</th>\n",
|
| 643 |
+
" <th>GENE_ID</th>\n",
|
| 644 |
+
" </tr>\n",
|
| 645 |
+
" </thead>\n",
|
| 646 |
+
" <tbody>\n",
|
| 647 |
+
" <tr>\n",
|
| 648 |
+
" <th>0</th>\n",
|
| 649 |
+
" <td>chr1</td>\n",
|
| 650 |
+
" <td>976215</td>\n",
|
| 651 |
+
" <td>A</td>\n",
|
| 652 |
+
" <td>G</td>\n",
|
| 653 |
+
" <td>Pathogenic</td>\n",
|
| 654 |
+
" <td>ClinVar</td>\n",
|
| 655 |
+
" <td>missense_variant</td>\n",
|
| 656 |
+
" <td>1320032</td>\n",
|
| 657 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 658 |
+
" <td></td>\n",
|
| 659 |
+
" <td>train</td>\n",
|
| 660 |
+
" <td>1</td>\n",
|
| 661 |
+
" <td></td>\n",
|
| 662 |
+
" </tr>\n",
|
| 663 |
+
" <tr>\n",
|
| 664 |
+
" <th>1</th>\n",
|
| 665 |
+
" <td>chr1</td>\n",
|
| 666 |
+
" <td>1050449</td>\n",
|
| 667 |
+
" <td>G</td>\n",
|
| 668 |
+
" <td>A</td>\n",
|
| 669 |
+
" <td>Pathogenic</td>\n",
|
| 670 |
+
" <td>ClinVar</td>\n",
|
| 671 |
+
" <td>missense_variant</td>\n",
|
| 672 |
+
" <td>1284257</td>\n",
|
| 673 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 674 |
+
" <td></td>\n",
|
| 675 |
+
" <td>train</td>\n",
|
| 676 |
+
" <td>1</td>\n",
|
| 677 |
+
" <td></td>\n",
|
| 678 |
+
" </tr>\n",
|
| 679 |
+
" <tr>\n",
|
| 680 |
+
" <th>2</th>\n",
|
| 681 |
+
" <td>chr1</td>\n",
|
| 682 |
+
" <td>1050575</td>\n",
|
| 683 |
+
" <td>G</td>\n",
|
| 684 |
+
" <td>C</td>\n",
|
| 685 |
+
" <td>Pathogenic</td>\n",
|
| 686 |
+
" <td>ClinVar</td>\n",
|
| 687 |
+
" <td>missense_variant</td>\n",
|
| 688 |
+
" <td>18241</td>\n",
|
| 689 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 690 |
+
" <td></td>\n",
|
| 691 |
+
" <td>train</td>\n",
|
| 692 |
+
" <td>1</td>\n",
|
| 693 |
+
" <td></td>\n",
|
| 694 |
+
" </tr>\n",
|
| 695 |
+
" <tr>\n",
|
| 696 |
+
" <th>3</th>\n",
|
| 697 |
+
" <td>chr1</td>\n",
|
| 698 |
+
" <td>1213738</td>\n",
|
| 699 |
+
" <td>G</td>\n",
|
| 700 |
+
" <td>A</td>\n",
|
| 701 |
+
" <td>Pathogenic</td>\n",
|
| 702 |
+
" <td>ClinVar</td>\n",
|
| 703 |
+
" <td>missense_variant</td>\n",
|
| 704 |
+
" <td>96692</td>\n",
|
| 705 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 706 |
+
" <td></td>\n",
|
| 707 |
+
" <td>train</td>\n",
|
| 708 |
+
" <td>1</td>\n",
|
| 709 |
+
" <td></td>\n",
|
| 710 |
+
" </tr>\n",
|
| 711 |
+
" <tr>\n",
|
| 712 |
+
" <th>4</th>\n",
|
| 713 |
+
" <td>chr1</td>\n",
|
| 714 |
+
" <td>1232279</td>\n",
|
| 715 |
+
" <td>A</td>\n",
|
| 716 |
+
" <td>G</td>\n",
|
| 717 |
+
" <td>Pathogenic</td>\n",
|
| 718 |
+
" <td>ClinVar</td>\n",
|
| 719 |
+
" <td>initiatior_codon_variant,missense_variant</td>\n",
|
| 720 |
+
" <td>60484</td>\n",
|
| 721 |
+
" <td>criteria_provided,_multiple_submitters,_no_con...</td>\n",
|
| 722 |
+
" <td></td>\n",
|
| 723 |
+
" <td>train</td>\n",
|
| 724 |
+
" <td>1</td>\n",
|
| 725 |
+
" <td></td>\n",
|
| 726 |
+
" </tr>\n",
|
| 727 |
+
" <tr>\n",
|
| 728 |
+
" <th>...</th>\n",
|
| 729 |
+
" <td>...</td>\n",
|
| 730 |
+
" <td>...</td>\n",
|
| 731 |
+
" <td>...</td>\n",
|
| 732 |
+
" <td>...</td>\n",
|
| 733 |
+
" <td>...</td>\n",
|
| 734 |
+
" <td>...</td>\n",
|
| 735 |
+
" <td>...</td>\n",
|
| 736 |
+
" <td>...</td>\n",
|
| 737 |
+
" <td>...</td>\n",
|
| 738 |
+
" <td>...</td>\n",
|
| 739 |
+
" <td>...</td>\n",
|
| 740 |
+
" <td>...</td>\n",
|
| 741 |
+
" <td>...</td>\n",
|
| 742 |
+
" </tr>\n",
|
| 743 |
+
" <tr>\n",
|
| 744 |
+
" <th>22249</th>\n",
|
| 745 |
+
" <td>chrY</td>\n",
|
| 746 |
+
" <td>2787412</td>\n",
|
| 747 |
+
" <td>C</td>\n",
|
| 748 |
+
" <td>T</td>\n",
|
| 749 |
+
" <td>Pathogenic</td>\n",
|
| 750 |
+
" <td>ClinVar</td>\n",
|
| 751 |
+
" <td>missense_variant</td>\n",
|
| 752 |
+
" <td>9747</td>\n",
|
| 753 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 754 |
+
" <td></td>\n",
|
| 755 |
+
" <td>train</td>\n",
|
| 756 |
+
" <td>1</td>\n",
|
| 757 |
+
" <td></td>\n",
|
| 758 |
+
" </tr>\n",
|
| 759 |
+
" <tr>\n",
|
| 760 |
+
" <th>22250</th>\n",
|
| 761 |
+
" <td>chrY</td>\n",
|
| 762 |
+
" <td>2787426</td>\n",
|
| 763 |
+
" <td>C</td>\n",
|
| 764 |
+
" <td>G</td>\n",
|
| 765 |
+
" <td>Pathogenic</td>\n",
|
| 766 |
+
" <td>ClinVar</td>\n",
|
| 767 |
+
" <td>missense_variant</td>\n",
|
| 768 |
+
" <td>9739</td>\n",
|
| 769 |
+
" <td>criteria_provided,_single_submitter</td>\n",
|
| 770 |
+
" <td></td>\n",
|
| 771 |
+
" <td>train</td>\n",
|
| 772 |
+
" <td>1</td>\n",
|
| 773 |
+
" <td></td>\n",
|
| 774 |
+
" </tr>\n",
|
| 775 |
+
" <tr>\n",
|
| 776 |
+
" <th>22251</th>\n",
|
| 777 |
+
" <td>chrY</td>\n",
|
| 778 |
+
" <td>2787515</td>\n",
|
| 779 |
+
" <td>C</td>\n",
|
| 780 |
+
" <td>A</td>\n",
|
| 781 |
+
" <td>Pathogenic</td>\n",
|
| 782 |
+
" <td>ClinVar</td>\n",
|
| 783 |
+
" <td>missense_variant</td>\n",
|
| 784 |
+
" <td>492908</td>\n",
|
| 785 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 786 |
+
" <td></td>\n",
|
| 787 |
+
" <td>train</td>\n",
|
| 788 |
+
" <td>1</td>\n",
|
| 789 |
+
" <td></td>\n",
|
| 790 |
+
" </tr>\n",
|
| 791 |
+
" <tr>\n",
|
| 792 |
+
" <th>22252</th>\n",
|
| 793 |
+
" <td>chrY</td>\n",
|
| 794 |
+
" <td>2787551</td>\n",
|
| 795 |
+
" <td>C</td>\n",
|
| 796 |
+
" <td>T</td>\n",
|
| 797 |
+
" <td>Pathogenic</td>\n",
|
| 798 |
+
" <td>ClinVar</td>\n",
|
| 799 |
+
" <td>missense_variant</td>\n",
|
| 800 |
+
" <td>9754</td>\n",
|
| 801 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 802 |
+
" <td></td>\n",
|
| 803 |
+
" <td>train</td>\n",
|
| 804 |
+
" <td>1</td>\n",
|
| 805 |
+
" <td></td>\n",
|
| 806 |
+
" </tr>\n",
|
| 807 |
+
" <tr>\n",
|
| 808 |
+
" <th>22253</th>\n",
|
| 809 |
+
" <td>chrY</td>\n",
|
| 810 |
+
" <td>7063898</td>\n",
|
| 811 |
+
" <td>A</td>\n",
|
| 812 |
+
" <td>T</td>\n",
|
| 813 |
+
" <td>Pathogenic</td>\n",
|
| 814 |
+
" <td>ClinVar</td>\n",
|
| 815 |
+
" <td>missense_variant</td>\n",
|
| 816 |
+
" <td>625467</td>\n",
|
| 817 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 818 |
+
" <td></td>\n",
|
| 819 |
+
" <td>train</td>\n",
|
| 820 |
+
" <td>1</td>\n",
|
| 821 |
+
" <td></td>\n",
|
| 822 |
+
" </tr>\n",
|
| 823 |
+
" </tbody>\n",
|
| 824 |
+
"</table>\n",
|
| 825 |
+
"<p>22254 rows × 13 columns</p>\n",
|
| 826 |
+
"</div>"
|
| 827 |
+
],
|
| 828 |
+
"text/plain": [
|
| 829 |
+
" CHROM POS REF ALT LABEL SOURCE \\\n",
|
| 830 |
+
"0 chr1 976215 A G Pathogenic ClinVar \n",
|
| 831 |
+
"1 chr1 1050449 G A Pathogenic ClinVar \n",
|
| 832 |
+
"2 chr1 1050575 G C Pathogenic ClinVar \n",
|
| 833 |
+
"3 chr1 1213738 G A Pathogenic ClinVar \n",
|
| 834 |
+
"4 chr1 1232279 A G Pathogenic ClinVar \n",
|
| 835 |
+
"... ... ... .. .. ... ... \n",
|
| 836 |
+
"22249 chrY 2787412 C T Pathogenic ClinVar \n",
|
| 837 |
+
"22250 chrY 2787426 C G Pathogenic ClinVar \n",
|
| 838 |
+
"22251 chrY 2787515 C A Pathogenic ClinVar \n",
|
| 839 |
+
"22252 chrY 2787551 C T Pathogenic ClinVar \n",
|
| 840 |
+
"22253 chrY 7063898 A T Pathogenic ClinVar \n",
|
| 841 |
+
"\n",
|
| 842 |
+
" CONSEQUENCE ID \\\n",
|
| 843 |
+
"0 missense_variant 1320032 \n",
|
| 844 |
+
"1 missense_variant 1284257 \n",
|
| 845 |
+
"2 missense_variant 18241 \n",
|
| 846 |
+
"3 missense_variant 96692 \n",
|
| 847 |
+
"4 initiatior_codon_variant,missense_variant 60484 \n",
|
| 848 |
+
"... ... ... \n",
|
| 849 |
+
"22249 missense_variant 9747 \n",
|
| 850 |
+
"22250 missense_variant 9739 \n",
|
| 851 |
+
"22251 missense_variant 492908 \n",
|
| 852 |
+
"22252 missense_variant 9754 \n",
|
| 853 |
+
"22253 missense_variant 625467 \n",
|
| 854 |
+
"\n",
|
| 855 |
+
" REVIEW_STATUS GENE split \\\n",
|
| 856 |
+
"0 no_assertion_criteria_provided train \n",
|
| 857 |
+
"1 no_assertion_criteria_provided train \n",
|
| 858 |
+
"2 no_assertion_criteria_provided train \n",
|
| 859 |
+
"3 no_assertion_criteria_provided train \n",
|
| 860 |
+
"4 criteria_provided,_multiple_submitters,_no_con... train \n",
|
| 861 |
+
"... ... ... ... \n",
|
| 862 |
+
"22249 no_assertion_criteria_provided train \n",
|
| 863 |
+
"22250 criteria_provided,_single_submitter train \n",
|
| 864 |
+
"22251 no_assertion_criteria_provided train \n",
|
| 865 |
+
"22252 no_assertion_criteria_provided train \n",
|
| 866 |
+
"22253 no_assertion_criteria_provided train \n",
|
| 867 |
+
"\n",
|
| 868 |
+
" INT_LABEL GENE_ID \n",
|
| 869 |
+
"0 1 \n",
|
| 870 |
+
"1 1 \n",
|
| 871 |
+
"2 1 \n",
|
| 872 |
+
"3 1 \n",
|
| 873 |
+
"4 1 \n",
|
| 874 |
+
"... ... ... \n",
|
| 875 |
+
"22249 1 \n",
|
| 876 |
+
"22250 1 \n",
|
| 877 |
+
"22251 1 \n",
|
| 878 |
+
"22252 1 \n",
|
| 879 |
+
"22253 1 \n",
|
| 880 |
+
"\n",
|
| 881 |
+
"[22254 rows x 13 columns]"
|
| 882 |
+
]
|
| 883 |
+
},
|
| 884 |
+
"execution_count": 34,
|
| 885 |
+
"metadata": {},
|
| 886 |
+
"output_type": "execute_result"
|
| 887 |
+
}
|
| 888 |
+
],
|
| 889 |
+
"source": [
|
| 890 |
+
"clinvar_raw\n",
|
| 891 |
+
"\n",
|
| 892 |
+
"# Display updated dataset with new columns\n",
|
| 893 |
+
"if clinvar_raw is not None:\n",
|
| 894 |
+
" print(f\"📊 Dataset with new columns: {clinvar_raw.shape}\")\n",
|
| 895 |
+
" print(f\"📋 All columns: {list(clinvar_raw.columns)}\")\n",
|
| 896 |
+
" display(clinvar_raw.head())\n",
|
| 897 |
+
"else:\n",
|
| 898 |
+
" print(\"❌ No data to display\")"
|
| 899 |
+
]
|
| 900 |
+
},
|
| 901 |
+
{
|
| 902 |
+
"cell_type": "code",
|
| 903 |
+
"execution_count": null,
|
| 904 |
+
"id": "f36db716-392a-46a8-a404-d78165a4623c",
|
| 905 |
+
"metadata": {},
|
| 906 |
+
"outputs": [],
|
| 907 |
+
"source": [
|
| 908 |
+
"import pandas as pd\n",
|
| 909 |
+
"import xml.etree.ElementTree as ET\n",
|
| 910 |
+
"import os\n",
|
| 911 |
+
"\n",
|
| 912 |
+
"# Parse ClinVar XML files to extract gene information\n",
|
| 913 |
+
"# This processes each ClinVar ID and extracts gene symbols and IDs from XML records\n",
|
| 914 |
+
"\n",
|
| 915 |
+
"if clinvar_raw is not None:\n",
|
| 916 |
+
" # Load list of ClinVar IDs\n",
|
| 917 |
+
" try:\n",
|
| 918 |
+
" with open(\"Clinvar_ID.txt\", \"r\") as f:\n",
|
| 919 |
+
" clinvar_ids = [line.strip() for line in f if line.strip()]\n",
|
| 920 |
+
" \n",
|
| 921 |
+
" print(f\"📋 Processing {len(clinvar_ids)} ClinVar IDs\")\n",
|
| 922 |
+
" \n",
|
| 923 |
+
" processed_count = 0\n",
|
| 924 |
+
" error_count = 0\n",
|
| 925 |
+
" \n",
|
| 926 |
+
" # Process each ClinVar ID\n",
|
| 927 |
+
" for i, clinvar_id in enumerate(clinvar_ids):\n",
|
| 928 |
+
" if i % 100 == 0: # Progress indicator\n",
|
| 929 |
+
" print(f\"📊 Processing ID {i+1}/{len(clinvar_ids)}...\")\n",
|
| 930 |
+
" \n",
|
| 931 |
+
" try:\n",
|
| 932 |
+
" id_int = int(clinvar_id)\n",
|
| 933 |
+
" xml_path = f'data/{clinvar_id}.xml'\n",
|
| 934 |
+
" \n",
|
| 935 |
+
" # Check if XML file exists\n",
|
| 936 |
+
" if not os.path.exists(xml_path):\n",
|
| 937 |
+
" print(f\"⚠️ XML file not found: {xml_path}\")\n",
|
| 938 |
+
" continue\n",
|
| 939 |
+
" \n",
|
| 940 |
+
" # Parse XML file\n",
|
| 941 |
+
" with open(xml_path, 'r', encoding='utf-8') as file:\n",
|
| 942 |
+
" tree = ET.parse(file)\n",
|
| 943 |
+
" root = tree.getroot()\n",
|
| 944 |
+
" \n",
|
| 945 |
+
" # Check for error in XML\n",
|
| 946 |
+
" error_element = root.find(\".//error\")\n",
|
| 947 |
+
" if error_element is not None:\n",
|
| 948 |
+
" # Remove entries with errors\n",
|
| 949 |
+
" clinvar_raw = clinvar_raw[clinvar_raw[\"ID\"] != id_int]\n",
|
| 950 |
+
" error_count += 1\n",
|
| 951 |
+
" continue\n",
|
| 952 |
+
" \n",
|
| 953 |
+
" # Extract gene information\n",
|
| 954 |
+
" gene_names = []\n",
|
| 955 |
+
" gene_ids = []\n",
|
| 956 |
+
" \n",
|
| 957 |
+
" for gene in root.findall(\".//genes/gene\"):\n",
|
| 958 |
+
" symbol = gene.findtext(\"symbol\")\n",
|
| 959 |
+
" gene_id_data = gene.findtext(\"GeneID\")\n",
|
| 960 |
+
" \n",
|
| 961 |
+
" if symbol:\n",
|
| 962 |
+
" gene_names.append(symbol)\n",
|
| 963 |
+
" if gene_id_data:\n",
|
| 964 |
+
" gene_ids.append(gene_id_data)\n",
|
| 965 |
+
" \n",
|
| 966 |
+
" # Join multiple entries with commas\n",
|
| 967 |
+
" gene_name_str = \", \".join(gene_names) if gene_names else \"\"\n",
|
| 968 |
+
" gene_id_str = \", \".join(gene_ids) if gene_ids else \"\"\n",
|
| 969 |
+
" \n",
|
| 970 |
+
" # Update DataFrame\n",
|
| 971 |
+
" mask = clinvar_raw[\"ID\"] == id_int\n",
|
| 972 |
+
" if mask.any():\n",
|
| 973 |
+
" clinvar_raw.loc[mask, \"GENE\"] = gene_name_str\n",
|
| 974 |
+
" clinvar_raw.loc[mask, \"GENE_ID\"] = gene_id_str\n",
|
| 975 |
+
" processed_count += 1\n",
|
| 976 |
+
" \n",
|
| 977 |
+
" except ET.ParseError as e:\n",
|
| 978 |
+
" print(f\"⚠️ XML parsing error for {clinvar_id}: {e}\")\n",
|
| 979 |
+
" error_count += 1\n",
|
| 980 |
+
" except ValueError as e:\n",
|
| 981 |
+
" print(f\"⚠️ Invalid ClinVar ID {clinvar_id}: {e}\")\n",
|
| 982 |
+
" error_count += 1\n",
|
| 983 |
+
" except Exception as e:\n",
|
| 984 |
+
" print(f\"⚠️ Unexpected error processing {clinvar_id}: {e}\")\n",
|
| 985 |
+
" error_count += 1\n",
|
| 986 |
+
" \n",
|
| 987 |
+
" print(f\"\\n✅ Processing complete:\")\n",
|
| 988 |
+
" print(f\" 📊 Successfully processed: {processed_count}\")\n",
|
| 989 |
+
" print(f\" ❌ Errors encountered: {error_count}\")\n",
|
| 990 |
+
" print(f\" 📋 Final dataset shape: {clinvar_raw.shape}\")\n",
|
| 991 |
+
" \n",
|
| 992 |
+
" except FileNotFoundError:\n",
|
| 993 |
+
" print(\"❌ Error: Clinvar_ID.txt not found\")\n",
|
| 994 |
+
" print(\"Please run the ID extraction step first\")\n",
|
| 995 |
+
" except Exception as e:\n",
|
| 996 |
+
" print(f\"❌ Error during XML processing: {e}\")\n",
|
| 997 |
+
"else:\n",
|
| 998 |
+
" print(\"⚠️ Cannot process XML files - ClinVar data not loaded\")"
|
| 999 |
+
]
|
| 1000 |
+
},
|
| 1001 |
+
{
|
| 1002 |
+
"cell_type": "code",
|
| 1003 |
+
"execution_count": null,
|
| 1004 |
+
"id": "ae0c9d8b-1b12-40a4-82ec-c3452e9dda90",
|
| 1005 |
+
"metadata": {},
|
| 1006 |
+
"outputs": [
|
| 1007 |
+
{
|
| 1008 |
+
"data": {
|
| 1009 |
+
"text/html": [
|
| 1010 |
+
"<div>\n",
|
| 1011 |
+
"<style scoped>\n",
|
| 1012 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 1013 |
+
" vertical-align: middle;\n",
|
| 1014 |
+
" }\n",
|
| 1015 |
+
"\n",
|
| 1016 |
+
" .dataframe tbody tr th {\n",
|
| 1017 |
+
" vertical-align: top;\n",
|
| 1018 |
+
" }\n",
|
| 1019 |
+
"\n",
|
| 1020 |
+
" .dataframe thead th {\n",
|
| 1021 |
+
" text-align: right;\n",
|
| 1022 |
+
" }\n",
|
| 1023 |
+
"</style>\n",
|
| 1024 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 1025 |
+
" <thead>\n",
|
| 1026 |
+
" <tr style=\"text-align: right;\">\n",
|
| 1027 |
+
" <th></th>\n",
|
| 1028 |
+
" <th>CHROM</th>\n",
|
| 1029 |
+
" <th>POS</th>\n",
|
| 1030 |
+
" <th>REF</th>\n",
|
| 1031 |
+
" <th>ALT</th>\n",
|
| 1032 |
+
" <th>LABEL</th>\n",
|
| 1033 |
+
" <th>SOURCE</th>\n",
|
| 1034 |
+
" <th>CONSEQUENCE</th>\n",
|
| 1035 |
+
" <th>ID</th>\n",
|
| 1036 |
+
" <th>REVIEW_STATUS</th>\n",
|
| 1037 |
+
" <th>GENE</th>\n",
|
| 1038 |
+
" <th>split</th>\n",
|
| 1039 |
+
" <th>INT_LABEL</th>\n",
|
| 1040 |
+
" <th>GENE_ID</th>\n",
|
| 1041 |
+
" </tr>\n",
|
| 1042 |
+
" </thead>\n",
|
| 1043 |
+
" <tbody>\n",
|
| 1044 |
+
" <tr>\n",
|
| 1045 |
+
" <th>0</th>\n",
|
| 1046 |
+
" <td>chr1</td>\n",
|
| 1047 |
+
" <td>976215</td>\n",
|
| 1048 |
+
" <td>A</td>\n",
|
| 1049 |
+
" <td>G</td>\n",
|
| 1050 |
+
" <td>Pathogenic</td>\n",
|
| 1051 |
+
" <td>ClinVar</td>\n",
|
| 1052 |
+
" <td>missense_variant</td>\n",
|
| 1053 |
+
" <td>1320032</td>\n",
|
| 1054 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1055 |
+
" <td>PERM1</td>\n",
|
| 1056 |
+
" <td>train</td>\n",
|
| 1057 |
+
" <td>1</td>\n",
|
| 1058 |
+
" <td>84808</td>\n",
|
| 1059 |
+
" </tr>\n",
|
| 1060 |
+
" <tr>\n",
|
| 1061 |
+
" <th>1</th>\n",
|
| 1062 |
+
" <td>chr1</td>\n",
|
| 1063 |
+
" <td>1050449</td>\n",
|
| 1064 |
+
" <td>G</td>\n",
|
| 1065 |
+
" <td>A</td>\n",
|
| 1066 |
+
" <td>Pathogenic</td>\n",
|
| 1067 |
+
" <td>ClinVar</td>\n",
|
| 1068 |
+
" <td>missense_variant</td>\n",
|
| 1069 |
+
" <td>1284257</td>\n",
|
| 1070 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1071 |
+
" <td>AGRN</td>\n",
|
| 1072 |
+
" <td>train</td>\n",
|
| 1073 |
+
" <td>1</td>\n",
|
| 1074 |
+
" <td>375790</td>\n",
|
| 1075 |
+
" </tr>\n",
|
| 1076 |
+
" <tr>\n",
|
| 1077 |
+
" <th>2</th>\n",
|
| 1078 |
+
" <td>chr1</td>\n",
|
| 1079 |
+
" <td>1050575</td>\n",
|
| 1080 |
+
" <td>G</td>\n",
|
| 1081 |
+
" <td>C</td>\n",
|
| 1082 |
+
" <td>Pathogenic</td>\n",
|
| 1083 |
+
" <td>ClinVar</td>\n",
|
| 1084 |
+
" <td>missense_variant</td>\n",
|
| 1085 |
+
" <td>18241</td>\n",
|
| 1086 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1087 |
+
" <td>AGRN</td>\n",
|
| 1088 |
+
" <td>train</td>\n",
|
| 1089 |
+
" <td>1</td>\n",
|
| 1090 |
+
" <td>375790</td>\n",
|
| 1091 |
+
" </tr>\n",
|
| 1092 |
+
" <tr>\n",
|
| 1093 |
+
" <th>3</th>\n",
|
| 1094 |
+
" <td>chr1</td>\n",
|
| 1095 |
+
" <td>1213738</td>\n",
|
| 1096 |
+
" <td>G</td>\n",
|
| 1097 |
+
" <td>A</td>\n",
|
| 1098 |
+
" <td>Pathogenic</td>\n",
|
| 1099 |
+
" <td>ClinVar</td>\n",
|
| 1100 |
+
" <td>missense_variant</td>\n",
|
| 1101 |
+
" <td>96692</td>\n",
|
| 1102 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1103 |
+
" <td>TNFRSF4</td>\n",
|
| 1104 |
+
" <td>train</td>\n",
|
| 1105 |
+
" <td>1</td>\n",
|
| 1106 |
+
" <td>7293</td>\n",
|
| 1107 |
+
" </tr>\n",
|
| 1108 |
+
" <tr>\n",
|
| 1109 |
+
" <th>4</th>\n",
|
| 1110 |
+
" <td>chr1</td>\n",
|
| 1111 |
+
" <td>1232279</td>\n",
|
| 1112 |
+
" <td>A</td>\n",
|
| 1113 |
+
" <td>G</td>\n",
|
| 1114 |
+
" <td>Pathogenic</td>\n",
|
| 1115 |
+
" <td>ClinVar</td>\n",
|
| 1116 |
+
" <td>initiatior_codon_variant,missense_variant</td>\n",
|
| 1117 |
+
" <td>60484</td>\n",
|
| 1118 |
+
" <td>criteria_provided,_multiple_submitters,_no_con...</td>\n",
|
| 1119 |
+
" <td>B3GALT6</td>\n",
|
| 1120 |
+
" <td>train</td>\n",
|
| 1121 |
+
" <td>1</td>\n",
|
| 1122 |
+
" <td>126792</td>\n",
|
| 1123 |
+
" </tr>\n",
|
| 1124 |
+
" <tr>\n",
|
| 1125 |
+
" <th>...</th>\n",
|
| 1126 |
+
" <td>...</td>\n",
|
| 1127 |
+
" <td>...</td>\n",
|
| 1128 |
+
" <td>...</td>\n",
|
| 1129 |
+
" <td>...</td>\n",
|
| 1130 |
+
" <td>...</td>\n",
|
| 1131 |
+
" <td>...</td>\n",
|
| 1132 |
+
" <td>...</td>\n",
|
| 1133 |
+
" <td>...</td>\n",
|
| 1134 |
+
" <td>...</td>\n",
|
| 1135 |
+
" <td>...</td>\n",
|
| 1136 |
+
" <td>...</td>\n",
|
| 1137 |
+
" <td>...</td>\n",
|
| 1138 |
+
" <td>...</td>\n",
|
| 1139 |
+
" </tr>\n",
|
| 1140 |
+
" <tr>\n",
|
| 1141 |
+
" <th>22249</th>\n",
|
| 1142 |
+
" <td>chrY</td>\n",
|
| 1143 |
+
" <td>2787412</td>\n",
|
| 1144 |
+
" <td>C</td>\n",
|
| 1145 |
+
" <td>T</td>\n",
|
| 1146 |
+
" <td>Pathogenic</td>\n",
|
| 1147 |
+
" <td>ClinVar</td>\n",
|
| 1148 |
+
" <td>missense_variant</td>\n",
|
| 1149 |
+
" <td>9747</td>\n",
|
| 1150 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1151 |
+
" <td>SRY</td>\n",
|
| 1152 |
+
" <td>train</td>\n",
|
| 1153 |
+
" <td>1</td>\n",
|
| 1154 |
+
" <td>6736</td>\n",
|
| 1155 |
+
" </tr>\n",
|
| 1156 |
+
" <tr>\n",
|
| 1157 |
+
" <th>22250</th>\n",
|
| 1158 |
+
" <td>chrY</td>\n",
|
| 1159 |
+
" <td>2787426</td>\n",
|
| 1160 |
+
" <td>C</td>\n",
|
| 1161 |
+
" <td>G</td>\n",
|
| 1162 |
+
" <td>Pathogenic</td>\n",
|
| 1163 |
+
" <td>ClinVar</td>\n",
|
| 1164 |
+
" <td>missense_variant</td>\n",
|
| 1165 |
+
" <td>9739</td>\n",
|
| 1166 |
+
" <td>criteria_provided,_single_submitter</td>\n",
|
| 1167 |
+
" <td>SRY</td>\n",
|
| 1168 |
+
" <td>train</td>\n",
|
| 1169 |
+
" <td>1</td>\n",
|
| 1170 |
+
" <td>6736</td>\n",
|
| 1171 |
+
" </tr>\n",
|
| 1172 |
+
" <tr>\n",
|
| 1173 |
+
" <th>22251</th>\n",
|
| 1174 |
+
" <td>chrY</td>\n",
|
| 1175 |
+
" <td>2787515</td>\n",
|
| 1176 |
+
" <td>C</td>\n",
|
| 1177 |
+
" <td>A</td>\n",
|
| 1178 |
+
" <td>Pathogenic</td>\n",
|
| 1179 |
+
" <td>ClinVar</td>\n",
|
| 1180 |
+
" <td>missense_variant</td>\n",
|
| 1181 |
+
" <td>492908</td>\n",
|
| 1182 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1183 |
+
" <td>SRY</td>\n",
|
| 1184 |
+
" <td>train</td>\n",
|
| 1185 |
+
" <td>1</td>\n",
|
| 1186 |
+
" <td>6736</td>\n",
|
| 1187 |
+
" </tr>\n",
|
| 1188 |
+
" <tr>\n",
|
| 1189 |
+
" <th>22252</th>\n",
|
| 1190 |
+
" <td>chrY</td>\n",
|
| 1191 |
+
" <td>2787551</td>\n",
|
| 1192 |
+
" <td>C</td>\n",
|
| 1193 |
+
" <td>T</td>\n",
|
| 1194 |
+
" <td>Pathogenic</td>\n",
|
| 1195 |
+
" <td>ClinVar</td>\n",
|
| 1196 |
+
" <td>missense_variant</td>\n",
|
| 1197 |
+
" <td>9754</td>\n",
|
| 1198 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1199 |
+
" <td>SRY</td>\n",
|
| 1200 |
+
" <td>train</td>\n",
|
| 1201 |
+
" <td>1</td>\n",
|
| 1202 |
+
" <td>6736</td>\n",
|
| 1203 |
+
" </tr>\n",
|
| 1204 |
+
" <tr>\n",
|
| 1205 |
+
" <th>22253</th>\n",
|
| 1206 |
+
" <td>chrY</td>\n",
|
| 1207 |
+
" <td>7063898</td>\n",
|
| 1208 |
+
" <td>A</td>\n",
|
| 1209 |
+
" <td>T</td>\n",
|
| 1210 |
+
" <td>Pathogenic</td>\n",
|
| 1211 |
+
" <td>ClinVar</td>\n",
|
| 1212 |
+
" <td>missense_variant</td>\n",
|
| 1213 |
+
" <td>625467</td>\n",
|
| 1214 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1215 |
+
" <td>LOC126057105, TBL1Y</td>\n",
|
| 1216 |
+
" <td>train</td>\n",
|
| 1217 |
+
" <td>1</td>\n",
|
| 1218 |
+
" <td>126057105, 90665</td>\n",
|
| 1219 |
+
" </tr>\n",
|
| 1220 |
+
" </tbody>\n",
|
| 1221 |
+
"</table>\n",
|
| 1222 |
+
"<p>22150 rows × 13 columns</p>\n",
|
| 1223 |
+
"</div>"
|
| 1224 |
+
],
|
| 1225 |
+
"text/plain": [
|
| 1226 |
+
" CHROM POS REF ALT LABEL SOURCE \\\n",
|
| 1227 |
+
"0 chr1 976215 A G Pathogenic ClinVar \n",
|
| 1228 |
+
"1 chr1 1050449 G A Pathogenic ClinVar \n",
|
| 1229 |
+
"2 chr1 1050575 G C Pathogenic ClinVar \n",
|
| 1230 |
+
"3 chr1 1213738 G A Pathogenic ClinVar \n",
|
| 1231 |
+
"4 chr1 1232279 A G Pathogenic ClinVar \n",
|
| 1232 |
+
"... ... ... .. .. ... ... \n",
|
| 1233 |
+
"22249 chrY 2787412 C T Pathogenic ClinVar \n",
|
| 1234 |
+
"22250 chrY 2787426 C G Pathogenic ClinVar \n",
|
| 1235 |
+
"22251 chrY 2787515 C A Pathogenic ClinVar \n",
|
| 1236 |
+
"22252 chrY 2787551 C T Pathogenic ClinVar \n",
|
| 1237 |
+
"22253 chrY 7063898 A T Pathogenic ClinVar \n",
|
| 1238 |
+
"\n",
|
| 1239 |
+
" CONSEQUENCE ID \\\n",
|
| 1240 |
+
"0 missense_variant 1320032 \n",
|
| 1241 |
+
"1 missense_variant 1284257 \n",
|
| 1242 |
+
"2 missense_variant 18241 \n",
|
| 1243 |
+
"3 missense_variant 96692 \n",
|
| 1244 |
+
"4 initiatior_codon_variant,missense_variant 60484 \n",
|
| 1245 |
+
"... ... ... \n",
|
| 1246 |
+
"22249 missense_variant 9747 \n",
|
| 1247 |
+
"22250 missense_variant 9739 \n",
|
| 1248 |
+
"22251 missense_variant 492908 \n",
|
| 1249 |
+
"22252 missense_variant 9754 \n",
|
| 1250 |
+
"22253 missense_variant 625467 \n",
|
| 1251 |
+
"\n",
|
| 1252 |
+
" REVIEW_STATUS GENE \\\n",
|
| 1253 |
+
"0 no_assertion_criteria_provided PERM1 \n",
|
| 1254 |
+
"1 no_assertion_criteria_provided AGRN \n",
|
| 1255 |
+
"2 no_assertion_criteria_provided AGRN \n",
|
| 1256 |
+
"3 no_assertion_criteria_provided TNFRSF4 \n",
|
| 1257 |
+
"4 criteria_provided,_multiple_submitters,_no_con... B3GALT6 \n",
|
| 1258 |
+
"... ... ... \n",
|
| 1259 |
+
"22249 no_assertion_criteria_provided SRY \n",
|
| 1260 |
+
"22250 criteria_provided,_single_submitter SRY \n",
|
| 1261 |
+
"22251 no_assertion_criteria_provided SRY \n",
|
| 1262 |
+
"22252 no_assertion_criteria_provided SRY \n",
|
| 1263 |
+
"22253 no_assertion_criteria_provided LOC126057105, TBL1Y \n",
|
| 1264 |
+
"\n",
|
| 1265 |
+
" split INT_LABEL GENE_ID \n",
|
| 1266 |
+
"0 train 1 84808 \n",
|
| 1267 |
+
"1 train 1 375790 \n",
|
| 1268 |
+
"2 train 1 375790 \n",
|
| 1269 |
+
"3 train 1 7293 \n",
|
| 1270 |
+
"4 train 1 126792 \n",
|
| 1271 |
+
"... ... ... ... \n",
|
| 1272 |
+
"22249 train 1 6736 \n",
|
| 1273 |
+
"22250 train 1 6736 \n",
|
| 1274 |
+
"22251 train 1 6736 \n",
|
| 1275 |
+
"22252 train 1 6736 \n",
|
| 1276 |
+
"22253 train 1 126057105, 90665 \n",
|
| 1277 |
+
"\n",
|
| 1278 |
+
"[22150 rows x 13 columns]"
|
| 1279 |
+
]
|
| 1280 |
+
},
|
| 1281 |
+
"execution_count": 39,
|
| 1282 |
+
"metadata": {},
|
| 1283 |
+
"output_type": "execute_result"
|
| 1284 |
+
}
|
| 1285 |
+
],
|
| 1286 |
+
"source": [
|
| 1287 |
+
"clinvar_raw\n",
|
| 1288 |
+
"\n",
|
| 1289 |
+
"# Display the dataset with extracted gene information\n",
|
| 1290 |
+
"if clinvar_raw is not None:\n",
|
| 1291 |
+
" print(f\"📊 Dataset after gene extraction: {clinvar_raw.shape}\")\n",
|
| 1292 |
+
" \n",
|
| 1293 |
+
" # Show statistics\n",
|
| 1294 |
+
" gene_filled = (clinvar_raw['GENE'] != '').sum()\n",
|
| 1295 |
+
" gene_id_filled = (clinvar_raw['GENE_ID'] != '').sum()\n",
|
| 1296 |
+
" \n",
|
| 1297 |
+
" print(f\"📋 Entries with gene names: {gene_filled} ({gene_filled/len(clinvar_raw)*100:.1f}%)\")\n",
|
| 1298 |
+
" print(f\"📋 Entries with gene IDs: {gene_id_filled} ({gene_id_filled/len(clinvar_raw)*100:.1f}%)\")\n",
|
| 1299 |
+
" \n",
|
| 1300 |
+
" # Show sample data\n",
|
| 1301 |
+
" display(clinvar_raw.head(10))\n",
|
| 1302 |
+
"else:\n",
|
| 1303 |
+
" print(\"❌ No data to display\")"
|
| 1304 |
+
]
|
| 1305 |
+
},
|
| 1306 |
+
{
|
| 1307 |
+
"cell_type": "markdown",
|
| 1308 |
+
"id": "b76910bd-aa86-4943-a0f2-dcf9756ad81d",
|
| 1309 |
+
"metadata": {},
|
| 1310 |
+
"source": [
|
| 1311 |
+
"## Disease/Phenotype Information Extraction\n",
|
| 1312 |
+
"\n",
|
| 1313 |
+
"This section extracts disease and phenotype information from the ClinVar XML records. Each variant may be associated with multiple diseases, so the data is expanded to create one row per variant-disease combination.\n",
|
| 1314 |
+
"\n",
|
| 1315 |
+
"### Putting in the Disease Name"
|
| 1316 |
+
]
|
| 1317 |
+
},
|
| 1318 |
+
{
|
| 1319 |
+
"cell_type": "code",
|
| 1320 |
+
"execution_count": null,
|
| 1321 |
+
"id": "54ccd972-5804-4d63-9012-5531034d2b60",
|
| 1322 |
+
"metadata": {},
|
| 1323 |
+
"outputs": [],
|
| 1324 |
+
"source": [
|
| 1325 |
+
"# Extract disease/phenotype information from ClinVar XML files\n",
|
| 1326 |
+
"# This creates multiple rows for variants associated with multiple diseases\n",
|
| 1327 |
+
"\n",
|
| 1328 |
+
"if clinvar_raw is not None:\n",
|
| 1329 |
+
" try:\n",
|
| 1330 |
+
" # Load ClinVar IDs\n",
|
| 1331 |
+
" with open(\"Clinvar_ID.txt\", \"r\") as f:\n",
|
| 1332 |
+
" clinvar_ids = [line.strip() for line in f if line.strip()]\n",
|
| 1333 |
+
" \n",
|
| 1334 |
+
" print(f\"📋 Processing {len(clinvar_ids)} ClinVar IDs for disease extraction\")\n",
|
| 1335 |
+
" \n",
|
| 1336 |
+
" # Ensure ID column is integer type\n",
|
| 1337 |
+
" clinvar_raw[\"ID\"] = clinvar_raw[\"ID\"].astype(int)\n",
|
| 1338 |
+
" \n",
|
| 1339 |
+
" # Create new DataFrame to store expanded data\n",
|
| 1340 |
+
" clinvar_data = pd.DataFrame(columns=clinvar_raw.columns.tolist() + [\"Disease\"])\n",
|
| 1341 |
+
" \n",
|
| 1342 |
+
" processed_count = 0\n",
|
| 1343 |
+
" disease_count = 0\n",
|
| 1344 |
+
" \n",
|
| 1345 |
+
" # Process each ClinVar ID\n",
|
| 1346 |
+
" for i, clinvar_id in enumerate(clinvar_ids):\n",
|
| 1347 |
+
" if i % 100 == 0: # Progress indicator\n",
|
| 1348 |
+
" print(f\"📊 Processing disease info {i+1}/{len(clinvar_ids)}...\")\n",
|
| 1349 |
+
" \n",
|
| 1350 |
+
" try:\n",
|
| 1351 |
+
" id_int = int(clinvar_id)\n",
|
| 1352 |
+
" xml_path = f\"data/{clinvar_id}.xml\"\n",
|
| 1353 |
+
" \n",
|
| 1354 |
+
" if not os.path.exists(xml_path):\n",
|
| 1355 |
+
" continue\n",
|
| 1356 |
+
" \n",
|
| 1357 |
+
" # Parse XML\n",
|
| 1358 |
+
" tree = ET.parse(xml_path)\n",
|
| 1359 |
+
" root = tree.getroot()\n",
|
| 1360 |
+
" \n",
|
| 1361 |
+
" # Extract all trait names (diseases/phenotypes)\n",
|
| 1362 |
+
" trait_names = []\n",
|
| 1363 |
+
" for trait in root.findall(\".//trait\"):\n",
|
| 1364 |
+
" trait_name = trait.findtext(\"trait_name\")\n",
|
| 1365 |
+
" if trait_name:\n",
|
| 1366 |
+
" trait_names.append(trait_name)\n",
|
| 1367 |
+
" \n",
|
| 1368 |
+
" # Filter out 'not provided' if other traits exist\n",
|
| 1369 |
+
" filtered_traits = [t for t in trait_names if t.lower() != \"not provided\"]\n",
|
| 1370 |
+
" if not filtered_traits and \"not provided\" in [t.lower() for t in trait_names]:\n",
|
| 1371 |
+
" filtered_traits = [\"not provided\"]\n",
|
| 1372 |
+
" \n",
|
| 1373 |
+
" # If no traits found, use empty string\n",
|
| 1374 |
+
" if not filtered_traits:\n",
|
| 1375 |
+
" filtered_traits = [\"\"]\n",
|
| 1376 |
+
" \n",
|
| 1377 |
+
" # Create one row for each disease/trait\n",
|
| 1378 |
+
" base_row = clinvar_raw[clinvar_raw[\"ID\"] == id_int]\n",
|
| 1379 |
+
" if not base_row.empty:\n",
|
| 1380 |
+
" for disease_name in filtered_traits:\n",
|
| 1381 |
+
" new_row = base_row.copy()\n",
|
| 1382 |
+
" new_row[\"Disease\"] = disease_name\n",
|
| 1383 |
+
" clinvar_data = pd.concat([clinvar_data, new_row], ignore_index=True)\n",
|
| 1384 |
+
" disease_count += 1\n",
|
| 1385 |
+
" processed_count += 1\n",
|
| 1386 |
+
" \n",
|
| 1387 |
+
" except ET.ParseError as e:\n",
|
| 1388 |
+
" print(f\"⚠️ XML parsing error for {clinvar_id}: {e}\")\n",
|
| 1389 |
+
" except Exception as e:\n",
|
| 1390 |
+
" print(f\"⚠️ Error processing {clinvar_id}: {e}\")\n",
|
| 1391 |
+
" \n",
|
| 1392 |
+
" print(f\"\\n✅ Disease extraction complete:\")\n",
|
| 1393 |
+
" print(f\" 📊 Variants processed: {processed_count}\")\n",
|
| 1394 |
+
" print(f\" 🔬 Total variant-disease pairs: {disease_count}\")\n",
|
| 1395 |
+
" print(f\" 📋 Final dataset shape: {clinvar_data.shape}\")\n",
|
| 1396 |
+
" \n",
|
| 1397 |
+
" # Save intermediate results\n",
|
| 1398 |
+
" clinvar_data.to_csv(\"clinvar_with_disease.csv\", sep='\\t', index=False)\n",
|
| 1399 |
+
" print(\"💾 Saved results to clinvar_with_disease.csv\")\n",
|
| 1400 |
+
" \n",
|
| 1401 |
+
" except FileNotFoundError:\n",
|
| 1402 |
+
" print(\"❌ Error: Required files not found\")\n",
|
| 1403 |
+
" print(\"Please ensure Clinvar_ID.txt exists and XML files are downloaded\")\n",
|
| 1404 |
+
" clinvar_data = None\n",
|
| 1405 |
+
" except Exception as e:\n",
|
| 1406 |
+
" print(f\"❌ Error during disease extraction: {e}\")\n",
|
| 1407 |
+
" clinvar_data = None\n",
|
| 1408 |
+
"else:\n",
|
| 1409 |
+
" print(\"⚠️ Cannot extract diseases - ClinVar data not loaded\")\n",
|
| 1410 |
+
" clinvar_data = None"
|
| 1411 |
+
]
|
| 1412 |
+
},
|
| 1413 |
+
{
|
| 1414 |
+
"cell_type": "code",
|
| 1415 |
+
"execution_count": null,
|
| 1416 |
+
"id": "277445cd-72b9-44a4-a257-49cd3202e501",
|
| 1417 |
+
"metadata": {
|
| 1418 |
+
"scrolled": true
|
| 1419 |
+
},
|
| 1420 |
+
"outputs": [
|
| 1421 |
+
{
|
| 1422 |
+
"data": {
|
| 1423 |
+
"text/html": [
|
| 1424 |
+
"<div>\n",
|
| 1425 |
+
"<style scoped>\n",
|
| 1426 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 1427 |
+
" vertical-align: middle;\n",
|
| 1428 |
+
" }\n",
|
| 1429 |
+
"\n",
|
| 1430 |
+
" .dataframe tbody tr th {\n",
|
| 1431 |
+
" vertical-align: top;\n",
|
| 1432 |
+
" }\n",
|
| 1433 |
+
"\n",
|
| 1434 |
+
" .dataframe thead th {\n",
|
| 1435 |
+
" text-align: right;\n",
|
| 1436 |
+
" }\n",
|
| 1437 |
+
"</style>\n",
|
| 1438 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 1439 |
+
" <thead>\n",
|
| 1440 |
+
" <tr style=\"text-align: right;\">\n",
|
| 1441 |
+
" <th></th>\n",
|
| 1442 |
+
" <th>CHROM</th>\n",
|
| 1443 |
+
" <th>POS</th>\n",
|
| 1444 |
+
" <th>REF</th>\n",
|
| 1445 |
+
" <th>ALT</th>\n",
|
| 1446 |
+
" <th>LABEL</th>\n",
|
| 1447 |
+
" <th>SOURCE</th>\n",
|
| 1448 |
+
" <th>CONSEQUENCE</th>\n",
|
| 1449 |
+
" <th>ID</th>\n",
|
| 1450 |
+
" <th>REVIEW_STATUS</th>\n",
|
| 1451 |
+
" <th>GENE</th>\n",
|
| 1452 |
+
" <th>split</th>\n",
|
| 1453 |
+
" <th>INT_LABEL</th>\n",
|
| 1454 |
+
" <th>GENE_ID</th>\n",
|
| 1455 |
+
" <th>Disease</th>\n",
|
| 1456 |
+
" </tr>\n",
|
| 1457 |
+
" </thead>\n",
|
| 1458 |
+
" <tbody>\n",
|
| 1459 |
+
" <tr>\n",
|
| 1460 |
+
" <th>0</th>\n",
|
| 1461 |
+
" <td>chr1</td>\n",
|
| 1462 |
+
" <td>976215</td>\n",
|
| 1463 |
+
" <td>A</td>\n",
|
| 1464 |
+
" <td>G</td>\n",
|
| 1465 |
+
" <td>Pathogenic</td>\n",
|
| 1466 |
+
" <td>ClinVar</td>\n",
|
| 1467 |
+
" <td>missense_variant</td>\n",
|
| 1468 |
+
" <td>1320032</td>\n",
|
| 1469 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1470 |
+
" <td>PERM1</td>\n",
|
| 1471 |
+
" <td>train</td>\n",
|
| 1472 |
+
" <td>1</td>\n",
|
| 1473 |
+
" <td>84808</td>\n",
|
| 1474 |
+
" <td>Renal tubular epithelial cell apoptosis</td>\n",
|
| 1475 |
+
" </tr>\n",
|
| 1476 |
+
" <tr>\n",
|
| 1477 |
+
" <th>1</th>\n",
|
| 1478 |
+
" <td>chr1</td>\n",
|
| 1479 |
+
" <td>976215</td>\n",
|
| 1480 |
+
" <td>A</td>\n",
|
| 1481 |
+
" <td>G</td>\n",
|
| 1482 |
+
" <td>Pathogenic</td>\n",
|
| 1483 |
+
" <td>ClinVar</td>\n",
|
| 1484 |
+
" <td>missense_variant</td>\n",
|
| 1485 |
+
" <td>1320032</td>\n",
|
| 1486 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1487 |
+
" <td>PERM1</td>\n",
|
| 1488 |
+
" <td>train</td>\n",
|
| 1489 |
+
" <td>1</td>\n",
|
| 1490 |
+
" <td>84808</td>\n",
|
| 1491 |
+
" <td>Neutrophil inclusion bodies</td>\n",
|
| 1492 |
+
" </tr>\n",
|
| 1493 |
+
" <tr>\n",
|
| 1494 |
+
" <th>2</th>\n",
|
| 1495 |
+
" <td>chr1</td>\n",
|
| 1496 |
+
" <td>1050449</td>\n",
|
| 1497 |
+
" <td>G</td>\n",
|
| 1498 |
+
" <td>A</td>\n",
|
| 1499 |
+
" <td>Pathogenic</td>\n",
|
| 1500 |
+
" <td>ClinVar</td>\n",
|
| 1501 |
+
" <td>missense_variant</td>\n",
|
| 1502 |
+
" <td>1284257</td>\n",
|
| 1503 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1504 |
+
" <td>AGRN</td>\n",
|
| 1505 |
+
" <td>train</td>\n",
|
| 1506 |
+
" <td>1</td>\n",
|
| 1507 |
+
" <td>375790</td>\n",
|
| 1508 |
+
" <td>Congenital myasthenic syndrome 8</td>\n",
|
| 1509 |
+
" </tr>\n",
|
| 1510 |
+
" <tr>\n",
|
| 1511 |
+
" <th>3</th>\n",
|
| 1512 |
+
" <td>chr1</td>\n",
|
| 1513 |
+
" <td>1050575</td>\n",
|
| 1514 |
+
" <td>G</td>\n",
|
| 1515 |
+
" <td>C</td>\n",
|
| 1516 |
+
" <td>Pathogenic</td>\n",
|
| 1517 |
+
" <td>ClinVar</td>\n",
|
| 1518 |
+
" <td>missense_variant</td>\n",
|
| 1519 |
+
" <td>18241</td>\n",
|
| 1520 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1521 |
+
" <td>AGRN</td>\n",
|
| 1522 |
+
" <td>train</td>\n",
|
| 1523 |
+
" <td>1</td>\n",
|
| 1524 |
+
" <td>375790</td>\n",
|
| 1525 |
+
" <td>Congenital myasthenic syndrome 8</td>\n",
|
| 1526 |
+
" </tr>\n",
|
| 1527 |
+
" <tr>\n",
|
| 1528 |
+
" <th>4</th>\n",
|
| 1529 |
+
" <td>chr1</td>\n",
|
| 1530 |
+
" <td>1213738</td>\n",
|
| 1531 |
+
" <td>G</td>\n",
|
| 1532 |
+
" <td>A</td>\n",
|
| 1533 |
+
" <td>Pathogenic</td>\n",
|
| 1534 |
+
" <td>ClinVar</td>\n",
|
| 1535 |
+
" <td>missense_variant</td>\n",
|
| 1536 |
+
" <td>96692</td>\n",
|
| 1537 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1538 |
+
" <td>TNFRSF4</td>\n",
|
| 1539 |
+
" <td>train</td>\n",
|
| 1540 |
+
" <td>1</td>\n",
|
| 1541 |
+
" <td>7293</td>\n",
|
| 1542 |
+
" <td>Combined immunodeficiency due to OX40 deficiency</td>\n",
|
| 1543 |
+
" </tr>\n",
|
| 1544 |
+
" <tr>\n",
|
| 1545 |
+
" <th>...</th>\n",
|
| 1546 |
+
" <td>...</td>\n",
|
| 1547 |
+
" <td>...</td>\n",
|
| 1548 |
+
" <td>...</td>\n",
|
| 1549 |
+
" <td>...</td>\n",
|
| 1550 |
+
" <td>...</td>\n",
|
| 1551 |
+
" <td>...</td>\n",
|
| 1552 |
+
" <td>...</td>\n",
|
| 1553 |
+
" <td>...</td>\n",
|
| 1554 |
+
" <td>...</td>\n",
|
| 1555 |
+
" <td>...</td>\n",
|
| 1556 |
+
" <td>...</td>\n",
|
| 1557 |
+
" <td>...</td>\n",
|
| 1558 |
+
" <td>...</td>\n",
|
| 1559 |
+
" <td>...</td>\n",
|
| 1560 |
+
" </tr>\n",
|
| 1561 |
+
" <tr>\n",
|
| 1562 |
+
" <th>32680</th>\n",
|
| 1563 |
+
" <td>chrY</td>\n",
|
| 1564 |
+
" <td>2787412</td>\n",
|
| 1565 |
+
" <td>C</td>\n",
|
| 1566 |
+
" <td>T</td>\n",
|
| 1567 |
+
" <td>Pathogenic</td>\n",
|
| 1568 |
+
" <td>ClinVar</td>\n",
|
| 1569 |
+
" <td>missense_variant</td>\n",
|
| 1570 |
+
" <td>9747</td>\n",
|
| 1571 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1572 |
+
" <td>SRY</td>\n",
|
| 1573 |
+
" <td>train</td>\n",
|
| 1574 |
+
" <td>1</td>\n",
|
| 1575 |
+
" <td>6736</td>\n",
|
| 1576 |
+
" <td>46,XY sex reversal 1</td>\n",
|
| 1577 |
+
" </tr>\n",
|
| 1578 |
+
" <tr>\n",
|
| 1579 |
+
" <th>32681</th>\n",
|
| 1580 |
+
" <td>chrY</td>\n",
|
| 1581 |
+
" <td>2787426</td>\n",
|
| 1582 |
+
" <td>C</td>\n",
|
| 1583 |
+
" <td>G</td>\n",
|
| 1584 |
+
" <td>Pathogenic</td>\n",
|
| 1585 |
+
" <td>ClinVar</td>\n",
|
| 1586 |
+
" <td>missense_variant</td>\n",
|
| 1587 |
+
" <td>9739</td>\n",
|
| 1588 |
+
" <td>criteria_provided,_single_submitter</td>\n",
|
| 1589 |
+
" <td>SRY</td>\n",
|
| 1590 |
+
" <td>train</td>\n",
|
| 1591 |
+
" <td>1</td>\n",
|
| 1592 |
+
" <td>6736</td>\n",
|
| 1593 |
+
" <td>not provided</td>\n",
|
| 1594 |
+
" </tr>\n",
|
| 1595 |
+
" <tr>\n",
|
| 1596 |
+
" <th>32682</th>\n",
|
| 1597 |
+
" <td>chrY</td>\n",
|
| 1598 |
+
" <td>2787515</td>\n",
|
| 1599 |
+
" <td>C</td>\n",
|
| 1600 |
+
" <td>A</td>\n",
|
| 1601 |
+
" <td>Pathogenic</td>\n",
|
| 1602 |
+
" <td>ClinVar</td>\n",
|
| 1603 |
+
" <td>missense_variant</td>\n",
|
| 1604 |
+
" <td>492908</td>\n",
|
| 1605 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1606 |
+
" <td>SRY</td>\n",
|
| 1607 |
+
" <td>train</td>\n",
|
| 1608 |
+
" <td>1</td>\n",
|
| 1609 |
+
" <td>6736</td>\n",
|
| 1610 |
+
" <td>46,XY sex reversal 1</td>\n",
|
| 1611 |
+
" </tr>\n",
|
| 1612 |
+
" <tr>\n",
|
| 1613 |
+
" <th>32683</th>\n",
|
| 1614 |
+
" <td>chrY</td>\n",
|
| 1615 |
+
" <td>2787551</td>\n",
|
| 1616 |
+
" <td>C</td>\n",
|
| 1617 |
+
" <td>T</td>\n",
|
| 1618 |
+
" <td>Pathogenic</td>\n",
|
| 1619 |
+
" <td>ClinVar</td>\n",
|
| 1620 |
+
" <td>missense_variant</td>\n",
|
| 1621 |
+
" <td>9754</td>\n",
|
| 1622 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1623 |
+
" <td>SRY</td>\n",
|
| 1624 |
+
" <td>train</td>\n",
|
| 1625 |
+
" <td>1</td>\n",
|
| 1626 |
+
" <td>6736</td>\n",
|
| 1627 |
+
" <td>46,XY sex reversal 1</td>\n",
|
| 1628 |
+
" </tr>\n",
|
| 1629 |
+
" <tr>\n",
|
| 1630 |
+
" <th>32684</th>\n",
|
| 1631 |
+
" <td>chrY</td>\n",
|
| 1632 |
+
" <td>7063898</td>\n",
|
| 1633 |
+
" <td>A</td>\n",
|
| 1634 |
+
" <td>T</td>\n",
|
| 1635 |
+
" <td>Pathogenic</td>\n",
|
| 1636 |
+
" <td>ClinVar</td>\n",
|
| 1637 |
+
" <td>missense_variant</td>\n",
|
| 1638 |
+
" <td>625467</td>\n",
|
| 1639 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 1640 |
+
" <td>LOC126057105, TBL1Y</td>\n",
|
| 1641 |
+
" <td>train</td>\n",
|
| 1642 |
+
" <td>1</td>\n",
|
| 1643 |
+
" <td>126057105, 90665</td>\n",
|
| 1644 |
+
" <td>Deafness, Y-linked 2</td>\n",
|
| 1645 |
+
" </tr>\n",
|
| 1646 |
+
" </tbody>\n",
|
| 1647 |
+
"</table>\n",
|
| 1648 |
+
"<p>32685 rows × 14 columns</p>\n",
|
| 1649 |
+
"</div>"
|
| 1650 |
+
],
|
| 1651 |
+
"text/plain": [
|
| 1652 |
+
" CHROM POS REF ALT LABEL SOURCE CONSEQUENCE ID \\\n",
|
| 1653 |
+
"0 chr1 976215 A G Pathogenic ClinVar missense_variant 1320032 \n",
|
| 1654 |
+
"1 chr1 976215 A G Pathogenic ClinVar missense_variant 1320032 \n",
|
| 1655 |
+
"2 chr1 1050449 G A Pathogenic ClinVar missense_variant 1284257 \n",
|
| 1656 |
+
"3 chr1 1050575 G C Pathogenic ClinVar missense_variant 18241 \n",
|
| 1657 |
+
"4 chr1 1213738 G A Pathogenic ClinVar missense_variant 96692 \n",
|
| 1658 |
+
"... ... ... .. .. ... ... ... ... \n",
|
| 1659 |
+
"32680 chrY 2787412 C T Pathogenic ClinVar missense_variant 9747 \n",
|
| 1660 |
+
"32681 chrY 2787426 C G Pathogenic ClinVar missense_variant 9739 \n",
|
| 1661 |
+
"32682 chrY 2787515 C A Pathogenic ClinVar missense_variant 492908 \n",
|
| 1662 |
+
"32683 chrY 2787551 C T Pathogenic ClinVar missense_variant 9754 \n",
|
| 1663 |
+
"32684 chrY 7063898 A T Pathogenic ClinVar missense_variant 625467 \n",
|
| 1664 |
+
"\n",
|
| 1665 |
+
" REVIEW_STATUS GENE split \\\n",
|
| 1666 |
+
"0 no_assertion_criteria_provided PERM1 train \n",
|
| 1667 |
+
"1 no_assertion_criteria_provided PERM1 train \n",
|
| 1668 |
+
"2 no_assertion_criteria_provided AGRN train \n",
|
| 1669 |
+
"3 no_assertion_criteria_provided AGRN train \n",
|
| 1670 |
+
"4 no_assertion_criteria_provided TNFRSF4 train \n",
|
| 1671 |
+
"... ... ... ... \n",
|
| 1672 |
+
"32680 no_assertion_criteria_provided SRY train \n",
|
| 1673 |
+
"32681 criteria_provided,_single_submitter SRY train \n",
|
| 1674 |
+
"32682 no_assertion_criteria_provided SRY train \n",
|
| 1675 |
+
"32683 no_assertion_criteria_provided SRY train \n",
|
| 1676 |
+
"32684 no_assertion_criteria_provided LOC126057105, TBL1Y train \n",
|
| 1677 |
+
"\n",
|
| 1678 |
+
" INT_LABEL GENE_ID \\\n",
|
| 1679 |
+
"0 1 84808 \n",
|
| 1680 |
+
"1 1 84808 \n",
|
| 1681 |
+
"2 1 375790 \n",
|
| 1682 |
+
"3 1 375790 \n",
|
| 1683 |
+
"4 1 7293 \n",
|
| 1684 |
+
"... ... ... \n",
|
| 1685 |
+
"32680 1 6736 \n",
|
| 1686 |
+
"32681 1 6736 \n",
|
| 1687 |
+
"32682 1 6736 \n",
|
| 1688 |
+
"32683 1 6736 \n",
|
| 1689 |
+
"32684 1 126057105, 90665 \n",
|
| 1690 |
+
"\n",
|
| 1691 |
+
" Disease \n",
|
| 1692 |
+
"0 Renal tubular epithelial cell apoptosis \n",
|
| 1693 |
+
"1 Neutrophil inclusion bodies \n",
|
| 1694 |
+
"2 Congenital myasthenic syndrome 8 \n",
|
| 1695 |
+
"3 Congenital myasthenic syndrome 8 \n",
|
| 1696 |
+
"4 Combined immunodeficiency due to OX40 deficiency \n",
|
| 1697 |
+
"... ... \n",
|
| 1698 |
+
"32680 46,XY sex reversal 1 \n",
|
| 1699 |
+
"32681 not provided \n",
|
| 1700 |
+
"32682 46,XY sex reversal 1 \n",
|
| 1701 |
+
"32683 46,XY sex reversal 1 \n",
|
| 1702 |
+
"32684 Deafness, Y-linked 2 \n",
|
| 1703 |
+
"\n",
|
| 1704 |
+
"[32685 rows x 14 columns]"
|
| 1705 |
+
]
|
| 1706 |
+
},
|
| 1707 |
+
"execution_count": 51,
|
| 1708 |
+
"metadata": {},
|
| 1709 |
+
"output_type": "execute_result"
|
| 1710 |
+
}
|
| 1711 |
+
],
|
| 1712 |
+
"source": [
|
| 1713 |
+
"clinvar_data\n",
|
| 1714 |
+
"\n",
|
| 1715 |
+
"# Display the dataset with disease information\n",
|
| 1716 |
+
"if 'clinvar_data' in locals() and clinvar_data is not None:\n",
|
| 1717 |
+
" print(f\"📊 Dataset with diseases: {clinvar_data.shape}\")\n",
|
| 1718 |
+
" \n",
|
| 1719 |
+
" # Show disease statistics\n",
|
| 1720 |
+
" disease_counts = clinvar_data['Disease'].value_counts()\n",
|
| 1721 |
+
" print(f\"\\n🔬 Disease distribution (top 10):\")\n",
|
| 1722 |
+
" print(disease_counts.head(10))\n",
|
| 1723 |
+
" \n",
|
| 1724 |
+
" # Show sample data\n",
|
| 1725 |
+
" print(\"\\n🔍 Sample data:\")\n",
|
| 1726 |
+
" display(clinvar_data.head())\n",
|
| 1727 |
+
"else:\n",
|
| 1728 |
+
" print(\"❌ No disease data to display\")"
|
| 1729 |
+
]
|
| 1730 |
+
},
|
| 1731 |
+
{
|
| 1732 |
+
"cell_type": "code",
|
| 1733 |
+
"execution_count": null,
|
| 1734 |
+
"id": "c6b1c6dc-33ed-4f57-a385-29816f4c9984",
|
| 1735 |
+
"metadata": {},
|
| 1736 |
+
"outputs": [
|
| 1737 |
+
{
|
| 1738 |
+
"data": {
|
| 1739 |
+
"text/plain": [
|
| 1740 |
+
"np.int64(2749)"
|
| 1741 |
+
]
|
| 1742 |
+
},
|
| 1743 |
+
"execution_count": 53,
|
| 1744 |
+
"metadata": {},
|
| 1745 |
+
"output_type": "execute_result"
|
| 1746 |
+
}
|
| 1747 |
+
],
|
| 1748 |
+
"source": [
|
| 1749 |
+
"# Count entries with 'not provided' disease information\n",
|
| 1750 |
+
"if 'clinvar_data' in locals() and clinvar_data is not None:\n",
|
| 1751 |
+
" not_provided_count = (clinvar_data[\"Disease\"] == \"not provided\").sum()\n",
|
| 1752 |
+
" total_count = len(clinvar_data)\n",
|
| 1753 |
+
" \n",
|
| 1754 |
+
" print(f\"📊 Entries with 'not provided' disease: {not_provided_count}\")\n",
|
| 1755 |
+
" print(f\"📊 Total entries: {total_count}\")\n",
|
| 1756 |
+
" print(f\"📊 Percentage: {not_provided_count/total_count*100:.1f}%\")\n",
|
| 1757 |
+
"else:\n",
|
| 1758 |
+
" print(\"❌ Cannot calculate statistics - data not available\")"
|
| 1759 |
+
]
|
| 1760 |
+
},
|
| 1761 |
+
{
|
| 1762 |
+
"cell_type": "markdown",
|
| 1763 |
+
"id": "8a7513ee-96b2-4c7d-8678-0195eb826aa5",
|
| 1764 |
+
"metadata": {},
|
| 1765 |
+
"source": [
|
| 1766 |
+
"## Gene ID to Gene Name Mapping\n",
|
| 1767 |
+
"\n",
|
| 1768 |
+
"This section converts gene IDs to human-readable gene names using NCBI Entrez utilities.\n",
|
| 1769 |
+
"\n",
|
| 1770 |
+
"**Prerequisites**: NCBI Entrez Direct tools must be installed:\n",
|
| 1771 |
+
"- macOS: `brew install brewsci/bio/edirect`\n",
|
| 1772 |
+
"- Linux: Follow NCBI EDirect installation guide\n",
|
| 1773 |
+
"\n",
|
| 1774 |
+
"The process:\n",
|
| 1775 |
+
"1. Extract unique gene IDs from the dataset\n",
|
| 1776 |
+
"2. Use `esummary` to fetch gene descriptions from NCBI\n",
|
| 1777 |
+
"3. Create a mapping dictionary\n",
|
| 1778 |
+
"4. Apply the mapping to add gene names to the dataset"
|
| 1779 |
+
]
|
| 1780 |
+
},
|
| 1781 |
+
{
|
| 1782 |
+
"cell_type": "code",
|
| 1783 |
+
"execution_count": null,
|
| 1784 |
+
"id": "ee0d3632-d11e-4429-bb50-5eb9ba55d424",
|
| 1785 |
+
"metadata": {},
|
| 1786 |
+
"outputs": [],
|
| 1787 |
+
"source": [
|
| 1788 |
+
"#!/usr/bin/env python3\n",
|
| 1789 |
+
"\n",
|
| 1790 |
+
"import os\n",
|
| 1791 |
+
"import pandas as pd\n",
|
| 1792 |
+
"\n",
|
| 1793 |
+
"# Extract unique gene IDs and create mapping file\n",
|
| 1794 |
+
"# This prepares the gene ID list for NCBI lookup\n",
|
| 1795 |
+
"\n",
|
| 1796 |
+
"if 'clinvar_data' in locals() and clinvar_data is not None:\n",
|
| 1797 |
+
" # Extract all unique gene IDs\n",
|
| 1798 |
+
" all_gene_ids = set()\n",
|
| 1799 |
+
" \n",
|
| 1800 |
+
" for gene_id_str in clinvar_data['GENE_ID'].dropna():\n",
|
| 1801 |
+
" if gene_id_str.strip(): # Skip empty strings\n",
|
| 1802 |
+
" # Split comma-separated IDs\n",
|
| 1803 |
+
" ids = [gid.strip() for gid in gene_id_str.split(',') if gid.strip()]\n",
|
| 1804 |
+
" all_gene_ids.update(ids)\n",
|
| 1805 |
+
" \n",
|
| 1806 |
+
" # Save unique gene IDs to file\n",
|
| 1807 |
+
" with open(\"gene_id.txt\", 'w') as f:\n",
|
| 1808 |
+
" for gene_id in sorted(all_gene_ids):\n",
|
| 1809 |
+
" f.write(f\"{gene_id}\\n\")\n",
|
| 1810 |
+
" \n",
|
| 1811 |
+
" print(f\"✅ Extracted {len(all_gene_ids)} unique gene IDs to gene_id.txt\")\n",
|
| 1812 |
+
" \n",
|
| 1813 |
+
" # Create the shell script for NCBI lookup\n",
|
| 1814 |
+
" script_content = '''#!/bin/bash\n",
|
| 1815 |
+
"\n",
|
| 1816 |
+
"input_file=\"gene_id.txt\"\n",
|
| 1817 |
+
"output_file=\"gene_id_to_name.json\"\n",
|
| 1818 |
+
"\n",
|
| 1819 |
+
"# Check if input file exists\n",
|
| 1820 |
+
"if [ ! -f \"$input_file\" ]; then\n",
|
| 1821 |
+
" echo \"❌ Error: $input_file not found\"\n",
|
| 1822 |
+
" exit 1\n",
|
| 1823 |
+
"fi\n",
|
| 1824 |
+
"\n",
|
| 1825 |
+
"# Check if EDirect tools are available\n",
|
| 1826 |
+
"if ! command -v esummary &> /dev/null; then\n",
|
| 1827 |
+
" echo \"❌ Error: NCBI EDirect tools not found\"\n",
|
| 1828 |
+
" echo \"Please install: brew install brewsci/bio/edirect (macOS)\"\n",
|
| 1829 |
+
" exit 1\n",
|
| 1830 |
+
"fi\n",
|
| 1831 |
+
"\n",
|
| 1832 |
+
"echo \"🚀 Starting gene ID to name mapping...\"\n",
|
| 1833 |
+
"\n",
|
| 1834 |
+
"# Start JSON object\n",
|
| 1835 |
+
"echo \"{\" > \"$output_file\"\n",
|
| 1836 |
+
"\n",
|
| 1837 |
+
"first_entry=true\n",
|
| 1838 |
+
"total_lines=$(wc -l < \"$input_file\")\n",
|
| 1839 |
+
"current_line=0\n",
|
| 1840 |
+
"\n",
|
| 1841 |
+
"while IFS= read -r gene_id; do\n",
|
| 1842 |
+
" # Skip empty lines\n",
|
| 1843 |
+
" [[ -z \"$gene_id\" ]] && continue\n",
|
| 1844 |
+
" \n",
|
| 1845 |
+
" current_line=$((current_line + 1))\n",
|
| 1846 |
+
" \n",
|
| 1847 |
+
" # Progress indicator\n",
|
| 1848 |
+
" if (( current_line % 50 == 0 )); then\n",
|
| 1849 |
+
" echo \"📊 Processing $current_line/$total_lines gene IDs...\"\n",
|
| 1850 |
+
" fi\n",
|
| 1851 |
+
" \n",
|
| 1852 |
+
" # Fetch gene description using Entrez Direct\n",
|
| 1853 |
+
" description=$(esummary -db gene -id \"$gene_id\" 2>/dev/null | xtract -pattern DocumentSummary -element Description)\n",
|
| 1854 |
+
" \n",
|
| 1855 |
+
" # Handle empty description\n",
|
| 1856 |
+
" if [ -z \"$description\" ]; then\n",
|
| 1857 |
+
" description=\"Unknown\"\n",
|
| 1858 |
+
" fi\n",
|
| 1859 |
+
" \n",
|
| 1860 |
+
" # JSON escape quotes and other special characters\n",
|
| 1861 |
+
" description=$(printf '%s' \"$description\" | sed 's/\"/\\\\\"/g')\n",
|
| 1862 |
+
" \n",
|
| 1863 |
+
" # Add comma if not the first entry\n",
|
| 1864 |
+
" if [ \"$first_entry\" = true ]; then\n",
|
| 1865 |
+
" first_entry=false\n",
|
| 1866 |
+
" else\n",
|
| 1867 |
+
" echo \",\" >> \"$output_file\"\n",
|
| 1868 |
+
" fi\n",
|
| 1869 |
+
" \n",
|
| 1870 |
+
" # Append key-value pair\n",
|
| 1871 |
+
" echo \" \\\"$gene_id\\\": \\\"$description\\\"\" >> \"$output_file\"\n",
|
| 1872 |
+
" \n",
|
| 1873 |
+
"done < \"$input_file\"\n",
|
| 1874 |
+
"\n",
|
| 1875 |
+
"# Close JSON object\n",
|
| 1876 |
+
"echo \"\" >> \"$output_file\"\n",
|
| 1877 |
+
"echo \"}\" >> \"$output_file\"\n",
|
| 1878 |
+
"\n",
|
| 1879 |
+
"echo \"✅ Gene ID to name mapping completed\"\n",
|
| 1880 |
+
"echo \"💾 Results saved to $output_file\"\n",
|
| 1881 |
+
"'''\n",
|
| 1882 |
+
" \n",
|
| 1883 |
+
" # Write the script\n",
|
| 1884 |
+
" with open(\"gene_mapping.sh\", 'w') as f:\n",
|
| 1885 |
+
" f.write(script_content)\n",
|
| 1886 |
+
" \n",
|
| 1887 |
+
" # Make executable\n",
|
| 1888 |
+
" os.chmod(\"gene_mapping.sh\", 0o755)\n",
|
| 1889 |
+
" \n",
|
| 1890 |
+
" print(\"✅ Created gene_mapping.sh script\")\n",
|
| 1891 |
+
" print(\"\\n🚀 To run the gene mapping:\")\n",
|
| 1892 |
+
" print(\" ./gene_mapping.sh\")\n",
|
| 1893 |
+
" print(\"\\n⚠️ Note: This requires NCBI EDirect tools to be installed\")\n",
|
| 1894 |
+
" \n",
|
| 1895 |
+
"else:\n",
|
| 1896 |
+
" print(\"⚠️ Cannot create gene mapping - data not available\")"
|
| 1897 |
+
]
|
| 1898 |
+
},
|
| 1899 |
+
{
|
| 1900 |
+
"cell_type": "markdown",
|
| 1901 |
+
"id": "1957ef57-1af8-46a1-8d1b-147f6b423619",
|
| 1902 |
+
"metadata": {},
|
| 1903 |
+
"source": [
|
| 1904 |
+
"## Apply Gene Name Mapping\n",
|
| 1905 |
+
"\n",
|
| 1906 |
+
"Load the gene ID to name mapping and apply it to the dataset to add human-readable gene names.\n",
|
| 1907 |
+
"\n",
|
| 1908 |
+
"Read json and add it to the clinvar_data df"
|
| 1909 |
+
]
|
| 1910 |
+
},
|
| 1911 |
+
{
|
| 1912 |
+
"cell_type": "code",
|
| 1913 |
+
"execution_count": null,
|
| 1914 |
+
"id": "b39be718-c0ae-4aae-b1d8-d0c872947ec2",
|
| 1915 |
+
"metadata": {},
|
| 1916 |
+
"outputs": [],
|
| 1917 |
+
"source": [
|
| 1918 |
+
"import json\n",
|
| 1919 |
+
"\n",
|
| 1920 |
+
"# Load gene ID to name mapping and apply to dataset\n",
|
| 1921 |
+
"\n",
|
| 1922 |
+
"if 'clinvar_data' in locals() and clinvar_data is not None:\n",
|
| 1923 |
+
" try:\n",
|
| 1924 |
+
" # Load gene ID → name mapping\n",
|
| 1925 |
+
" with open(\"gene_id_to_name.json\", \"r\") as f:\n",
|
| 1926 |
+
" gene_id_dict = json.load(f)\n",
|
| 1927 |
+
" \n",
|
| 1928 |
+
" print(f\"✅ Loaded mapping for {len(gene_id_dict)} gene IDs\")\n",
|
| 1929 |
+
" \n",
|
| 1930 |
+
" # Function to convert gene IDs to gene names\n",
|
| 1931 |
+
" def get_gene_names(gene_id_str):\n",
|
| 1932 |
+
" if pd.isna(gene_id_str) or not gene_id_str.strip():\n",
|
| 1933 |
+
" return \"\"\n",
|
| 1934 |
+
" \n",
|
| 1935 |
+
" gene_ids = [gid.strip() for gid in gene_id_str.split(\",\") if gid.strip()]\n",
|
| 1936 |
+
" gene_names = []\n",
|
| 1937 |
+
" \n",
|
| 1938 |
+
" for gid in gene_ids:\n",
|
| 1939 |
+
" gene_name = gene_id_dict.get(gid, f\"Unknown_ID_{gid}\")\n",
|
| 1940 |
+
" gene_names.append(gene_name)\n",
|
| 1941 |
+
" \n",
|
| 1942 |
+
" return \" | \".join(gene_names)\n",
|
| 1943 |
+
" \n",
|
| 1944 |
+
" # Apply mapping to create gene names column\n",
|
| 1945 |
+
" print(\"📊 Applying gene name mapping...\")\n",
|
| 1946 |
+
" clinvar_data[\"GENE_Name\"] = clinvar_data[\"GENE_ID\"].apply(get_gene_names)\n",
|
| 1947 |
+
" \n",
|
| 1948 |
+
" # Statistics\n",
|
| 1949 |
+
" mapped_count = (clinvar_data[\"GENE_Name\"] != \"\").sum()\n",
|
| 1950 |
+
" print(f\"✅ Gene names mapped for {mapped_count} entries ({mapped_count/len(clinvar_data)*100:.1f}%)\")\n",
|
| 1951 |
+
" \n",
|
| 1952 |
+
" # Show sample mappings\n",
|
| 1953 |
+
" sample_data = clinvar_data[clinvar_data[\"GENE_Name\"] != \"\"][[\"GENE_ID\", \"GENE_Name\"]].head()\n",
|
| 1954 |
+
" if not sample_data.empty:\n",
|
| 1955 |
+
" print(\"\\n🔍 Sample gene ID to name mappings:\")\n",
|
| 1956 |
+
" for _, row in sample_data.iterrows():\n",
|
| 1957 |
+
" print(f\" {row['GENE_ID']} → {row['GENE_Name'][:100]}{'...' if len(row['GENE_Name']) > 100 else ''}\")\n",
|
| 1958 |
+
" \n",
|
| 1959 |
+
" except FileNotFoundError:\n",
|
| 1960 |
+
" print(\"❌ Error: gene_id_to_name.json not found\")\n",
|
| 1961 |
+
" print(\"Please run the gene mapping script first: ./gene_mapping.sh\")\n",
|
| 1962 |
+
" # Create empty column as fallback\n",
|
| 1963 |
+
" clinvar_data[\"GENE_Name\"] = \"\"\n",
|
| 1964 |
+
" except json.JSONDecodeError as e:\n",
|
| 1965 |
+
" print(f\"❌ Error parsing JSON mapping file: {e}\")\n",
|
| 1966 |
+
" clinvar_data[\"GENE_Name\"] = \"\"\n",
|
| 1967 |
+
" except Exception as e:\n",
|
| 1968 |
+
" print(f\"❌ Error applying gene mapping: {e}\")\n",
|
| 1969 |
+
" clinvar_data[\"GENE_Name\"] = \"\"\n",
|
| 1970 |
+
"else:\n",
|
| 1971 |
+
" print(\"⚠️ Cannot apply gene mapping - data not available\")"
|
| 1972 |
+
]
|
| 1973 |
+
},
|
| 1974 |
+
{
|
| 1975 |
+
"cell_type": "code",
|
| 1976 |
+
"execution_count": null,
|
| 1977 |
+
"id": "4b7a44c2-7823-47c1-b268-22a1815ffd09",
|
| 1978 |
+
"metadata": {},
|
| 1979 |
+
"outputs": [
|
| 1980 |
+
{
|
| 1981 |
+
"data": {
|
| 1982 |
+
"text/html": [
|
| 1983 |
+
"<div>\n",
|
| 1984 |
+
"<style scoped>\n",
|
| 1985 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
| 1986 |
+
" vertical-align: middle;\n",
|
| 1987 |
+
" }\n",
|
| 1988 |
+
"\n",
|
| 1989 |
+
" .dataframe tbody tr th {\n",
|
| 1990 |
+
" vertical-align: top;\n",
|
| 1991 |
+
" }\n",
|
| 1992 |
+
"\n",
|
| 1993 |
+
" .dataframe thead th {\n",
|
| 1994 |
+
" text-align: right;\n",
|
| 1995 |
+
" }\n",
|
| 1996 |
+
"</style>\n",
|
| 1997 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
| 1998 |
+
" <thead>\n",
|
| 1999 |
+
" <tr style=\"text-align: right;\">\n",
|
| 2000 |
+
" <th></th>\n",
|
| 2001 |
+
" <th>CHROM</th>\n",
|
| 2002 |
+
" <th>POS</th>\n",
|
| 2003 |
+
" <th>REF</th>\n",
|
| 2004 |
+
" <th>ALT</th>\n",
|
| 2005 |
+
" <th>LABEL</th>\n",
|
| 2006 |
+
" <th>SOURCE</th>\n",
|
| 2007 |
+
" <th>CONSEQUENCE</th>\n",
|
| 2008 |
+
" <th>ID</th>\n",
|
| 2009 |
+
" <th>REVIEW_STATUS</th>\n",
|
| 2010 |
+
" <th>GENE</th>\n",
|
| 2011 |
+
" <th>split</th>\n",
|
| 2012 |
+
" <th>INT_LABEL</th>\n",
|
| 2013 |
+
" <th>GENE_ID</th>\n",
|
| 2014 |
+
" <th>Disease</th>\n",
|
| 2015 |
+
" <th>GENE_Name</th>\n",
|
| 2016 |
+
" </tr>\n",
|
| 2017 |
+
" </thead>\n",
|
| 2018 |
+
" <tbody>\n",
|
| 2019 |
+
" <tr>\n",
|
| 2020 |
+
" <th>0</th>\n",
|
| 2021 |
+
" <td>chr1</td>\n",
|
| 2022 |
+
" <td>976215</td>\n",
|
| 2023 |
+
" <td>A</td>\n",
|
| 2024 |
+
" <td>G</td>\n",
|
| 2025 |
+
" <td>Pathogenic</td>\n",
|
| 2026 |
+
" <td>ClinVar</td>\n",
|
| 2027 |
+
" <td>missense_variant</td>\n",
|
| 2028 |
+
" <td>1320032</td>\n",
|
| 2029 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 2030 |
+
" <td>PERM1</td>\n",
|
| 2031 |
+
" <td>train</td>\n",
|
| 2032 |
+
" <td>1</td>\n",
|
| 2033 |
+
" <td>84808</td>\n",
|
| 2034 |
+
" <td>Renal tubular epithelial cell apoptosis</td>\n",
|
| 2035 |
+
" <td>PPARGC1 and ESRR induced regulator, muscle 1</td>\n",
|
| 2036 |
+
" </tr>\n",
|
| 2037 |
+
" <tr>\n",
|
| 2038 |
+
" <th>1</th>\n",
|
| 2039 |
+
" <td>chr1</td>\n",
|
| 2040 |
+
" <td>976215</td>\n",
|
| 2041 |
+
" <td>A</td>\n",
|
| 2042 |
+
" <td>G</td>\n",
|
| 2043 |
+
" <td>Pathogenic</td>\n",
|
| 2044 |
+
" <td>ClinVar</td>\n",
|
| 2045 |
+
" <td>missense_variant</td>\n",
|
| 2046 |
+
" <td>1320032</td>\n",
|
| 2047 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 2048 |
+
" <td>PERM1</td>\n",
|
| 2049 |
+
" <td>train</td>\n",
|
| 2050 |
+
" <td>1</td>\n",
|
| 2051 |
+
" <td>84808</td>\n",
|
| 2052 |
+
" <td>Neutrophil inclusion bodies</td>\n",
|
| 2053 |
+
" <td>PPARGC1 and ESRR induced regulator, muscle 1</td>\n",
|
| 2054 |
+
" </tr>\n",
|
| 2055 |
+
" <tr>\n",
|
| 2056 |
+
" <th>2</th>\n",
|
| 2057 |
+
" <td>chr1</td>\n",
|
| 2058 |
+
" <td>1050449</td>\n",
|
| 2059 |
+
" <td>G</td>\n",
|
| 2060 |
+
" <td>A</td>\n",
|
| 2061 |
+
" <td>Pathogenic</td>\n",
|
| 2062 |
+
" <td>ClinVar</td>\n",
|
| 2063 |
+
" <td>missense_variant</td>\n",
|
| 2064 |
+
" <td>1284257</td>\n",
|
| 2065 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 2066 |
+
" <td>AGRN</td>\n",
|
| 2067 |
+
" <td>train</td>\n",
|
| 2068 |
+
" <td>1</td>\n",
|
| 2069 |
+
" <td>375790</td>\n",
|
| 2070 |
+
" <td>Congenital myasthenic syndrome 8</td>\n",
|
| 2071 |
+
" <td>agrin</td>\n",
|
| 2072 |
+
" </tr>\n",
|
| 2073 |
+
" <tr>\n",
|
| 2074 |
+
" <th>3</th>\n",
|
| 2075 |
+
" <td>chr1</td>\n",
|
| 2076 |
+
" <td>1050575</td>\n",
|
| 2077 |
+
" <td>G</td>\n",
|
| 2078 |
+
" <td>C</td>\n",
|
| 2079 |
+
" <td>Pathogenic</td>\n",
|
| 2080 |
+
" <td>ClinVar</td>\n",
|
| 2081 |
+
" <td>missense_variant</td>\n",
|
| 2082 |
+
" <td>18241</td>\n",
|
| 2083 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 2084 |
+
" <td>AGRN</td>\n",
|
| 2085 |
+
" <td>train</td>\n",
|
| 2086 |
+
" <td>1</td>\n",
|
| 2087 |
+
" <td>375790</td>\n",
|
| 2088 |
+
" <td>Congenital myasthenic syndrome 8</td>\n",
|
| 2089 |
+
" <td>agrin</td>\n",
|
| 2090 |
+
" </tr>\n",
|
| 2091 |
+
" <tr>\n",
|
| 2092 |
+
" <th>4</th>\n",
|
| 2093 |
+
" <td>chr1</td>\n",
|
| 2094 |
+
" <td>1213738</td>\n",
|
| 2095 |
+
" <td>G</td>\n",
|
| 2096 |
+
" <td>A</td>\n",
|
| 2097 |
+
" <td>Pathogenic</td>\n",
|
| 2098 |
+
" <td>ClinVar</td>\n",
|
| 2099 |
+
" <td>missense_variant</td>\n",
|
| 2100 |
+
" <td>96692</td>\n",
|
| 2101 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 2102 |
+
" <td>TNFRSF4</td>\n",
|
| 2103 |
+
" <td>train</td>\n",
|
| 2104 |
+
" <td>1</td>\n",
|
| 2105 |
+
" <td>7293</td>\n",
|
| 2106 |
+
" <td>Combined immunodeficiency due to OX40 deficiency</td>\n",
|
| 2107 |
+
" <td>TNF receptor superfamily member 4</td>\n",
|
| 2108 |
+
" </tr>\n",
|
| 2109 |
+
" <tr>\n",
|
| 2110 |
+
" <th>...</th>\n",
|
| 2111 |
+
" <td>...</td>\n",
|
| 2112 |
+
" <td>...</td>\n",
|
| 2113 |
+
" <td>...</td>\n",
|
| 2114 |
+
" <td>...</td>\n",
|
| 2115 |
+
" <td>...</td>\n",
|
| 2116 |
+
" <td>...</td>\n",
|
| 2117 |
+
" <td>...</td>\n",
|
| 2118 |
+
" <td>...</td>\n",
|
| 2119 |
+
" <td>...</td>\n",
|
| 2120 |
+
" <td>...</td>\n",
|
| 2121 |
+
" <td>...</td>\n",
|
| 2122 |
+
" <td>...</td>\n",
|
| 2123 |
+
" <td>...</td>\n",
|
| 2124 |
+
" <td>...</td>\n",
|
| 2125 |
+
" <td>...</td>\n",
|
| 2126 |
+
" </tr>\n",
|
| 2127 |
+
" <tr>\n",
|
| 2128 |
+
" <th>32680</th>\n",
|
| 2129 |
+
" <td>chrY</td>\n",
|
| 2130 |
+
" <td>2787412</td>\n",
|
| 2131 |
+
" <td>C</td>\n",
|
| 2132 |
+
" <td>T</td>\n",
|
| 2133 |
+
" <td>Pathogenic</td>\n",
|
| 2134 |
+
" <td>ClinVar</td>\n",
|
| 2135 |
+
" <td>missense_variant</td>\n",
|
| 2136 |
+
" <td>9747</td>\n",
|
| 2137 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 2138 |
+
" <td>SRY</td>\n",
|
| 2139 |
+
" <td>train</td>\n",
|
| 2140 |
+
" <td>1</td>\n",
|
| 2141 |
+
" <td>6736</td>\n",
|
| 2142 |
+
" <td>46,XY sex reversal 1</td>\n",
|
| 2143 |
+
" <td>sex determining region Y</td>\n",
|
| 2144 |
+
" </tr>\n",
|
| 2145 |
+
" <tr>\n",
|
| 2146 |
+
" <th>32681</th>\n",
|
| 2147 |
+
" <td>chrY</td>\n",
|
| 2148 |
+
" <td>2787426</td>\n",
|
| 2149 |
+
" <td>C</td>\n",
|
| 2150 |
+
" <td>G</td>\n",
|
| 2151 |
+
" <td>Pathogenic</td>\n",
|
| 2152 |
+
" <td>ClinVar</td>\n",
|
| 2153 |
+
" <td>missense_variant</td>\n",
|
| 2154 |
+
" <td>9739</td>\n",
|
| 2155 |
+
" <td>criteria_provided,_single_submitter</td>\n",
|
| 2156 |
+
" <td>SRY</td>\n",
|
| 2157 |
+
" <td>train</td>\n",
|
| 2158 |
+
" <td>1</td>\n",
|
| 2159 |
+
" <td>6736</td>\n",
|
| 2160 |
+
" <td>not provided</td>\n",
|
| 2161 |
+
" <td>sex determining region Y</td>\n",
|
| 2162 |
+
" </tr>\n",
|
| 2163 |
+
" <tr>\n",
|
| 2164 |
+
" <th>32682</th>\n",
|
| 2165 |
+
" <td>chrY</td>\n",
|
| 2166 |
+
" <td>2787515</td>\n",
|
| 2167 |
+
" <td>C</td>\n",
|
| 2168 |
+
" <td>A</td>\n",
|
| 2169 |
+
" <td>Pathogenic</td>\n",
|
| 2170 |
+
" <td>ClinVar</td>\n",
|
| 2171 |
+
" <td>missense_variant</td>\n",
|
| 2172 |
+
" <td>492908</td>\n",
|
| 2173 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 2174 |
+
" <td>SRY</td>\n",
|
| 2175 |
+
" <td>train</td>\n",
|
| 2176 |
+
" <td>1</td>\n",
|
| 2177 |
+
" <td>6736</td>\n",
|
| 2178 |
+
" <td>46,XY sex reversal 1</td>\n",
|
| 2179 |
+
" <td>sex determining region Y</td>\n",
|
| 2180 |
+
" </tr>\n",
|
| 2181 |
+
" <tr>\n",
|
| 2182 |
+
" <th>32683</th>\n",
|
| 2183 |
+
" <td>chrY</td>\n",
|
| 2184 |
+
" <td>2787551</td>\n",
|
| 2185 |
+
" <td>C</td>\n",
|
| 2186 |
+
" <td>T</td>\n",
|
| 2187 |
+
" <td>Pathogenic</td>\n",
|
| 2188 |
+
" <td>ClinVar</td>\n",
|
| 2189 |
+
" <td>missense_variant</td>\n",
|
| 2190 |
+
" <td>9754</td>\n",
|
| 2191 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 2192 |
+
" <td>SRY</td>\n",
|
| 2193 |
+
" <td>train</td>\n",
|
| 2194 |
+
" <td>1</td>\n",
|
| 2195 |
+
" <td>6736</td>\n",
|
| 2196 |
+
" <td>46,XY sex reversal 1</td>\n",
|
| 2197 |
+
" <td>sex determining region Y</td>\n",
|
| 2198 |
+
" </tr>\n",
|
| 2199 |
+
" <tr>\n",
|
| 2200 |
+
" <th>32684</th>\n",
|
| 2201 |
+
" <td>chrY</td>\n",
|
| 2202 |
+
" <td>7063898</td>\n",
|
| 2203 |
+
" <td>A</td>\n",
|
| 2204 |
+
" <td>T</td>\n",
|
| 2205 |
+
" <td>Pathogenic</td>\n",
|
| 2206 |
+
" <td>ClinVar</td>\n",
|
| 2207 |
+
" <td>missense_variant</td>\n",
|
| 2208 |
+
" <td>625467</td>\n",
|
| 2209 |
+
" <td>no_assertion_criteria_provided</td>\n",
|
| 2210 |
+
" <td>LOC126057105, TBL1Y</td>\n",
|
| 2211 |
+
" <td>train</td>\n",
|
| 2212 |
+
" <td>1</td>\n",
|
| 2213 |
+
" <td>126057105, 90665</td>\n",
|
| 2214 |
+
" <td>Deafness, Y-linked 2</td>\n",
|
| 2215 |
+
" <td>P300/CBP strongly-dependent group 1 enhancer G...</td>\n",
|
| 2216 |
+
" </tr>\n",
|
| 2217 |
+
" </tbody>\n",
|
| 2218 |
+
"</table>\n",
|
| 2219 |
+
"<p>32685 rows × 15 columns</p>\n",
|
| 2220 |
+
"</div>"
|
| 2221 |
+
],
|
| 2222 |
+
"text/plain": [
|
| 2223 |
+
" CHROM POS REF ALT LABEL SOURCE CONSEQUENCE ID \\\n",
|
| 2224 |
+
"0 chr1 976215 A G Pathogenic ClinVar missense_variant 1320032 \n",
|
| 2225 |
+
"1 chr1 976215 A G Pathogenic ClinVar missense_variant 1320032 \n",
|
| 2226 |
+
"2 chr1 1050449 G A Pathogenic ClinVar missense_variant 1284257 \n",
|
| 2227 |
+
"3 chr1 1050575 G C Pathogenic ClinVar missense_variant 18241 \n",
|
| 2228 |
+
"4 chr1 1213738 G A Pathogenic ClinVar missense_variant 96692 \n",
|
| 2229 |
+
"... ... ... .. .. ... ... ... ... \n",
|
| 2230 |
+
"32680 chrY 2787412 C T Pathogenic ClinVar missense_variant 9747 \n",
|
| 2231 |
+
"32681 chrY 2787426 C G Pathogenic ClinVar missense_variant 9739 \n",
|
| 2232 |
+
"32682 chrY 2787515 C A Pathogenic ClinVar missense_variant 492908 \n",
|
| 2233 |
+
"32683 chrY 2787551 C T Pathogenic ClinVar missense_variant 9754 \n",
|
| 2234 |
+
"32684 chrY 7063898 A T Pathogenic ClinVar missense_variant 625467 \n",
|
| 2235 |
+
"\n",
|
| 2236 |
+
" REVIEW_STATUS GENE split \\\n",
|
| 2237 |
+
"0 no_assertion_criteria_provided PERM1 train \n",
|
| 2238 |
+
"1 no_assertion_criteria_provided PERM1 train \n",
|
| 2239 |
+
"2 no_assertion_criteria_provided AGRN train \n",
|
| 2240 |
+
"3 no_assertion_criteria_provided AGRN train \n",
|
| 2241 |
+
"4 no_assertion_criteria_provided TNFRSF4 train \n",
|
| 2242 |
+
"... ... ... ... \n",
|
| 2243 |
+
"32680 no_assertion_criteria_provided SRY train \n",
|
| 2244 |
+
"32681 criteria_provided,_single_submitter SRY train \n",
|
| 2245 |
+
"32682 no_assertion_criteria_provided SRY train \n",
|
| 2246 |
+
"32683 no_assertion_criteria_provided SRY train \n",
|
| 2247 |
+
"32684 no_assertion_criteria_provided LOC126057105, TBL1Y train \n",
|
| 2248 |
+
"\n",
|
| 2249 |
+
" INT_LABEL GENE_ID \\\n",
|
| 2250 |
+
"0 1 84808 \n",
|
| 2251 |
+
"1 1 84808 \n",
|
| 2252 |
+
"2 1 375790 \n",
|
| 2253 |
+
"3 1 375790 \n",
|
| 2254 |
+
"4 1 7293 \n",
|
| 2255 |
+
"... ... ... \n",
|
| 2256 |
+
"32680 1 6736 \n",
|
| 2257 |
+
"32681 1 6736 \n",
|
| 2258 |
+
"32682 1 6736 \n",
|
| 2259 |
+
"32683 1 6736 \n",
|
| 2260 |
+
"32684 1 126057105, 90665 \n",
|
| 2261 |
+
"\n",
|
| 2262 |
+
" Disease \\\n",
|
| 2263 |
+
"0 Renal tubular epithelial cell apoptosis \n",
|
| 2264 |
+
"1 Neutrophil inclusion bodies \n",
|
| 2265 |
+
"2 Congenital myasthenic syndrome 8 \n",
|
| 2266 |
+
"3 Congenital myasthenic syndrome 8 \n",
|
| 2267 |
+
"4 Combined immunodeficiency due to OX40 deficiency \n",
|
| 2268 |
+
"... ... \n",
|
| 2269 |
+
"32680 46,XY sex reversal 1 \n",
|
| 2270 |
+
"32681 not provided \n",
|
| 2271 |
+
"32682 46,XY sex reversal 1 \n",
|
| 2272 |
+
"32683 46,XY sex reversal 1 \n",
|
| 2273 |
+
"32684 Deafness, Y-linked 2 \n",
|
| 2274 |
+
"\n",
|
| 2275 |
+
" GENE_Name \n",
|
| 2276 |
+
"0 PPARGC1 and ESRR induced regulator, muscle 1 \n",
|
| 2277 |
+
"1 PPARGC1 and ESRR induced regulator, muscle 1 \n",
|
| 2278 |
+
"2 agrin \n",
|
| 2279 |
+
"3 agrin \n",
|
| 2280 |
+
"4 TNF receptor superfamily member 4 \n",
|
| 2281 |
+
"... ... \n",
|
| 2282 |
+
"32680 sex determining region Y \n",
|
| 2283 |
+
"32681 sex determining region Y \n",
|
| 2284 |
+
"32682 sex determining region Y \n",
|
| 2285 |
+
"32683 sex determining region Y \n",
|
| 2286 |
+
"32684 P300/CBP strongly-dependent group 1 enhancer G... \n",
|
| 2287 |
+
"\n",
|
| 2288 |
+
"[32685 rows x 15 columns]"
|
| 2289 |
+
]
|
| 2290 |
+
},
|
| 2291 |
+
"execution_count": 59,
|
| 2292 |
+
"metadata": {},
|
| 2293 |
+
"output_type": "execute_result"
|
| 2294 |
+
}
|
| 2295 |
+
],
|
| 2296 |
+
"source": [
|
| 2297 |
+
"# Display final dataset with all extracted information\n",
|
| 2298 |
+
"if 'clinvar_data' in locals() and clinvar_data is not None:\n",
|
| 2299 |
+
" print(f\"📊 Final dataset shape: {clinvar_data.shape}\")\n",
|
| 2300 |
+
" print(f\"📋 Columns: {list(clinvar_data.columns)}\")\n",
|
| 2301 |
+
" \n",
|
| 2302 |
+
" # Data completeness statistics\n",
|
| 2303 |
+
" print(\"\\n📈 Data Completeness:\")\n",
|
| 2304 |
+
" for col in ['GENE', 'GENE_ID', 'GENE_Name', 'Disease']:\n",
|
| 2305 |
+
" if col in clinvar_data.columns:\n",
|
| 2306 |
+
" filled_count = (clinvar_data[col] != '').sum()\n",
|
| 2307 |
+
" print(f\" {col}: {filled_count}/{len(clinvar_data)} ({filled_count/len(clinvar_data)*100:.1f}%)\")\n",
|
| 2308 |
+
" \n",
|
| 2309 |
+
" # Sample data\n",
|
| 2310 |
+
" print(\"\\n🔍 Sample data:\")\n",
|
| 2311 |
+
" display(clinvar_data.head())\n",
|
| 2312 |
+
" \n",
|
| 2313 |
+
" # Memory usage\n",
|
| 2314 |
+
" memory_mb = clinvar_data.memory_usage(deep=True).sum() / 1024 / 1024\n",
|
| 2315 |
+
" print(f\"\\n💾 Dataset memory usage: {memory_mb:.1f} MB\")\n",
|
| 2316 |
+
"else:\n",
|
| 2317 |
+
" print(\"❌ No final data to display\")"
|
| 2318 |
+
]
|
| 2319 |
+
},
|
| 2320 |
+
{
|
| 2321 |
+
"cell_type": "code",
|
| 2322 |
+
"execution_count": null,
|
| 2323 |
+
"id": "c545ae83-5cd1-4e29-87fd-69389bdb153f",
|
| 2324 |
+
"metadata": {},
|
| 2325 |
+
"outputs": [
|
| 2326 |
+
{
|
| 2327 |
+
"data": {
|
| 2328 |
+
"text/plain": [
|
| 2329 |
+
"'P300/CBP strongly-dependent group 1 enhancer GRCh37_chrY:6931456-6932655| transducin beta like 1 Y-linked'"
|
| 2330 |
+
]
|
| 2331 |
+
},
|
| 2332 |
+
"execution_count": 60,
|
| 2333 |
+
"metadata": {},
|
| 2334 |
+
"output_type": "execute_result"
|
| 2335 |
+
}
|
| 2336 |
+
],
|
| 2337 |
+
"source": [
|
| 2338 |
+
"# Show example of gene name mapping\n",
|
| 2339 |
+
"if 'clinvar_data' in locals() and clinvar_data is not None and len(clinvar_data) > 32684:\n",
|
| 2340 |
+
" example_gene_name = clinvar_data.iloc[32684]['GENE_Name']\n",
|
| 2341 |
+
" example_gene_id = clinvar_data.iloc[32684]['GENE_ID']\n",
|
| 2342 |
+
" \n",
|
| 2343 |
+
" print(f\"🔍 Example gene mapping for row 32684:\")\n",
|
| 2344 |
+
" print(f\" Gene ID: {example_gene_id}\")\n",
|
| 2345 |
+
" print(f\" Gene Name: {example_gene_name}\")\n",
|
| 2346 |
+
"else:\n",
|
| 2347 |
+
" # Show any available example\n",
|
| 2348 |
+
" if 'clinvar_data' in locals() and clinvar_data is not None and not clinvar_data.empty:\n",
|
| 2349 |
+
" # Find first row with gene name data\n",
|
| 2350 |
+
" example_row = clinvar_data[clinvar_data['GENE_Name'] != ''].iloc[0] if (clinvar_data['GENE_Name'] != '').any() else clinvar_data.iloc[0]\n",
|
| 2351 |
+
" \n",
|
| 2352 |
+
" print(f\"🔍 Example gene mapping:\")\n",
|
| 2353 |
+
" print(f\" Gene ID: {example_row.get('GENE_ID', 'N/A')}\")\n",
|
| 2354 |
+
" print(f\" Gene Name: {example_row.get('GENE_Name', 'N/A')}\")\n",
|
| 2355 |
+
" else:\n",
|
| 2356 |
+
" print(\"❌ No data available for example\")"
|
| 2357 |
+
]
|
| 2358 |
+
},
|
| 2359 |
+
{
|
| 2360 |
+
"cell_type": "code",
|
| 2361 |
+
"execution_count": null,
|
| 2362 |
+
"id": "a214c29d-a4f1-4af6-a914-e6b4a14a1c49",
|
| 2363 |
+
"metadata": {},
|
| 2364 |
+
"outputs": [],
|
| 2365 |
+
"source": [
|
| 2366 |
+
"import os\n",
|
| 2367 |
+
"\n",
|
| 2368 |
+
"# Save the final processed dataset\n",
|
| 2369 |
+
"if 'clinvar_data' in locals() and clinvar_data is not None:\n",
|
| 2370 |
+
" output_file = \"clinvar_with_disease.csv\"\n",
|
| 2371 |
+
" \n",
|
| 2372 |
+
" try:\n",
|
| 2373 |
+
" clinvar_data.to_csv(output_file, index=False)\n",
|
| 2374 |
+
" \n",
|
| 2375 |
+
" print(f\"✅ Final dataset saved to {output_file}\")\n",
|
| 2376 |
+
" print(f\"📊 Saved {len(clinvar_data)} records with {len(clinvar_data.columns)} columns\")\n",
|
| 2377 |
+
" \n",
|
| 2378 |
+
" # File size\n",
|
| 2379 |
+
" file_size = os.path.getsize(output_file) / 1024 / 1024\n",
|
| 2380 |
+
" print(f\"💾 File size: {file_size:.1f} MB\")\n",
|
| 2381 |
+
" \n",
|
| 2382 |
+
" # Summary of what was accomplished\n",
|
| 2383 |
+
" print(\"\\n🎯 Processing Summary:\")\n",
|
| 2384 |
+
" print(f\" ✓ Extracted ClinVar coding variants\")\n",
|
| 2385 |
+
" print(f\" ✓ Parsed XML records for gene information\")\n",
|
| 2386 |
+
" print(f\" ✓ Mapped diseases/phenotypes\")\n",
|
| 2387 |
+
" print(f\" ✓ Added human-readable gene names\")\n",
|
| 2388 |
+
" print(f\" ✓ Created comprehensive dataset\")\n",
|
| 2389 |
+
" \n",
|
| 2390 |
+
" except Exception as e:\n",
|
| 2391 |
+
" print(f\"❌ Error saving dataset: {e}\")\n",
|
| 2392 |
+
"else:\n",
|
| 2393 |
+
" print(\"⚠️ No data available to save\")"
|
| 2394 |
+
]
|
| 2395 |
+
},
|
| 2396 |
+
{
|
| 2397 |
+
"cell_type": "code",
|
| 2398 |
+
"execution_count": null,
|
| 2399 |
+
"id": "b6c4c1f4-4b87-4624-8f8a-c568e40b2e63",
|
| 2400 |
+
"metadata": {},
|
| 2401 |
+
"outputs": [],
|
| 2402 |
+
"source": [
|
| 2403 |
+
"import os\n",
|
| 2404 |
+
"import shutil\n",
|
| 2405 |
+
"\n",
|
| 2406 |
+
"# Optional: Clean up temporary XML data directory\n",
|
| 2407 |
+
"# Uncomment the following lines if you want to remove the XML files to save space\n",
|
| 2408 |
+
"\n",
|
| 2409 |
+
"if os.path.exists(\"data\") and os.path.isdir(\"data\"):\n",
|
| 2410 |
+
" # Count files before cleanup\n",
|
| 2411 |
+
" xml_files = [f for f in os.listdir(\"data\") if f.endswith('.xml')]\n",
|
| 2412 |
+
" \n",
|
| 2413 |
+
" print(f\"🗂️ Found {len(xml_files)} XML files in data directory\")\n",
|
| 2414 |
+
" \n",
|
| 2415 |
+
" # Uncomment to actually remove the directory\n",
|
| 2416 |
+
" # shutil.rmtree(\"data\")\n",
|
| 2417 |
+
" # print(\"🗑️ Removed temporary XML data directory\")\n",
|
| 2418 |
+
" \n",
|
| 2419 |
+
" print(\"ℹ️ XML files preserved. Uncomment the cleanup code to remove them.\")\n",
|
| 2420 |
+
"else:\n",
|
| 2421 |
+
" print(\"ℹ️ No XML data directory found to clean up\")"
|
| 2422 |
+
]
|
| 2423 |
+
},
|
| 2424 |
+
{
|
| 2425 |
+
"cell_type": "code",
|
| 2426 |
+
"execution_count": null,
|
| 2427 |
+
"id": "c08beea6-6ff7-4900-a8b8-8a719db36189",
|
| 2428 |
+
"metadata": {},
|
| 2429 |
+
"outputs": [],
|
| 2430 |
+
"source": [
|
| 2431 |
+
"## Processing Complete ✅\n",
|
| 2432 |
+
"\n",
|
| 2433 |
+
"The ClinVar coding variants have been successfully processed with the following enhancements:\n",
|
| 2434 |
+
"\n",
|
| 2435 |
+
"### Generated Files:\n",
|
| 2436 |
+
"- `clinvar_coding_raw.csv` - Raw ClinVar entries extracted from VEP data\n",
|
| 2437 |
+
"- `Clinvar_ID.txt` - List of ClinVar IDs for processing\n",
|
| 2438 |
+
"- `gene_id.txt` - Unique gene IDs for name mapping\n",
|
| 2439 |
+
"- `gene_id_to_name.json` - Gene ID to name mapping dictionary\n",
|
| 2440 |
+
"- `clinvar_with_disease.csv` - **Final comprehensive dataset**\n",
|
| 2441 |
+
"\n",
|
| 2442 |
+
"### Dataset Features:\n",
|
| 2443 |
+
"- **Variant Information**: Genomic coordinates, alleles, and annotations\n",
|
| 2444 |
+
"- **Gene Data**: Symbols, IDs, and human-readable names\n",
|
| 2445 |
+
"- **Disease/Phenotype**: Associated conditions and clinical significance\n",
|
| 2446 |
+
"- **Expanded Format**: One row per variant-disease combination\n",
|
| 2447 |
+
"\n",
|
| 2448 |
+
"### Next Steps:\n",
|
| 2449 |
+
"1. **Quality Control**: Review the data for completeness and accuracy\n",
|
| 2450 |
+
"2. **Analysis**: Use the dataset for downstream genetic analysis\n",
|
| 2451 |
+
"3. **Integration**: Combine with other datasets as needed\n",
|
| 2452 |
+
"4. **Documentation**: Update metadata and create data dictionary\n",
|
| 2453 |
+
"\n",
|
| 2454 |
+
"### File Cleanup:\n",
|
| 2455 |
+
"- XML files in `data/` directory can be removed to save space\n",
|
| 2456 |
+
"- Intermediate files can be archived or removed as needed"
|
| 2457 |
+
]
|
| 2458 |
+
}
|
| 2459 |
+
],
|
| 2460 |
+
"metadata": {
|
| 2461 |
+
"kernelspec": {
|
| 2462 |
+
"display_name": "Python 3 (ipykernel)",
|
| 2463 |
+
"language": "python",
|
| 2464 |
+
"name": "python3"
|
| 2465 |
+
},
|
| 2466 |
+
"language_info": {
|
| 2467 |
+
"codemirror_mode": {
|
| 2468 |
+
"name": "ipython",
|
| 2469 |
+
"version": 3
|
| 2470 |
+
},
|
| 2471 |
+
"file_extension": ".py",
|
| 2472 |
+
"mimetype": "text/x-python",
|
| 2473 |
+
"name": "python",
|
| 2474 |
+
"nbconvert_exporter": "python",
|
| 2475 |
+
"pygments_lexer": "ipython3",
|
| 2476 |
+
"version": "3.12.9"
|
| 2477 |
+
}
|
| 2478 |
+
},
|
| 2479 |
+
"nbformat": 4,
|
| 2480 |
+
"nbformat_minor": 5
|
| 2481 |
+
}
|
BioReason/pyproject.toml
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=42", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "bioreason"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "Bio-related Reasoning with Language Models"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.11"
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"Programming Language :: Python :: 3.11",
|
| 14 |
+
"License :: OSI Approved :: MIT License",
|
| 15 |
+
"Operating System :: OS Independent",
|
| 16 |
+
]
|
| 17 |
+
dependencies = [
|
| 18 |
+
"torch",
|
| 19 |
+
"torchvision",
|
| 20 |
+
"transformers",
|
| 21 |
+
"accelerate",
|
| 22 |
+
"qwen-vl-utils",
|
| 23 |
+
"jupyter",
|
| 24 |
+
"datasets",
|
| 25 |
+
"peft",
|
| 26 |
+
"pytorch_lightning",
|
| 27 |
+
"wandb",
|
| 28 |
+
"trl[vllm]",
|
| 29 |
+
"bitsandbytes",
|
| 30 |
+
"deepspeed",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
[project.optional-dependencies]
|
| 34 |
+
dev = [
|
| 35 |
+
"pytest",
|
| 36 |
+
"black",
|
| 37 |
+
"isort",
|
| 38 |
+
"mypy",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
[tool.setuptools]
|
| 42 |
+
packages = ["bioreason"]
|
| 43 |
+
|
| 44 |
+
[tool.black]
|
| 45 |
+
line-length = 88
|
| 46 |
+
target-version = ["py311"]
|
| 47 |
+
|
| 48 |
+
[tool.isort]
|
| 49 |
+
profile = "black"
|
| 50 |
+
line_length = 88
|
| 51 |
+
|
| 52 |
+
[tool.mypy]
|
| 53 |
+
python_version = "3.11"
|
| 54 |
+
warn_return_any = true
|
| 55 |
+
warn_unused_configs = true
|
| 56 |
+
disallow_untyped_defs = true
|
| 57 |
+
disallow_incomplete_defs = true
|
BioReason/reason.py
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import pathlib
|
| 5 |
+
from argparse import ArgumentParser
|
| 6 |
+
from typing import List, Dict, Optional
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.optim import AdamW
|
| 13 |
+
from torch.utils.data import DataLoader, Dataset
|
| 14 |
+
from transformers import get_cosine_schedule_with_warmup, AutoTokenizer
|
| 15 |
+
|
| 16 |
+
from transformers import (
|
| 17 |
+
AutoTokenizer,
|
| 18 |
+
AutoModelForCausalLM,
|
| 19 |
+
AutoModelForMaskedLM,
|
| 20 |
+
AutoProcessor,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from datasets import load_dataset, DatasetDict
|
| 24 |
+
|
| 25 |
+
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
|
| 26 |
+
from transformers import BitsAndBytesConfig
|
| 27 |
+
|
| 28 |
+
import pytorch_lightning as pl
|
| 29 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 30 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 31 |
+
|
| 32 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 33 |
+
#from unsloth import FastLanguageModel, is_bfloat16_supported
|
| 34 |
+
|
| 35 |
+
from bioreason.models.dna_llm import DNALLMModel
|
| 36 |
+
from bioreason.dna_modules import NucleotideDNAModule
|
| 37 |
+
from bioreason.models.dl.processing_dl import DLProcessor
|
| 38 |
+
from bioreason.trainer import DNALLMGRPOTrainer, DNALLMGRPOConfig
|
| 39 |
+
from bioreason.models.evo2_tokenizer import Evo2Tokenizer, register_evo2_tokenizer
|
| 40 |
+
register_evo2_tokenizer()
|
| 41 |
+
|
| 42 |
+
# Custom TrainerCallback to override the saving mechanism
|
| 43 |
+
from transformers import TrainerCallback, TrainerState, TrainerControl
|
| 44 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 45 |
+
|
| 46 |
+
class SaveWithPyTorchCallback(TrainerCallback):
|
| 47 |
+
"""Custom callback to save models with PyTorch's native save mechanism instead of safetensors"""
|
| 48 |
+
def on_save(self, args, state, control, **kwargs):
|
| 49 |
+
# Get the checkpoint folder
|
| 50 |
+
checkpoint_folder = os.path.join(
|
| 51 |
+
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
| 52 |
+
)
|
| 53 |
+
os.makedirs(checkpoint_folder, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
# Save with PyTorch instead of safetensors
|
| 56 |
+
checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
|
| 57 |
+
model = kwargs.get("model")
|
| 58 |
+
|
| 59 |
+
# Get model unwrapped from accelerator etc.
|
| 60 |
+
unwrapped_model = model.module if hasattr(model, "module") else model
|
| 61 |
+
|
| 62 |
+
# Save using PyTorch directly
|
| 63 |
+
torch.save(unwrapped_model.state_dict(), checkpoint_path)
|
| 64 |
+
|
| 65 |
+
# DNALLMModel doesn't have a direct config attribute, so we need to save
|
| 66 |
+
# the configs of its sub-models
|
| 67 |
+
if hasattr(unwrapped_model, "text_model"):
|
| 68 |
+
if hasattr(unwrapped_model.text_model, "config"):
|
| 69 |
+
unwrapped_model.text_model.config.save_pretrained(checkpoint_folder)
|
| 70 |
+
# Handle PEFT models which might have base_model
|
| 71 |
+
elif hasattr(unwrapped_model.text_model, "base_model") and hasattr(unwrapped_model.text_model.base_model, "config"):
|
| 72 |
+
unwrapped_model.text_model.base_model.config.save_pretrained(checkpoint_folder)
|
| 73 |
+
|
| 74 |
+
# Print info about what's being saved
|
| 75 |
+
print(f"Saved model checkpoint to {checkpoint_folder}")
|
| 76 |
+
lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k]
|
| 77 |
+
print(f"Checkpoint contains {len(lora_params)} LoRA parameters")
|
| 78 |
+
|
| 79 |
+
# Signal that we've saved
|
| 80 |
+
control.should_save = False
|
| 81 |
+
return control
|
| 82 |
+
|
| 83 |
+
def _get_target_modules(model: DNALLMModel):
|
| 84 |
+
# Apply LoRA to all linear layers in the text model
|
| 85 |
+
target_modules = []
|
| 86 |
+
|
| 87 |
+
# Get all unique linear layer names
|
| 88 |
+
seen_names = set()
|
| 89 |
+
for name, module in model.text.named_modules():
|
| 90 |
+
if isinstance(module, torch.nn.Linear):
|
| 91 |
+
names = name.split(".")
|
| 92 |
+
target_name = names[-1] # Use the last part of the name
|
| 93 |
+
|
| 94 |
+
# Skip output head but include all other linear layers
|
| 95 |
+
if target_name != "lm_head" and target_name not in seen_names:
|
| 96 |
+
target_modules.append(target_name)
|
| 97 |
+
seen_names.add(target_name)
|
| 98 |
+
|
| 99 |
+
# Add attention-specific layers
|
| 100 |
+
attention_patterns = [
|
| 101 |
+
"q_proj",
|
| 102 |
+
"k_proj",
|
| 103 |
+
"v_proj",
|
| 104 |
+
"out_proj",
|
| 105 |
+
"query",
|
| 106 |
+
"key",
|
| 107 |
+
"value",
|
| 108 |
+
]
|
| 109 |
+
for pattern in attention_patterns:
|
| 110 |
+
if pattern not in seen_names:
|
| 111 |
+
target_modules.append(pattern)
|
| 112 |
+
|
| 113 |
+
# Return all unique layer names to apply LoRA to all layers
|
| 114 |
+
return list(target_modules)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def extract_xml_answer(text: str) -> str:
|
| 118 |
+
# answer = text.split("<answer>")[-1]
|
| 119 |
+
# answer = answer.split("</answer>")[0]
|
| 120 |
+
answer = text.split("</think>")[-1]
|
| 121 |
+
return answer.strip()
|
| 122 |
+
|
| 123 |
+
def extract_hash_answer(text: str) -> str | None:
|
| 124 |
+
if "####" not in text:
|
| 125 |
+
return None
|
| 126 |
+
return text.split("####")[1].strip()
|
| 127 |
+
|
| 128 |
+
def get_kegg_questions() -> Dataset:
|
| 129 |
+
data = load_dataset('wanglab/kegg', 'default') # type: ignore
|
| 130 |
+
example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
|
| 131 |
+
num_dna_sequences = 2 # TODO: Change to 2!
|
| 132 |
+
|
| 133 |
+
data = data.map(lambda x: { # type: ignore
|
| 134 |
+
'prompt': [
|
| 135 |
+
|
| 136 |
+
{
|
| 137 |
+
'role': 'user',
|
| 138 |
+
'content': [
|
| 139 |
+
*({'type': 'dna', 'text': None} for _ in range(num_dna_sequences)),
|
| 140 |
+
{'type': 'text', 'text': x['question']},
|
| 141 |
+
],
|
| 142 |
+
},
|
| 143 |
+
],
|
| 144 |
+
'dna_sequences': [x['reference_sequence'], x['variant_sequence']],
|
| 145 |
+
'answer': x['answer'],
|
| 146 |
+
}) # type: ignore
|
| 147 |
+
|
| 148 |
+
return data
|
| 149 |
+
|
| 150 |
+
# uncomment middle messages for 1-shot prompting
|
| 151 |
+
def get_gsm8k_questions(question_prompt: str) -> Dataset:
|
| 152 |
+
data = load_dataset('openai/gsm8k', 'main') # type: ignore
|
| 153 |
+
|
| 154 |
+
example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
|
| 155 |
+
data = data.map(lambda x: { # type: ignore
|
| 156 |
+
'prompt': [
|
| 157 |
+
|
| 158 |
+
{
|
| 159 |
+
'role': 'user',
|
| 160 |
+
'content': [
|
| 161 |
+
*({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))),
|
| 162 |
+
{'type': 'text', 'text': 'Give me a short introduction to large language model.'}
|
| 163 |
+
]
|
| 164 |
+
},
|
| 165 |
+
],
|
| 166 |
+
'dna_sequences': [dna for dna in example_dna_sequences],
|
| 167 |
+
'answer': extract_hash_answer(x['answer']),
|
| 168 |
+
}) # type: ignore
|
| 169 |
+
|
| 170 |
+
return data # type: ignore
|
| 171 |
+
|
| 172 |
+
def get_gsm8k_questions_old(question_prompt: str) -> Dataset:
|
| 173 |
+
data = load_dataset('openai/gsm8k', 'main') # type: ignore
|
| 174 |
+
|
| 175 |
+
example_dna_sequences = ["ATCTACATGCAT", "CAGCAGCTACAG", "CATCACATCGACATCGAC"]
|
| 176 |
+
data = data.map(lambda x: { # type: ignore
|
| 177 |
+
'prompt': [
|
| 178 |
+
{
|
| 179 |
+
'role': 'user',
|
| 180 |
+
'content': [
|
| 181 |
+
*({'type': 'dna', 'text': None} for _ in range(len(example_dna_sequences))),
|
| 182 |
+
{'type': 'text', 'text': question_prompt.format(Question=x['question'])}
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
],
|
| 186 |
+
'dna_sequences': [dna for dna in example_dna_sequences],
|
| 187 |
+
'answer': extract_hash_answer(x['answer']),
|
| 188 |
+
}) # type: ignore
|
| 189 |
+
|
| 190 |
+
return data # type: ignore
|
| 191 |
+
|
| 192 |
+
# Reward functions
|
| 193 |
+
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
| 194 |
+
responses = [completion[0]['content'] for completion in completions]
|
| 195 |
+
q = prompts[0][-1]['content']
|
| 196 |
+
extracted_responses = [extract_xml_answer(r) for r in responses]
|
| 197 |
+
# extracted_responses = [r.lower().replace("answer:", "").strip() for r in extracted_responses]
|
| 198 |
+
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
|
| 199 |
+
return [2.0 if a.lower() in r.lower() else 0.0 for r, a in zip(extracted_responses, answer[0])]
|
| 200 |
+
|
| 201 |
+
def less_than_4_reward_func(completions, **kwargs) -> list[float]:
|
| 202 |
+
responses = [completion[0]['content'] for completion in completions]
|
| 203 |
+
extracted_responses = [extract_xml_answer(r) for r in responses]
|
| 204 |
+
return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses]
|
| 205 |
+
|
| 206 |
+
def strict_format_reward_func(completions, **kwargs) -> list[float]:
|
| 207 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 208 |
+
pattern = r"^<think>\n.*?\n</think>\n.*?\n$"
|
| 209 |
+
responses = [completion[0]["content"] for completion in completions]
|
| 210 |
+
matches = [re.match(pattern, r) for r in responses]
|
| 211 |
+
return [0.5 if match else 0.0 for match in matches]
|
| 212 |
+
|
| 213 |
+
def soft_format_reward_func(completions, **kwargs) -> list[float]:
|
| 214 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 215 |
+
pattern = r"<think>.*?</think>\s*.*?"
|
| 216 |
+
responses = [completion[0]["content"] for completion in completions]
|
| 217 |
+
matches = [re.match(pattern, r) for r in responses]
|
| 218 |
+
return [0.5 if match else 0.0 for match in matches]
|
| 219 |
+
|
| 220 |
+
def count_xml(text) -> float:
|
| 221 |
+
count = 0.0
|
| 222 |
+
if text.count("<think>\n") == 1:
|
| 223 |
+
count += 0.125
|
| 224 |
+
if text.count("\n</think>\n") == 1:
|
| 225 |
+
count += 0.125
|
| 226 |
+
return count
|
| 227 |
+
|
| 228 |
+
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
|
| 229 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 230 |
+
return [count_xml(c) for c in contents]
|
| 231 |
+
|
| 232 |
+
# Format into conversation
|
| 233 |
+
def make_conversation(example):
|
| 234 |
+
return {
|
| 235 |
+
"prompt": [
|
| 236 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 237 |
+
{"role": "user", "content": example["problem"]},
|
| 238 |
+
],
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
def make_conversation_image(example):
|
| 242 |
+
return {
|
| 243 |
+
"prompt": [
|
| 244 |
+
{
|
| 245 |
+
"role": "user",
|
| 246 |
+
"content": [
|
| 247 |
+
{"type": "image"},
|
| 248 |
+
],
|
| 249 |
+
},
|
| 250 |
+
],
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
@dataclass
|
| 254 |
+
# class GRPOModelConfig(ModelConfig):
|
| 255 |
+
|
| 256 |
+
# # "HuggingFaceTB/SmolLM-135M-Instruct"
|
| 257 |
+
# # "Qwen/Qwen2.5-0.5B-Instruct"
|
| 258 |
+
# model_name_or_path: str = field(default="Qwen/Qwen3-0.6B", metadata={"help": "Model checkpoint for weights initialization."})
|
| 259 |
+
# dna_model_name_or_path: str = field(default="InstaDeepAI/nucleotide-transformer-v2-100m-multi-species", metadata={"help": "Model checkpoint for weights initialization."})
|
| 260 |
+
# cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."})
|
| 261 |
+
# max_length_text: int = field(default=800, metadata={"help": "Maximum length of text sequences."})
|
| 262 |
+
# max_length_dna: int = field(default=800, metadata={"help": "Maximum length of DNA sequences, in groups of 6 nucleotides."})
|
| 263 |
+
# sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."})
|
| 264 |
+
# lora_r: int = field(default=32, metadata={"help": "LoRA R value."})
|
| 265 |
+
# lora_alpha: int = field(default=64, metadata={"help": "LoRA alpha."})
|
| 266 |
+
# lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout."})
|
| 267 |
+
# lora_modules_to_save: Optional[list[str]] = field(
|
| 268 |
+
# default="embed_tokens",
|
| 269 |
+
# metadata={"help": "Model layers to unfreeze & train."},
|
| 270 |
+
# )
|
| 271 |
+
# freeze_dna_modules: bool = False
|
| 272 |
+
|
| 273 |
+
class GRPOModelConfig(ModelConfig):
|
| 274 |
+
|
| 275 |
+
model_name_or_path: str = field(default="Qwen/Qwen3-0.6B", metadata={"help": "Model checkpoint for LLM weights initialization."})
|
| 276 |
+
protein_model_name_or_path: str = field(default="esm2_t33_650M_UR50D", metadata={"help": "Model checkpoint for ESM-2 protein weights initialization."})
|
| 277 |
+
cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."})
|
| 278 |
+
max_length_text: int = field(default=800, metadata={"help": "Maximum length of text sequences."})
|
| 279 |
+
max_length_protein: int = field(default=800, metadata={"help": "Maximum length of protein sequences (number of amino acids)."})
|
| 280 |
+
sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."})
|
| 281 |
+
lora_r: int = field(default=32, metadata={"help": "LoRA R value."})
|
| 282 |
+
lora_alpha: int = field(default=64, metadata={"help": "LoRA alpha."})
|
| 283 |
+
lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout."})
|
| 284 |
+
lora_modules_to_save: Optional[list[str]] = field(
|
| 285 |
+
default_factory=lambda: ["embed_tokens", "lm_head"],
|
| 286 |
+
metadata={"help": "Model layers to unfreeze & train with LoRA."},
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Updated: Renamed `freeze_dna_modules` to `freeze_protein_model`
|
| 290 |
+
freeze_protein_model: bool = field(default=True, metadata={"help": "Whether to freeze the ESM-2 protein model during training."})
|
| 291 |
+
num_query_tokens: int = field(default=32, metadata={"help": "The number of query tokens used by the Q-Former to summarize protein features. These tokens will be injected into the LLM input."})
|
| 292 |
+
# New: Parameters for the projector layer
|
| 293 |
+
projector_hidden_size: int = field(default=1280, metadata={"help": "Hidden size of the projector layer. It should match the ESM-2's output hidden size."})
|
| 294 |
+
projector_output_size: int = field(default=1024, metadata={"help": "Output size of the projector layer. It should match the LLM's hidden size."})
|
| 295 |
+
|
| 296 |
+
# New: Parameter to control projector training
|
| 297 |
+
freeze_projector: bool = field(default=False, metadata={"help": "Whether to freeze the projector layer during training."})
|
| 298 |
+
|
| 299 |
+
@dataclass
|
| 300 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 301 |
+
"""
|
| 302 |
+
Script arguments for the GRPO training script.
|
| 303 |
+
"""
|
| 304 |
+
dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."})
|
| 305 |
+
data_file_paths: str = field(
|
| 306 |
+
default=None,
|
| 307 |
+
metadata={"help": "Paths to data files, separated by ':'"},
|
| 308 |
+
)
|
| 309 |
+
arrow_cache_dir: str = field(
|
| 310 |
+
default=None,
|
| 311 |
+
metadata={"help": "Path to arrow cache directory"},
|
| 312 |
+
)
|
| 313 |
+
val_split_ratio: float = field(
|
| 314 |
+
default=0.0,
|
| 315 |
+
metadata={"help": "Ratio of validation split, default 0.0"},
|
| 316 |
+
)
|
| 317 |
+
reward_funcs: list[str] = field(
|
| 318 |
+
#default_factory=lambda: ["accuracy", "format"],
|
| 319 |
+
default_factory=lambda: ["xmlcount", "soft_format", "strict_format", "less_than_4", "correctness"],
|
| 320 |
+
#metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"},
|
| 321 |
+
metadata={"help": "List of reward functions. Possible values: 'accuracy', 'xmlcount', 'soft_format', 'strict_format', 'less_than_4', 'correctness'"},
|
| 322 |
+
)
|
| 323 |
+
# max_pixels: Optional[int] = field(
|
| 324 |
+
# default=12845056,
|
| 325 |
+
# metadata={"help": "Maximum number of pixels for the image (for QwenVL)"},
|
| 326 |
+
# )
|
| 327 |
+
# min_pixels: Optional[int] = field(
|
| 328 |
+
# default=3136,
|
| 329 |
+
# metadata={"help": "Minimum number of pixels for the image (for QwenVL)"},
|
| 330 |
+
# )
|
| 331 |
+
# task_type: Optional[str] = field(
|
| 332 |
+
# default=None,
|
| 333 |
+
# metadata={"help": "Choose task type: 'default', 'gui', ..."},
|
| 334 |
+
# )
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
reward_funcs_registry = {
|
| 339 |
+
# "accuracy": accuracy_reward,
|
| 340 |
+
# "format": format_reward,
|
| 341 |
+
"xmlcount": xmlcount_reward_func,
|
| 342 |
+
"soft_format": soft_format_reward_func,
|
| 343 |
+
"strict_format": strict_format_reward_func,
|
| 344 |
+
"less_than_4": less_than_4_reward_func,
|
| 345 |
+
"correctness": correctness_reward_func,
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
def get_vlm_module(model_name_or_path):
|
| 349 |
+
if any(mini_name in model_name_or_path.lower() for mini_name in ["qwen", "smol"]):
|
| 350 |
+
return NucleotideDNAModule
|
| 351 |
+
else:
|
| 352 |
+
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
| 353 |
+
|
| 354 |
+
def _get_target_modules(model):
|
| 355 |
+
# Apply LoRA to all linear layers in the text model
|
| 356 |
+
target_modules = []
|
| 357 |
+
|
| 358 |
+
# Get all unique linear layer names
|
| 359 |
+
seen_names = set()
|
| 360 |
+
for name, module in model.text_model.named_modules():
|
| 361 |
+
if isinstance(module, torch.nn.Linear):
|
| 362 |
+
names = name.split(".")
|
| 363 |
+
target_name = names[-1] # Use the last part of the name
|
| 364 |
+
|
| 365 |
+
# Skip output head but include all other linear layers
|
| 366 |
+
if target_name != "lm_head" and target_name not in seen_names:
|
| 367 |
+
target_modules.append(target_name)
|
| 368 |
+
seen_names.add(target_name)
|
| 369 |
+
|
| 370 |
+
# Add attention-specific layers
|
| 371 |
+
attention_patterns = [
|
| 372 |
+
"q_proj",
|
| 373 |
+
"k_proj",
|
| 374 |
+
"v_proj",
|
| 375 |
+
"out_proj",
|
| 376 |
+
"query",
|
| 377 |
+
"key",
|
| 378 |
+
"value",
|
| 379 |
+
]
|
| 380 |
+
for pattern in attention_patterns:
|
| 381 |
+
if pattern not in seen_names:
|
| 382 |
+
target_modules.append(pattern)
|
| 383 |
+
|
| 384 |
+
# Return all unique layer names to apply LoRA to all layers
|
| 385 |
+
return list(target_modules)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def _prep_for_training(model, training_args, dna_model_finetune: bool = False) -> LoraConfig:
|
| 389 |
+
"""
|
| 390 |
+
Load and configure the DNALLMModel.
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
# Freeze DNA encoder parameters
|
| 394 |
+
if dna_model_finetune:
|
| 395 |
+
pass
|
| 396 |
+
else:
|
| 397 |
+
for param in model.dna_model.parameters():
|
| 398 |
+
param.requires_grad = False
|
| 399 |
+
|
| 400 |
+
target_modules = _get_target_modules(model)
|
| 401 |
+
|
| 402 |
+
lora_config = LoraConfig(
|
| 403 |
+
r=training_args.lora_r,
|
| 404 |
+
lora_alpha=training_args.lora_alpha,
|
| 405 |
+
lora_dropout=training_args.lora_dropout,
|
| 406 |
+
target_modules=target_modules,
|
| 407 |
+
init_lora_weights="gaussian",
|
| 408 |
+
bias="none",
|
| 409 |
+
task_type="CAUSAL_LM",
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# Prepare text model for training
|
| 413 |
+
model.text_model = prepare_model_for_kbit_training(model.text_model)
|
| 414 |
+
model.text_model = get_peft_model(model.text_model, lora_config)
|
| 415 |
+
|
| 416 |
+
# Make projection layer trainable
|
| 417 |
+
for param in model.dna_projection.parameters():
|
| 418 |
+
param.requires_grad = True
|
| 419 |
+
|
| 420 |
+
return lora_config
|
| 421 |
+
|
| 422 |
+
def main(script_args, training_args, model_args):
|
| 423 |
+
|
| 424 |
+
print(training_args.output_dir)
|
| 425 |
+
#pl.seed_everything(args.seed)
|
| 426 |
+
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 427 |
+
torch.cuda.empty_cache()
|
| 428 |
+
torch.set_float32_matmul_precision("medium")
|
| 429 |
+
|
| 430 |
+
# Initialize model
|
| 431 |
+
# Load tokenizer for target text
|
| 432 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
| 433 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 434 |
+
|
| 435 |
+
# Load model
|
| 436 |
+
model = DNALLMModel(
|
| 437 |
+
text_model_name=model_args.model_name_or_path,
|
| 438 |
+
dna_model_name=model_args.dna_model_name_or_path,
|
| 439 |
+
cache_dir=model_args.cache_dir,
|
| 440 |
+
max_length_text=model_args.max_length_text,
|
| 441 |
+
max_length_dna=model_args.max_length_dna,
|
| 442 |
+
text_model_finetune=True,
|
| 443 |
+
dna_model_finetune=not model_args.freeze_dna_modules,
|
| 444 |
+
debug=False,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# load checkpoint
|
| 448 |
+
if model_args.sft_checkpoint is not None:
|
| 449 |
+
print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}")
|
| 450 |
+
|
| 451 |
+
# Determine if it's a directory (PEFT format) or file (PyTorch state dict)
|
| 452 |
+
is_directory = os.path.isdir(model_args.sft_checkpoint)
|
| 453 |
+
|
| 454 |
+
if is_directory:
|
| 455 |
+
# It's a PEFT checkpoint directory - load properly with PEFT
|
| 456 |
+
from peft import PeftModel
|
| 457 |
+
|
| 458 |
+
# First initialize the text model with PEFT
|
| 459 |
+
print("Loading as PEFT checkpoint directory")
|
| 460 |
+
model.text_model = PeftModel.from_pretrained(
|
| 461 |
+
model.text_model,
|
| 462 |
+
model_args.sft_checkpoint,
|
| 463 |
+
is_trainable=True
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
# Verify loaded adapters
|
| 467 |
+
print("Loaded LoRA adapters:", model.text_model.active_adapter)
|
| 468 |
+
|
| 469 |
+
# Optional: Merge weights into base model
|
| 470 |
+
print("Merging SFT LoRA weights into base model...")
|
| 471 |
+
model.text_model = model.text_model.merge_and_unload()
|
| 472 |
+
print("Successfully merged SFT knowledge into base model")
|
| 473 |
+
|
| 474 |
+
else:
|
| 475 |
+
# It's a PyTorch state dict file
|
| 476 |
+
print("Loading as PyTorch state dict file")
|
| 477 |
+
checkpoint = torch.load(model_args.sft_checkpoint)
|
| 478 |
+
|
| 479 |
+
# replace model.text_model with text_model for all in state dict
|
| 480 |
+
def new_key(k):
|
| 481 |
+
if k.startswith("=model."): return k[6:]
|
| 482 |
+
elif k.startswith("_forward_module."): return k[len("_forward_module."):]
|
| 483 |
+
else: return k
|
| 484 |
+
|
| 485 |
+
if "state_dict" in checkpoint:
|
| 486 |
+
magic = {new_key(k): v for k, v in checkpoint["state_dict"].items()}
|
| 487 |
+
elif "module" in checkpoint:
|
| 488 |
+
magic = {new_key(k): v for k, v in checkpoint["module"].items()}
|
| 489 |
+
elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
|
| 490 |
+
# Direct state dict - the checkpoint itself is the state dict
|
| 491 |
+
print("Detected direct state dict format")
|
| 492 |
+
magic = {new_key(k): v for k, v in checkpoint.items()}
|
| 493 |
+
else:
|
| 494 |
+
raise ValueError(f"Unsupported checkpoint format: {model_args.sft_checkpoint}")
|
| 495 |
+
|
| 496 |
+
# Handle prefix mapping for different model architectures
|
| 497 |
+
lora_prefix = False
|
| 498 |
+
for key in magic.keys():
|
| 499 |
+
if "lora" in key:
|
| 500 |
+
lora_prefix = True
|
| 501 |
+
break
|
| 502 |
+
|
| 503 |
+
if lora_prefix:
|
| 504 |
+
print("Detected LoRA weights in state dict")
|
| 505 |
+
# First prepare model for LoRA training
|
| 506 |
+
_prep_for_training(model, model_args, dna_model_finetune=model_args.freeze_dna_modules)
|
| 507 |
+
|
| 508 |
+
# Print some diagnostic info about the keys
|
| 509 |
+
model_keys = set(model.state_dict().keys())
|
| 510 |
+
checkpoint_keys = set(magic.keys())
|
| 511 |
+
print(f"Model has {len(model_keys)} keys")
|
| 512 |
+
print(f"Checkpoint has {len(checkpoint_keys)} keys")
|
| 513 |
+
|
| 514 |
+
# Try to map LoRA keys more intelligently
|
| 515 |
+
new_magic = {}
|
| 516 |
+
for k, v in magic.items():
|
| 517 |
+
# Try different prefix mappings based on common patterns
|
| 518 |
+
if "base_model.model" in k and k not in model_keys:
|
| 519 |
+
new_k = k.replace("text_model.base_model.model", "text_model")
|
| 520 |
+
if new_k in model_keys:
|
| 521 |
+
new_magic[new_k] = v
|
| 522 |
+
continue
|
| 523 |
+
|
| 524 |
+
# Try removing common prefixes
|
| 525 |
+
if k.startswith("text_model.") and k not in model_keys:
|
| 526 |
+
new_k = "text_model.base_model.model." + k[len("text_model."):]
|
| 527 |
+
if new_k in model_keys:
|
| 528 |
+
new_magic[new_k] = v
|
| 529 |
+
continue
|
| 530 |
+
|
| 531 |
+
# Keep original key if no mapping found
|
| 532 |
+
new_magic[k] = v
|
| 533 |
+
|
| 534 |
+
# Include missing target modules in diagnostic info
|
| 535 |
+
magic = new_magic
|
| 536 |
+
print(f"After key mapping: {len(magic)} keys")
|
| 537 |
+
|
| 538 |
+
# Then load weights, allowing missing/extra keys
|
| 539 |
+
result = model.load_state_dict(magic, strict=False)
|
| 540 |
+
|
| 541 |
+
if len(result.unexpected_keys) > 0:
|
| 542 |
+
print(f"Sample unexpected keys: {result.unexpected_keys[:5]}")
|
| 543 |
+
if len(result.missing_keys) > 0:
|
| 544 |
+
print(f"Sample missing keys: {result.missing_keys[:5]}")
|
| 545 |
+
|
| 546 |
+
print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
|
| 547 |
+
else:
|
| 548 |
+
print("Standard weights detected - remapping keys")
|
| 549 |
+
# Map keys to model structure
|
| 550 |
+
magic = {k.replace("text_model", "text_model.base_model.model"): v for k, v in magic.items()}
|
| 551 |
+
magic = {k.replace("dna_model", "dna_model"): v for k, v in magic.items()}
|
| 552 |
+
|
| 553 |
+
# Fix the shared memory tensors issue by making a copy of weights
|
| 554 |
+
for key in list(magic.keys()):
|
| 555 |
+
if 'lm_head.weight' in key:
|
| 556 |
+
magic[key] = magic[key].clone()
|
| 557 |
+
|
| 558 |
+
# Load weights before setting up LoRA
|
| 559 |
+
result = model.load_state_dict(magic, strict=False)
|
| 560 |
+
print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
|
| 561 |
+
|
| 562 |
+
# Now prepare for LoRA training
|
| 563 |
+
_prep_for_training(model, model_args, dna_model_finetune=model_args.freeze_dna_modules)
|
| 564 |
+
else:
|
| 565 |
+
# No checkpoint, just prepare for training
|
| 566 |
+
_prep_for_training(model, model_args, dna_model_finetune=model_args.freeze_dna_modules)
|
| 567 |
+
|
| 568 |
+
# Get reward functions
|
| 569 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 570 |
+
# reward_funcs = [
|
| 571 |
+
# xmlcount_reward_func,
|
| 572 |
+
# soft_format_reward_func,
|
| 573 |
+
# strict_format_reward_func,
|
| 574 |
+
# int_reward_func,
|
| 575 |
+
# correctness_reward_func,
|
| 576 |
+
# ]
|
| 577 |
+
print("reward_funcs:", reward_funcs)
|
| 578 |
+
|
| 579 |
+
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
| 580 |
+
print("using vlm module:", vlm_module_cls.__name__)
|
| 581 |
+
question_prompt = vlm_module_cls.get_question_template()
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
dataset = get_kegg_questions()
|
| 585 |
+
|
| 586 |
+
#dataset = get_gsm8k_questions(question_prompt)
|
| 587 |
+
|
| 588 |
+
print(dataset)
|
| 589 |
+
|
| 590 |
+
#print('ITEM ONE OF THE DATASET', dataset['train'][0])
|
| 591 |
+
|
| 592 |
+
# Custom callback to handle saving with PyTorch's native mechanism
|
| 593 |
+
custom_save_callback = SaveWithPyTorchCallback()
|
| 594 |
+
|
| 595 |
+
# Initialize the GRPO trainer with custom callback
|
| 596 |
+
trainer = DNALLMGRPOTrainer(
|
| 597 |
+
model=model,
|
| 598 |
+
reward_funcs=reward_funcs,
|
| 599 |
+
args=training_args,
|
| 600 |
+
dna_module=vlm_module_cls(),
|
| 601 |
+
train_dataset=dataset['train'],
|
| 602 |
+
eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None,
|
| 603 |
+
peft_config=get_peft_config(model_args),
|
| 604 |
+
attn_implementation=model_args.attn_implementation,
|
| 605 |
+
torch_dtype=model_args.torch_dtype,
|
| 606 |
+
callbacks=[custom_save_callback], # Add our custom callback
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# Set the trainer to save in PyTorch format instead of safetensors
|
| 610 |
+
training_args.save_safetensors = False
|
| 611 |
+
|
| 612 |
+
# Train and push the model to the Hub
|
| 613 |
+
# if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 614 |
+
# trainer.train(resume_from_checkpoint=True)
|
| 615 |
+
# else:
|
| 616 |
+
# trainer.train()
|
| 617 |
+
|
| 618 |
+
# Train and push the model to the Hub
|
| 619 |
+
trainer.train()
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
if __name__ == "__main__":
|
| 623 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
| 624 |
+
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
|
| 625 |
+
parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, GRPOModelConfig))
|
| 626 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 627 |
+
|
| 628 |
+
# Ensure we use PyTorch's save mechanism instead of safetensors
|
| 629 |
+
training_args.save_safetensors = False
|
| 630 |
+
|
| 631 |
+
main(script_args, training_args, model_args)
|
| 632 |
+
|
| 633 |
+
# parser.add_argument("--wandb_project", type=str, default="dna-text-finetune")
|
| 634 |
+
# parser.add_argument("--wandb_entity", type=str, default="adibvafa")
|
| 635 |
+
|
| 636 |
+
# args = parser.parse_args()
|
BioReason/reason_protein.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import pathlib
|
| 5 |
+
from argparse import ArgumentParser
|
| 6 |
+
from typing import List, Dict, Optional
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch import nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from torch.optim import AdamW
|
| 13 |
+
from torch.utils.data import DataLoader, Dataset
|
| 14 |
+
from transformers import get_cosine_schedule_with_warmup, AutoTokenizer
|
| 15 |
+
|
| 16 |
+
from transformers import (
|
| 17 |
+
AutoTokenizer,
|
| 18 |
+
AutoModelForCausalLM,
|
| 19 |
+
AutoModelForMaskedLM,
|
| 20 |
+
AutoProcessor,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
from datasets import load_dataset, DatasetDict
|
| 24 |
+
|
| 25 |
+
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
|
| 26 |
+
from transformers import BitsAndBytesConfig
|
| 27 |
+
|
| 28 |
+
import pytorch_lightning as pl
|
| 29 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 30 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 31 |
+
|
| 32 |
+
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
| 33 |
+
#from unsloth import FastLanguageModel, is_bfloat16_supported
|
| 34 |
+
|
| 35 |
+
from bioreason.models.dna_llm import DNALLMModel
|
| 36 |
+
from bioreason.models.protein_llm import ProteinLLMModel
|
| 37 |
+
from bioreason.dna_modules import NucleotideDNAModule
|
| 38 |
+
from bioreason.models.dl.processing_dl import DLProcessor
|
| 39 |
+
from bioreason.trainer import DNALLMGRPOTrainer, DNALLMGRPOConfig
|
| 40 |
+
from bioreason.models.evo2_tokenizer import Evo2Tokenizer, register_evo2_tokenizer
|
| 41 |
+
register_evo2_tokenizer()
|
| 42 |
+
|
| 43 |
+
# Custom TrainerCallback to override the saving mechanism
|
| 44 |
+
from transformers import TrainerCallback, TrainerState, TrainerControl
|
| 45 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
| 46 |
+
|
| 47 |
+
class SaveWithPyTorchCallback(TrainerCallback):
|
| 48 |
+
"""Custom callback to save models with PyTorch's native save mechanism instead of safetensors"""
|
| 49 |
+
def on_save(self, args, state, control, **kwargs):
|
| 50 |
+
# Get the checkpoint folder
|
| 51 |
+
checkpoint_folder = os.path.join(
|
| 52 |
+
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
|
| 53 |
+
)
|
| 54 |
+
os.makedirs(checkpoint_folder, exist_ok=True)
|
| 55 |
+
|
| 56 |
+
# Save with PyTorch instead of safetensors
|
| 57 |
+
checkpoint_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
|
| 58 |
+
model = kwargs.get("model")
|
| 59 |
+
|
| 60 |
+
# Get model unwrapped from accelerator etc.
|
| 61 |
+
unwrapped_model = model.module if hasattr(model, "module") else model
|
| 62 |
+
|
| 63 |
+
# Save using PyTorch directly
|
| 64 |
+
torch.save(unwrapped_model.state_dict(), checkpoint_path)
|
| 65 |
+
|
| 66 |
+
# DNALLMModel doesn't have a direct config attribute, so we need to save
|
| 67 |
+
# the configs of its sub-models
|
| 68 |
+
if hasattr(unwrapped_model, "text_model"):
|
| 69 |
+
if hasattr(unwrapped_model.text_model, "config"):
|
| 70 |
+
unwrapped_model.text_model.config.save_pretrained(checkpoint_folder)
|
| 71 |
+
# Handle PEFT models which might have base_model
|
| 72 |
+
elif hasattr(unwrapped_model.text_model, "base_model") and hasattr(unwrapped_model.text_model.base_model, "config"):
|
| 73 |
+
unwrapped_model.text_model.base_model.config.save_pretrained(checkpoint_folder)
|
| 74 |
+
|
| 75 |
+
# Print info about what's being saved
|
| 76 |
+
print(f"Saved model checkpoint to {checkpoint_folder}")
|
| 77 |
+
lora_params = [k for k in unwrapped_model.state_dict().keys() if "lora" in k]
|
| 78 |
+
print(f"Checkpoint contains {len(lora_params)} LoRA parameters")
|
| 79 |
+
|
| 80 |
+
# Signal that we've saved
|
| 81 |
+
control.should_save = False
|
| 82 |
+
return control
|
| 83 |
+
|
| 84 |
+
def _get_target_modules(model: ProteinLLMModel):
|
| 85 |
+
# Apply LoRA to all linear layers in the text model
|
| 86 |
+
target_modules = []
|
| 87 |
+
|
| 88 |
+
# Get all unique linear layer names
|
| 89 |
+
seen_names = set()
|
| 90 |
+
for name, module in model.text.named_modules():
|
| 91 |
+
if isinstance(module, torch.nn.Linear):
|
| 92 |
+
names = name.split(".")
|
| 93 |
+
target_name = names[-1] # Use the last part of the name
|
| 94 |
+
|
| 95 |
+
# Skip output head but include all other linear layers
|
| 96 |
+
if target_name != "lm_head" and target_name not in seen_names:
|
| 97 |
+
target_modules.append(target_name)
|
| 98 |
+
seen_names.add(target_name)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Add attention-specific layers commonly found in transformers
|
| 102 |
+
attention_patterns = [
|
| 103 |
+
"q_proj",
|
| 104 |
+
"k_proj",
|
| 105 |
+
"v_proj",
|
| 106 |
+
"o_proj",
|
| 107 |
+
"out_proj",
|
| 108 |
+
"query",
|
| 109 |
+
"key",
|
| 110 |
+
"value",
|
| 111 |
+
"gate_proj",
|
| 112 |
+
"up_proj",
|
| 113 |
+
"down_proj",
|
| 114 |
+
]
|
| 115 |
+
for pattern in attention_patterns:
|
| 116 |
+
if pattern not in seen_names:
|
| 117 |
+
target_modules.append(pattern)
|
| 118 |
+
|
| 119 |
+
# Return all unique layer names to apply LoRA to all layers
|
| 120 |
+
return list(target_modules)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def extract_xml_answer(text: str) -> str:
|
| 124 |
+
# answer = text.split("<answer>")[-1]
|
| 125 |
+
# answer = answer.split("</answer>")[0]
|
| 126 |
+
answer = text.split("</think>")[-1]
|
| 127 |
+
return answer.strip()
|
| 128 |
+
|
| 129 |
+
def extract_hash_answer(text: str) -> str | None:
|
| 130 |
+
if "####" not in text:
|
| 131 |
+
return None
|
| 132 |
+
return text.split("####")[1].strip()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_kegg_questions() -> Dataset:
|
| 136 |
+
data = load_dataset('wanglab/kegg', 'default') # type: ignore
|
| 137 |
+
# 修改为蛋白质序列示例
|
| 138 |
+
example_protein_sequences = ["MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
|
| 139 |
+
"MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKRISSKLLERGKTHYPPHTMVGTGVLVTKMRVAGQEPDVQGPHAGIVVQGAGDAPVVVKPVVEMLNRMVVVVSGSAAPVVVNNNNNGAAAAAAA",
|
| 140 |
+
"MSQVQVQVQNQALNTLVKQLGRVLLQGKGRPPLQGFRIIEQNGGDSPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPPP"]
|
| 141 |
+
num_protein_sequences = 2
|
| 142 |
+
|
| 143 |
+
data = data.map(lambda x: { # type: ignore
|
| 144 |
+
'prompt': [
|
| 145 |
+
{
|
| 146 |
+
'role': 'user',
|
| 147 |
+
'content': [
|
| 148 |
+
*({'type': 'protein', 'text': None} for _ in range(num_protein_sequences)),
|
| 149 |
+
{'type': 'text', 'text': x['question']},
|
| 150 |
+
],
|
| 151 |
+
},
|
| 152 |
+
],
|
| 153 |
+
'protein_sequences': [example_protein_sequences[0], example_protein_sequences[1]], # 使用蛋白质序列
|
| 154 |
+
'answer': x['answer'],
|
| 155 |
+
}) # type: ignore
|
| 156 |
+
|
| 157 |
+
return data
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# uncomment middle messages for 1-shot prompting
|
| 161 |
+
|
| 162 |
+
# Reward functions
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
| 167 |
+
responses = [completion[0]['content'] for completion in completions]
|
| 168 |
+
q = prompts[0][-1]['content']
|
| 169 |
+
extracted_responses = [extract_xml_answer(r) for r in responses]
|
| 170 |
+
# extracted_responses = [r.lower().replace("answer:", "").strip() for r in extracted_responses]
|
| 171 |
+
print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
|
| 172 |
+
return [2.0 if a.lower() in r.lower() else 0.0 for r, a in zip(extracted_responses, answer[0])]
|
| 173 |
+
|
| 174 |
+
def less_than_4_reward_func(completions, **kwargs) -> list[float]:
|
| 175 |
+
responses = [completion[0]['content'] for completion in completions]
|
| 176 |
+
extracted_responses = [extract_xml_answer(r) for r in responses]
|
| 177 |
+
return [0.5 if len(r.split(' ')) <= 4 else 0.0 for r in extracted_responses]
|
| 178 |
+
|
| 179 |
+
def strict_format_reward_func(completions, **kwargs) -> list[float]:
|
| 180 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 181 |
+
pattern = r"^<think>\n.*?\n</think>\n.*?\n$"
|
| 182 |
+
responses = [completion[0]["content"] for completion in completions]
|
| 183 |
+
matches = [re.match(pattern, r) for r in responses]
|
| 184 |
+
return [0.5 if match else 0.0 for match in matches]
|
| 185 |
+
|
| 186 |
+
def soft_format_reward_func(completions, **kwargs) -> list[float]:
|
| 187 |
+
"""Reward function that checks if the completion has a specific format."""
|
| 188 |
+
pattern = r"<think>.*?</think>\s*.*?"
|
| 189 |
+
responses = [completion[0]["content"] for completion in completions]
|
| 190 |
+
matches = [re.match(pattern, r) for r in responses]
|
| 191 |
+
return [0.5 if match else 0.0 for match in matches]
|
| 192 |
+
|
| 193 |
+
def count_xml(text) -> float:
|
| 194 |
+
count = 0.0
|
| 195 |
+
if text.count("<think>\n") == 1:
|
| 196 |
+
count += 0.125
|
| 197 |
+
if text.count("\n</think>\n") == 1:
|
| 198 |
+
count += 0.125
|
| 199 |
+
return count
|
| 200 |
+
|
| 201 |
+
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
|
| 202 |
+
contents = [completion[0]["content"] for completion in completions]
|
| 203 |
+
return [count_xml(c) for c in contents]
|
| 204 |
+
|
| 205 |
+
#补充奖励函数
|
| 206 |
+
def repeatness_reward(s: str):
|
| 207 |
+
"""计算文本重复度,返回值越高表示重复度越低"""
|
| 208 |
+
def ranks(l):
|
| 209 |
+
index = {v: i for i, v in enumerate(sorted(set(l)))}
|
| 210 |
+
return [index[v] for v in l]
|
| 211 |
+
|
| 212 |
+
def suffixArray(s):
|
| 213 |
+
line = ranks(s)
|
| 214 |
+
n, k, ans, sa = len(s), 1, line, [0] * len(s)
|
| 215 |
+
while k < n - 1:
|
| 216 |
+
line = ranks(list(zip_longest(line, islice(line, k, None), fillvalue=-1)))
|
| 217 |
+
ans, k = line, k << 1
|
| 218 |
+
for i, k in enumerate(ans):
|
| 219 |
+
sa[k] = i
|
| 220 |
+
return ans, sa
|
| 221 |
+
|
| 222 |
+
def lcp(arr, suffixArr, inv_suff):
|
| 223 |
+
n, ans, k = len(arr), [0] * len(arr), 0
|
| 224 |
+
for i in range(n):
|
| 225 |
+
if inv_suff[i] == n - 1:
|
| 226 |
+
k = 0
|
| 227 |
+
continue
|
| 228 |
+
j = suffixArr[inv_suff[i] + 1]
|
| 229 |
+
while i + k < n and j + k < n and arr[i + k] == arr[j + k]:
|
| 230 |
+
k += 1
|
| 231 |
+
ans[inv_suff[i]] = k
|
| 232 |
+
if k > 0:
|
| 233 |
+
k -= 1
|
| 234 |
+
return ans
|
| 235 |
+
|
| 236 |
+
arr = [ord(i) for i in s]
|
| 237 |
+
n = len(arr)
|
| 238 |
+
if n <= 1:
|
| 239 |
+
return 0
|
| 240 |
+
c, sa = suffixArray(arr)
|
| 241 |
+
cnt = sum(lcp(arr, sa, c))
|
| 242 |
+
return 1 - cnt * 2 / (n * (n + 1))
|
| 243 |
+
|
| 244 |
+
def format_reward(predict_str: str) -> float:
|
| 245 |
+
"""
|
| 246 |
+
格式奖励函数,严格要求输出格式为:
|
| 247 |
+
<think>...</think><answer>...</answer>
|
| 248 |
+
中间不能有多余内容
|
| 249 |
+
"""
|
| 250 |
+
pattern = r'^<think>.*?</think>\s*<answer>\s*.*?\s*</answer>$'
|
| 251 |
+
return 1.0 if re.fullmatch(pattern, predict_str.strip(), re.DOTALL) else 0.0
|
| 252 |
+
|
| 253 |
+
def acc_reward(predict_str: str, ground_truth) -> float:
|
| 254 |
+
"""
|
| 255 |
+
准确率奖励函数
|
| 256 |
+
要求<answer>中内容与ground_truth完全一致
|
| 257 |
+
"""
|
| 258 |
+
match = re.search(r'<answer>\s*([^<]*?)\s*</answer>', predict_str)
|
| 259 |
+
if not match:
|
| 260 |
+
return 0.0
|
| 261 |
+
answer_content = match.group(1).strip()
|
| 262 |
+
|
| 263 |
+
# 处理不同类型的ground_truth
|
| 264 |
+
if isinstance(ground_truth, str):
|
| 265 |
+
return 1.0 if answer_content == ground_truth else 0.0
|
| 266 |
+
elif isinstance(ground_truth, (int, float)):
|
| 267 |
+
try:
|
| 268 |
+
# 尝试将答案转换为数字进行比较
|
| 269 |
+
return 1.0 if float(answer_content) == float(ground_truth) else 0.0
|
| 270 |
+
except ValueError:
|
| 271 |
+
# 如果转换失败,尝试字符串比较
|
| 272 |
+
return 1.0 if answer_content == str(ground_truth) else 0.0
|
| 273 |
+
else:
|
| 274 |
+
# 其他类型,转换为字符串比较
|
| 275 |
+
return 1.0 if answer_content == str(ground_truth) else 0.0
|
| 276 |
+
|
| 277 |
+
# 包装函数以适配现有的奖励函数接口
|
| 278 |
+
def repeatness_reward_func(completions, **kwargs) -> list[float]:
|
| 279 |
+
"""重复度奖励函数包装器"""
|
| 280 |
+
responses = [completion[0]['content'] for completion in completions]
|
| 281 |
+
return [repeatness_reward(r) for r in responses]
|
| 282 |
+
|
| 283 |
+
def format_reward_func(completions, **kwargs) -> list[float]:
|
| 284 |
+
"""格式奖励函数包装器"""
|
| 285 |
+
responses = [completion[0]['content'] for completion in completions]
|
| 286 |
+
return [format_reward(r) for r in responses]
|
| 287 |
+
|
| 288 |
+
def acc_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
|
| 289 |
+
"""准确率奖励函数包装器"""
|
| 290 |
+
responses = [completion[0]['content'] for completion in completions]
|
| 291 |
+
|
| 292 |
+
# 调试信息
|
| 293 |
+
print(f"DEBUG acc_reward_func - answer type: {type(answer)}, answer: {answer}")
|
| 294 |
+
|
| 295 |
+
# 根据现有代码的模式,answer可能是一个嵌套结构
|
| 296 |
+
try:
|
| 297 |
+
if isinstance(answer, list) and len(answer) > 0:
|
| 298 |
+
# 如果answer[0]是一个列表,说明是批次数据
|
| 299 |
+
if isinstance(answer[0], list):
|
| 300 |
+
ground_truths = answer[0]
|
| 301 |
+
else:
|
| 302 |
+
# 如果answer[0]是单个值,为所有响应使用相同的真实答案
|
| 303 |
+
ground_truths = [answer[0]] * len(responses)
|
| 304 |
+
else:
|
| 305 |
+
# 如果answer不是期望的格式,返回全0
|
| 306 |
+
print(f"DEBUG: Unexpected answer format, returning zeros")
|
| 307 |
+
return [0.0] * len(responses)
|
| 308 |
+
except (IndexError, TypeError) as e:
|
| 309 |
+
print(f"DEBUG: Error processing answer: {e}, returning zeros")
|
| 310 |
+
return [0.0] * len(responses)
|
| 311 |
+
|
| 312 |
+
print(f"DEBUG: ground_truths: {ground_truths}")
|
| 313 |
+
|
| 314 |
+
# 确保responses和ground_truths长度一致
|
| 315 |
+
rewards = []
|
| 316 |
+
for i, response in enumerate(responses):
|
| 317 |
+
if i < len(ground_truths):
|
| 318 |
+
reward = acc_reward(response, ground_truths[i])
|
| 319 |
+
print(f"DEBUG: response {i}: '{response[:100]}...', ground_truth: '{ground_truths[i]}', reward: {reward}")
|
| 320 |
+
else:
|
| 321 |
+
# 如果ground_truths不够长,使用第一个值
|
| 322 |
+
reward = acc_reward(response, ground_truths[0] if ground_truths else "")
|
| 323 |
+
print(f"DEBUG: response {i} (fallback): reward: {reward}")
|
| 324 |
+
rewards.append(reward)
|
| 325 |
+
|
| 326 |
+
return rewards
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
#
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# Format into conversation
|
| 334 |
+
def make_conversation(example):
|
| 335 |
+
return {
|
| 336 |
+
"prompt": [
|
| 337 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 338 |
+
{"role": "user", "content": example["problem"]},
|
| 339 |
+
],
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
def make_conversation_image(example):
|
| 343 |
+
return {
|
| 344 |
+
"prompt": [
|
| 345 |
+
{
|
| 346 |
+
"role": "user",
|
| 347 |
+
"content": [
|
| 348 |
+
{"type": "image"},
|
| 349 |
+
],
|
| 350 |
+
},
|
| 351 |
+
],
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
@dataclass
|
| 355 |
+
|
| 356 |
+
class GRPOModelConfig(ModelConfig):
|
| 357 |
+
|
| 358 |
+
model_name_or_path: str = field(default="Qwen/Qwen3-0.6B", metadata={"help": "Model checkpoint for LLM weights initialization."})
|
| 359 |
+
protein_model_name_or_path: str = field(default="esm2_t33_650M_UR50D", metadata={"help": "Model checkpoint for ESM-2 protein weights initialization."})
|
| 360 |
+
cache_dir: str = field(default=None, metadata={"help": "Path to model cache directory."})
|
| 361 |
+
max_length_text: int = field(default=800, metadata={"help": "Maximum length of text sequences."})
|
| 362 |
+
max_length_protein: int = field(default=800, metadata={"help": "Maximum length of protein sequences (number of amino acids)."})
|
| 363 |
+
sft_checkpoint: str = field(default=None, metadata={"help": "Path to the checkpoint for SFT."})
|
| 364 |
+
lora_r: int = field(default=32, metadata={"help": "LoRA R value."})
|
| 365 |
+
lora_alpha: int = field(default=64, metadata={"help": "LoRA alpha."})
|
| 366 |
+
lora_dropout: float = field(default=0.05, metadata={"help": "LoRA dropout."})
|
| 367 |
+
lora_modules_to_save: Optional[list[str]] = field(
|
| 368 |
+
default_factory=lambda: ["embed_tokens", "lm_head"],
|
| 369 |
+
metadata={"help": "Model layers to unfreeze & train with LoRA."},
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
# Updated: Renamed `freeze_dna_modules` to `freeze_protein_model`
|
| 373 |
+
freeze_protein_model: bool = field(default=True, metadata={"help": "Whether to freeze the ESM-2 protein model during training."})
|
| 374 |
+
|
| 375 |
+
num_query_tokens: int = field(default=32, metadata={"help": "Number of query tokens for QFormer."})
|
| 376 |
+
qformer_num_layers: int = field(default=6, metadata={"help": "Number of layers in QFormer."})
|
| 377 |
+
qformer_num_heads: int = field(default=8, metadata={"help": "Number of attention heads in QFormer."})
|
| 378 |
+
qformer_dropout: float = field(default=0.1, metadata={"help": "Dropout rate for QFormer."})
|
| 379 |
+
|
| 380 |
+
@dataclass
|
| 381 |
+
class GRPOScriptArguments(ScriptArguments):
|
| 382 |
+
"""
|
| 383 |
+
Script arguments for the GRPO training script.
|
| 384 |
+
"""
|
| 385 |
+
dataset_name: str = field(default="wanglab/kegg", metadata={"help": "Dataset name with default."})
|
| 386 |
+
data_file_paths: str = field(
|
| 387 |
+
default=None,
|
| 388 |
+
metadata={"help": "Paths to data files, separated by ':'"},
|
| 389 |
+
)
|
| 390 |
+
arrow_cache_dir: str = field(
|
| 391 |
+
default=None,
|
| 392 |
+
metadata={"help": "Path to arrow cache directory"},
|
| 393 |
+
)
|
| 394 |
+
val_split_ratio: float = field(
|
| 395 |
+
default=0.0,
|
| 396 |
+
metadata={"help": "Ratio of validation split, default 0.0"},
|
| 397 |
+
)
|
| 398 |
+
reward_funcs: list[str] = field(
|
| 399 |
+
# 更新默认奖励函数列表,包含新的三个函数
|
| 400 |
+
default_factory=lambda: ["repeatness", "format", "acc", "xmlcount", "soft_format"],
|
| 401 |
+
metadata={"help": "List of reward functions. Possible values: 'repeatness', 'format', 'acc', 'xmlcount', 'soft_format', 'strict_format', 'less_than_4', 'correctness'"},
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
reward_funcs_registry = {
|
| 405 |
+
# "accuracy": accuracy_reward,
|
| 406 |
+
# "format": format_reward,
|
| 407 |
+
"repeatness": repeatness_reward_func,
|
| 408 |
+
"format": format_reward_func,
|
| 409 |
+
"acc": acc_reward_func,
|
| 410 |
+
"xmlcount": xmlcount_reward_func,
|
| 411 |
+
"soft_format": soft_format_reward_func,
|
| 412 |
+
"strict_format": strict_format_reward_func,
|
| 413 |
+
"less_than_4": less_than_4_reward_func,
|
| 414 |
+
"correctness": correctness_reward_func,
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
def get_vlm_module(model_name_or_path):
|
| 418 |
+
if any(mini_name in model_name_or_path.lower() for mini_name in ["qwen", "smol"]):
|
| 419 |
+
# 如果你有专门的蛋白质模块,使用它
|
| 420 |
+
try:
|
| 421 |
+
from bioreason.protein_modules import ProteinModule
|
| 422 |
+
return ProteinModule
|
| 423 |
+
except ImportError:
|
| 424 |
+
# 如果没有专门的蛋白质模块,检查DNAModule是否兼容
|
| 425 |
+
print("Warning: Using NucleotideDNAModule for protein processing. Consider creating a dedicated ProteinModule.")
|
| 426 |
+
return NucleotideDNAModule
|
| 427 |
+
else:
|
| 428 |
+
raise ValueError(f"Unsupported model: {model_name_or_path}")
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def _prep_for_training(model: ProteinLLMModel, model_args, protein_model_finetune: bool = False) -> LoraConfig:
|
| 432 |
+
"""
|
| 433 |
+
准备ProteinLLMModel进行训练。
|
| 434 |
+
"""
|
| 435 |
+
# Freeze protein encoder parameters if not finetuning
|
| 436 |
+
if not protein_model_finetune:
|
| 437 |
+
for param in model.protein_model.parameters():
|
| 438 |
+
param.requires_grad = False
|
| 439 |
+
print("Frozen protein model parameters")
|
| 440 |
+
else:
|
| 441 |
+
print("Protein model parameters will be finetuned")
|
| 442 |
+
|
| 443 |
+
# Get target modules for LoRA
|
| 444 |
+
target_modules = _get_target_modules(model)
|
| 445 |
+
print(f"LoRA target modules: {target_modules}")
|
| 446 |
+
|
| 447 |
+
lora_config = LoraConfig(
|
| 448 |
+
r=model_args.lora_r,
|
| 449 |
+
lora_alpha=model_args.lora_alpha,
|
| 450 |
+
lora_dropout=model_args.lora_dropout,
|
| 451 |
+
target_modules=target_modules,
|
| 452 |
+
init_lora_weights="gaussian",
|
| 453 |
+
bias="none",
|
| 454 |
+
task_type="CAUSAL_LM",
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# Prepare text model for training
|
| 458 |
+
model.text_model = prepare_model_for_kbit_training(model.text_model)
|
| 459 |
+
model.text_model = get_peft_model(model.text_model, lora_config)
|
| 460 |
+
|
| 461 |
+
# Make QFormer projection layer trainable
|
| 462 |
+
for param in model.protein_projection.parameters():
|
| 463 |
+
param.requires_grad = True
|
| 464 |
+
print("QFormer projection layer set as trainable")
|
| 465 |
+
|
| 466 |
+
# Print trainable parameters info
|
| 467 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 468 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 469 |
+
print(f"Trainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
|
| 470 |
+
|
| 471 |
+
return lora_config
|
| 472 |
+
|
| 473 |
+
######################################################################
|
| 474 |
+
######################################################################
|
| 475 |
+
|
| 476 |
+
def main(script_args, training_args, model_args):
|
| 477 |
+
|
| 478 |
+
print(training_args.output_dir)
|
| 479 |
+
#pl.seed_everything(args.seed)
|
| 480 |
+
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
| 481 |
+
torch.cuda.empty_cache()
|
| 482 |
+
torch.set_float32_matmul_precision("medium")
|
| 483 |
+
|
| 484 |
+
# Initialize model
|
| 485 |
+
# Load tokenizer for target text
|
| 486 |
+
# tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
| 487 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 488 |
+
|
| 489 |
+
# Load model
|
| 490 |
+
# model = ProteinLLMModel(
|
| 491 |
+
# text_model_name=model_args.model_name_or_path,
|
| 492 |
+
# dna_model_name=model_args.dna_model_name_or_path,
|
| 493 |
+
# cache_dir=model_args.cache_dir,
|
| 494 |
+
# max_length_text=model_args.max_length_text,
|
| 495 |
+
# max_length_dna=model_args.max_length_dna,
|
| 496 |
+
# text_model_finetune=True,
|
| 497 |
+
# dna_model_finetune=not model_args.freeze_dna_modules,
|
| 498 |
+
# debug=False,
|
| 499 |
+
# )
|
| 500 |
+
print("Initializing ProteinLLMModel...")
|
| 501 |
+
model = ProteinLLMModel(
|
| 502 |
+
text_model_name=model_args.model_name_or_path,
|
| 503 |
+
protein_model_name=model_args.protein_model_name_or_path,
|
| 504 |
+
biomedbert_model_name=getattr(model_args, 'biomedbert_model_name',
|
| 505 |
+
"microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"),
|
| 506 |
+
cache_dir=model_args.cache_dir,
|
| 507 |
+
max_length_text=model_args.max_length_text,
|
| 508 |
+
max_length_protein=model_args.max_length_protein,
|
| 509 |
+
text_model_finetune=True,
|
| 510 |
+
protein_model_finetune=not model_args.freeze_protein_modules,
|
| 511 |
+
biomedbert_finetune=getattr(model_args, 'biomedbert_finetune', True), # 新增:控制BiomedBERT微调
|
| 512 |
+
# Q-Former相关参数(简化了,因为直接使用BiomedBERT)
|
| 513 |
+
qformer_num_query_tokens=getattr(model_args, 'qformer_num_query_tokens', 8), # 重命名为qformer_num_query_tokens
|
| 514 |
+
)
|
| 515 |
+
# load checkpoint
|
| 516 |
+
if model_args.sft_checkpoint is not None:
|
| 517 |
+
print(f"Loading SFT checkpoint from {model_args.sft_checkpoint}")
|
| 518 |
+
|
| 519 |
+
# Determine if it's a directory (PEFT format) or file (PyTorch state dict)
|
| 520 |
+
is_directory = os.path.isdir(model_args.sft_checkpoint)
|
| 521 |
+
|
| 522 |
+
if is_directory:
|
| 523 |
+
# It's a PEFT checkpoint directory - load properly with PEFT
|
| 524 |
+
from peft import PeftModel
|
| 525 |
+
|
| 526 |
+
# First initialize the text model with PEFT
|
| 527 |
+
print("Loading as PEFT checkpoint directory")
|
| 528 |
+
model.text_model = PeftModel.from_pretrained(
|
| 529 |
+
model.text_model,
|
| 530 |
+
model_args.sft_checkpoint,
|
| 531 |
+
is_trainable=True
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Verify loaded adapters
|
| 535 |
+
print("Loaded LoRA adapters:", model.text_model.active_adapter)
|
| 536 |
+
|
| 537 |
+
# Optional: Merge weights into base model
|
| 538 |
+
print("Merging SFT LoRA weights into base model...")
|
| 539 |
+
model.text_model = model.text_model.merge_and_unload()
|
| 540 |
+
print("Successfully merged SFT knowledge into base model")
|
| 541 |
+
|
| 542 |
+
else:
|
| 543 |
+
# It's a PyTorch state dict file
|
| 544 |
+
print("Loading as PyTorch state dict file")
|
| 545 |
+
checkpoint = torch.load(model_args.sft_checkpoint, map_location="cpu")
|
| 546 |
+
|
| 547 |
+
# replace model.text_model with text_model for all in state dict
|
| 548 |
+
def new_key(k):
|
| 549 |
+
if k.startswith("=model."): return k[6:]
|
| 550 |
+
elif k.startswith("_forward_module."): return k[len("_forward_module."):]
|
| 551 |
+
else: return k
|
| 552 |
+
|
| 553 |
+
if "state_dict" in checkpoint:
|
| 554 |
+
magic = {new_key(k): v for k, v in checkpoint["state_dict"].items()}
|
| 555 |
+
elif "module" in checkpoint:
|
| 556 |
+
magic = {new_key(k): v for k, v in checkpoint["module"].items()}
|
| 557 |
+
elif isinstance(checkpoint, dict) and all(isinstance(k, str) for k in checkpoint.keys()):
|
| 558 |
+
# Direct state dict - the checkpoint itself is the state dict
|
| 559 |
+
print("Detected direct state dict format")
|
| 560 |
+
magic = {new_key(k): v for k, v in checkpoint.items()}
|
| 561 |
+
else:
|
| 562 |
+
raise ValueError(f"Unsupported checkpoint format: {model_args.sft_checkpoint}")
|
| 563 |
+
|
| 564 |
+
# Handle prefix mapping for different model architectures
|
| 565 |
+
lora_prefix = any("lora" in key for key in state_dict.keys())
|
| 566 |
+
|
| 567 |
+
if lora_prefix:
|
| 568 |
+
print("Detected LoRA weights in state dict")
|
| 569 |
+
# First prepare model for LoRA training
|
| 570 |
+
_prep_for_training(model, model_args, protein_model_finetune=model_args.freeze_protein_modules)
|
| 571 |
+
|
| 572 |
+
# Print diagnostic info
|
| 573 |
+
model_keys = set(model.state_dict().keys())
|
| 574 |
+
checkpoint_keys = set(state_dict.keys())
|
| 575 |
+
print(f"Model has {len(model_keys)} keys")
|
| 576 |
+
print(f"Checkpoint has {len(checkpoint_keys)} keys")
|
| 577 |
+
|
| 578 |
+
# Intelligent key mapping for different prefixes
|
| 579 |
+
new_state_dict = {}
|
| 580 |
+
for k, v in state_dict.items():
|
| 581 |
+
# Handle different common prefix patterns
|
| 582 |
+
if "base_model.model" in k and k not in model_keys:
|
| 583 |
+
new_k = k.replace("text_model.base_model.model", "text_model")
|
| 584 |
+
if new_k in model_keys:
|
| 585 |
+
new_state_dict[new_k] = v
|
| 586 |
+
continue
|
| 587 |
+
|
| 588 |
+
# Try removing/adding prefixes
|
| 589 |
+
if k.startswith("text_model.") and k not in model_keys:
|
| 590 |
+
new_k = "text_model.base_model.model." + k[len("text_model."):]
|
| 591 |
+
if new_k in model_keys:
|
| 592 |
+
new_state_dict[new_k] = v
|
| 593 |
+
continue
|
| 594 |
+
|
| 595 |
+
# Keep original key
|
| 596 |
+
new_state_dict[k] = v
|
| 597 |
+
|
| 598 |
+
state_dict = new_state_dict
|
| 599 |
+
print(f"After key mapping: {len(state_dict)} keys")
|
| 600 |
+
|
| 601 |
+
# Load state dict with missing/unexpected keys allowed
|
| 602 |
+
result = model.load_state_dict(state_dict, strict=False)
|
| 603 |
+
|
| 604 |
+
if len(result.unexpected_keys) > 0:
|
| 605 |
+
print(f"Sample unexpected keys: {result.unexpected_keys[:5]}")
|
| 606 |
+
if len(result.missing_keys) > 0:
|
| 607 |
+
print(f"Sample missing keys: {result.missing_keys[:5]}")
|
| 608 |
+
|
| 609 |
+
print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
|
| 610 |
+
else:
|
| 611 |
+
print("Standard weights detected - loading before LoRA setup")
|
| 612 |
+
|
| 613 |
+
# Handle shared memory issue for embedding weights
|
| 614 |
+
for key in list(state_dict.keys()):
|
| 615 |
+
if 'lm_head.weight' in key:
|
| 616 |
+
state_dict[key] = state_dict[key].clone()
|
| 617 |
+
|
| 618 |
+
# Load weights before setting up LoRA
|
| 619 |
+
result = model.load_state_dict(state_dict, strict=False)
|
| 620 |
+
print(f"Loaded checkpoint with {len(result.missing_keys)} missing keys and {len(result.unexpected_keys)} unexpected keys")
|
| 621 |
+
|
| 622 |
+
# Now prepare for LoRA training
|
| 623 |
+
_prep_for_training(model, model_args, protein_model_finetune=model_args.freeze_protein_modules)
|
| 624 |
+
else:
|
| 625 |
+
# No checkpoint, just prepare for training
|
| 626 |
+
_prep_for_training(model, model_args, protein_model_finetune=not model_args.freeze_protein_model)
|
| 627 |
+
|
| 628 |
+
# Get reward functions
|
| 629 |
+
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
|
| 630 |
+
# reward_funcs = [
|
| 631 |
+
# xmlcount_reward_func,
|
| 632 |
+
# soft_format_reward_func,
|
| 633 |
+
# strict_format_reward_func,
|
| 634 |
+
# int_reward_func,
|
| 635 |
+
# correctness_reward_func,
|
| 636 |
+
# ]
|
| 637 |
+
print("reward_funcs:", [func.__name__ for func in reward_funcs])
|
| 638 |
+
|
| 639 |
+
vlm_module_cls = get_vlm_module(model_args.model_name_or_path)
|
| 640 |
+
print("using vlm module:", vlm_module_cls.__name__)
|
| 641 |
+
question_prompt = vlm_module_cls.get_question_template()
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
dataset = get_kegg_questions()
|
| 645 |
+
|
| 646 |
+
#dataset = get_gsm8k_questions(question_prompt)
|
| 647 |
+
|
| 648 |
+
print(dataset)
|
| 649 |
+
|
| 650 |
+
#print('ITEM ONE OF THE DATASET', dataset['train'][0])
|
| 651 |
+
|
| 652 |
+
# Custom callback to handle saving with PyTorch's native mechanism
|
| 653 |
+
custom_save_callback = SaveWithPyTorchCallback()
|
| 654 |
+
|
| 655 |
+
# Initialize the GRPO trainer with custom callback
|
| 656 |
+
trainer = DNALLMGRPOTrainer(
|
| 657 |
+
model=model,
|
| 658 |
+
reward_funcs=reward_funcs,
|
| 659 |
+
args=training_args,
|
| 660 |
+
dna_module=vlm_module_cls(),
|
| 661 |
+
train_dataset=dataset['train'],
|
| 662 |
+
eval_dataset=dataset['val'] if training_args.eval_strategy != "no" else None,
|
| 663 |
+
peft_config=get_peft_config(model_args),
|
| 664 |
+
attn_implementation=model_args.attn_implementation,
|
| 665 |
+
torch_dtype=model_args.torch_dtype,
|
| 666 |
+
callbacks=[custom_save_callback], # Add our custom callback
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
# Set the trainer to save in PyTorch format instead of safetensors
|
| 670 |
+
training_args.save_safetensors = False
|
| 671 |
+
|
| 672 |
+
# Train and push the model to the Hub
|
| 673 |
+
# if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
|
| 674 |
+
# trainer.train(resume_from_checkpoint=True)
|
| 675 |
+
# else:
|
| 676 |
+
# trainer.train()
|
| 677 |
+
|
| 678 |
+
# Train and push the model to the Hub
|
| 679 |
+
trainer.train()
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
if __name__ == "__main__":
|
| 683 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
| 684 |
+
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
|
| 685 |
+
parser = TrlParser((GRPOScriptArguments, DNALLMGRPOConfig, GRPOModelConfig))
|
| 686 |
+
script_args, training_args, model_args = parser.parse_args_and_config()
|
| 687 |
+
|
| 688 |
+
# Ensure we use PyTorch's save mechanism instead of safetensors
|
| 689 |
+
training_args.save_safetensors = False
|
| 690 |
+
|
| 691 |
+
main(script_args, training_args, model_args)
|
| 692 |
+
|
| 693 |
+
# parser.add_argument("--wandb_project", type=str, default="dna-text-finetune")
|
| 694 |
+
# parser.add_argument("--wandb_entity", type=str, default="adibvafa")
|
| 695 |
+
|
| 696 |
+
# args = parser.parse_args()
|
BioReason/requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
transformers
|
| 4 |
+
accelerate
|
| 5 |
+
qwen-vl-utils
|
| 6 |
+
jupyter
|
| 7 |
+
datasets
|
| 8 |
+
peft
|
| 9 |
+
pytorch_lightning
|
| 10 |
+
wandb
|
| 11 |
+
trl[vllm]
|
| 12 |
+
bitsandbytes
|
| 13 |
+
deepspeed
|
BioReason/sh_reason.sh
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=Qwen3_1.7B_SFT_RL # Name of the job
|
| 3 |
+
#SBATCH --gres=gpu:4 # Number of GPUs
|
| 4 |
+
#SBATCH -p a100 # Partition
|
| 5 |
+
#SBATCH -c 12 # Number of cores
|
| 6 |
+
#SBATCH --time=12:00:00 # Time limit
|
| 7 |
+
#SBATCH --mem=128gb # Memory limit
|
| 8 |
+
#SBATCH --output=Qwen3_1.7B_SFT_RL_a100-%j.out # Output file
|
| 9 |
+
#SBATCH --error=Qwen3_1.7B_SFT_RL_a100-%j.err # Error file
|
| 10 |
+
|
| 11 |
+
## Environment Setup
|
| 12 |
+
echo "CUDA_HOME: $CUDA_HOME"
|
| 13 |
+
echo "PATH: $PATH"
|
| 14 |
+
echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
| 15 |
+
echo "which python: $(which python)"
|
| 16 |
+
|
| 17 |
+
## Configuration Variables
|
| 18 |
+
# Change these to match your setup
|
| 19 |
+
SFT_CHECKPOINT=SFT_CHECKPOINT # Change to the checkpoint of the SFT model
|
| 20 |
+
CACHE_DIR=CACHE_DIR # Change to the directory where the model weights are cached
|
| 21 |
+
OUTPUT_DIR=OUTPUT_DIR # Change to the directory where the model will be saved
|
| 22 |
+
CONDA_ENV=CONDA_ENV # Change to the conda environment
|
| 23 |
+
|
| 24 |
+
## Setup Environment
|
| 25 |
+
conda activate $CONDA_ENV # Change to the conda environment
|
| 26 |
+
cd .../BioReason/ # Change to the directory containing the script
|
| 27 |
+
nvidia-smi # Check GPU status
|
| 28 |
+
|
| 29 |
+
## Dependencies
|
| 30 |
+
# You might need to install this on a gpu session
|
| 31 |
+
# pip install trl[vllm]
|
| 32 |
+
|
| 33 |
+
## =============================================================================
|
| 34 |
+
## Reinforcement Learning Training with DeepSpeed
|
| 35 |
+
## =============================================================================
|
| 36 |
+
|
| 37 |
+
# Run with DeepSpeed ZeRO Stage 2
|
| 38 |
+
srun deepspeed --num_gpus=4 --num_nodes=1 \
|
| 39 |
+
reason.py \
|
| 40 |
+
--deepspeed grpo_trainer_lora_model/ds_config_stage2.json \
|
| 41 |
+
--num_generations 4 \
|
| 42 |
+
--per_device_train_batch_size 2 \
|
| 43 |
+
--bf16 true \
|
| 44 |
+
--ddp_find_unused_parameters false \
|
| 45 |
+
--sft_checkpoint $SFT_CHECKPOINT \
|
| 46 |
+
--model_name_or_path Qwen/Qwen3-1.7B \
|
| 47 |
+
--dna_model_name_or_path InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 48 |
+
--cache_dir $CACHE_DIR \
|
| 49 |
+
--output_dir $OUTPUT_DIR \
|
| 50 |
+
--save_strategy "steps" \
|
| 51 |
+
--save_steps 100 \
|
| 52 |
+
--save_total_limit 2 \
|
| 53 |
+
--use_vllm true \
|
| 54 |
+
--temperature 0.6 \
|
| 55 |
+
--top_p 0.95 \
|
| 56 |
+
--top_k 20 \
|
| 57 |
+
--num_train_epochs 1
|
BioReason/sh_train_dna_only.sh
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=train_dna # Name of the job
|
| 3 |
+
#SBATCH --time=8:00:00 # Time limit
|
| 4 |
+
#SBATCH --partition=gpu_batch # Partition
|
| 5 |
+
#SBATCH --gpus=1 # Number of GPUs
|
| 6 |
+
#SBATCH --ntasks=1 # Number of tasks
|
| 7 |
+
#SBATCH --cpus-per-task=6 # Number of cores
|
| 8 |
+
#SBATCH --mem=128gb # Memory limit
|
| 9 |
+
#SBATCH --output=train_dna_%j_%x.out # Output file
|
| 10 |
+
#SBATCH --error=train_dna_%j_%x.err # Error file
|
| 11 |
+
|
| 12 |
+
## Environment Setup
|
| 13 |
+
echo "CUDA_HOME: $CUDA_HOME"
|
| 14 |
+
echo "PATH: $PATH"
|
| 15 |
+
echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
| 16 |
+
echo "which python: $(which python)"
|
| 17 |
+
|
| 18 |
+
## Configuration Variables
|
| 19 |
+
# Change these to match your setup
|
| 20 |
+
CONDA_ENV=CONDA_ENV # Change to your conda environment name
|
| 21 |
+
CACHE_DIR=CACHE_DIR # Change to your HuggingFace cache directory
|
| 22 |
+
WANDB_PROJECT=WANDB_PROJECT # Change to your W&B project name
|
| 23 |
+
|
| 24 |
+
## Setup Environment
|
| 25 |
+
conda activate $CONDA_ENV # Change to your conda environment
|
| 26 |
+
cd .../BioReason/ # Change to the directory containing the script
|
| 27 |
+
nvidia-smi # Check GPU status
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## =============================================================================
|
| 31 |
+
## KEGG Dataset Training (DNA-only models)
|
| 32 |
+
## =============================================================================
|
| 33 |
+
|
| 34 |
+
# NT-500M on KEGG
|
| 35 |
+
stdbuf -oL -eL srun python train_dna_only.py \
|
| 36 |
+
--cache_dir $CACHE_DIR \
|
| 37 |
+
--wandb_project $WANDB_PROJECT \
|
| 38 |
+
--dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 39 |
+
--strategy ddp \
|
| 40 |
+
--max_epochs 5 \
|
| 41 |
+
--num_gpus 1 \
|
| 42 |
+
--batch_size 1 \
|
| 43 |
+
--max_length_dna 2048 \
|
| 44 |
+
--truncate_dna_per_side 1024 \
|
| 45 |
+
--train_just_classifier True \
|
| 46 |
+
--learning_rate 3e-4 \
|
| 47 |
+
--dataset_type kegg \
|
| 48 |
+
--merge_val_test_set True
|
| 49 |
+
|
| 50 |
+
# EVO2-1B on KEGG
|
| 51 |
+
stdbuf -oL -eL srun python train_dna_only.py \
|
| 52 |
+
--cache_dir $CACHE_DIR \
|
| 53 |
+
--wandb_project $WANDB_PROJECT \
|
| 54 |
+
--dna_model_name evo2_1b_base \
|
| 55 |
+
--strategy ddp \
|
| 56 |
+
--max_epochs 5 \
|
| 57 |
+
--num_gpus 1 \
|
| 58 |
+
--batch_size 1 \
|
| 59 |
+
--max_length_dna 2048 \
|
| 60 |
+
--truncate_dna_per_side 1024 \
|
| 61 |
+
--train_just_classifier True \
|
| 62 |
+
--dna_is_evo2 True \
|
| 63 |
+
--dna_embedding_layer blocks.20.mlp.l3 \
|
| 64 |
+
--learning_rate 3e-4 \
|
| 65 |
+
--dataset_type kegg \
|
| 66 |
+
--merge_val_test_set True
|
| 67 |
+
|
| 68 |
+
## =============================================================================
|
| 69 |
+
## Variant Effect Prediction (VEP) Training
|
| 70 |
+
## =============================================================================
|
| 71 |
+
|
| 72 |
+
# NT-500M on VEP
|
| 73 |
+
stdbuf -oL -eL srun python train_dna_only.py \
|
| 74 |
+
--cache_dir $CACHE_DIR \
|
| 75 |
+
--wandb_project $WANDB_PROJECT \
|
| 76 |
+
--dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 77 |
+
--strategy ddp \
|
| 78 |
+
--max_epochs 3 \
|
| 79 |
+
--num_gpus 1 \
|
| 80 |
+
--batch_size 2 \
|
| 81 |
+
--max_length_dna 2048 \
|
| 82 |
+
--truncate_dna_per_side 1024 \
|
| 83 |
+
--train_just_classifier True \
|
| 84 |
+
--learning_rate 3e-4 \
|
| 85 |
+
--dataset_type variant_effect_coding
|
| 86 |
+
|
| 87 |
+
# EVO2-1B on VEP
|
| 88 |
+
stdbuf -oL -eL srun python train_dna_only.py \
|
| 89 |
+
--cache_dir $CACHE_DIR \
|
| 90 |
+
--wandb_project $WANDB_PROJECT \
|
| 91 |
+
--dna_model_name evo2_1b_base \
|
| 92 |
+
--strategy ddp \
|
| 93 |
+
--max_epochs 3 \
|
| 94 |
+
--num_gpus 1 \
|
| 95 |
+
--batch_size 2 \
|
| 96 |
+
--max_length_dna 2048 \
|
| 97 |
+
--truncate_dna_per_side 1024 \
|
| 98 |
+
--train_just_classifier True \
|
| 99 |
+
--dna_is_evo2 True \
|
| 100 |
+
--dna_embedding_layer blocks.20.mlp.l3 \
|
| 101 |
+
--learning_rate 3e-4 \
|
| 102 |
+
--dataset_type variant_effect_coding
|
| 103 |
+
|
| 104 |
+
## =============================================================================
|
| 105 |
+
## Variant Effect Prediction Non-SNV Training
|
| 106 |
+
## =============================================================================
|
| 107 |
+
|
| 108 |
+
# NT-500M on VEP Non-SNV
|
| 109 |
+
stdbuf -oL -eL srun python train_dna_only.py \
|
| 110 |
+
--cache_dir $CACHE_DIR \
|
| 111 |
+
--wandb_project $WANDB_PROJECT \
|
| 112 |
+
--dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 113 |
+
--strategy ddp \
|
| 114 |
+
--max_epochs 3 \
|
| 115 |
+
--num_gpus 1 \
|
| 116 |
+
--batch_size 2 \
|
| 117 |
+
--max_length_dna 2048 \
|
| 118 |
+
--truncate_dna_per_side 1024 \
|
| 119 |
+
--train_just_classifier True \
|
| 120 |
+
--learning_rate 3e-4 \
|
| 121 |
+
--dataset_type variant_effect_non_snv
|
| 122 |
+
|
| 123 |
+
# EVO2-1B on VEP Non-SNV
|
| 124 |
+
stdbuf -oL -eL srun python train_dna_only.py \
|
| 125 |
+
--cache_dir $CACHE_DIR \
|
| 126 |
+
--wandb_project $WANDB_PROJECT \
|
| 127 |
+
--dna_model_name evo2_1b_base \
|
| 128 |
+
--strategy ddp \
|
| 129 |
+
--max_epochs 3 \
|
| 130 |
+
--num_gpus 1 \
|
| 131 |
+
--batch_size 2 \
|
| 132 |
+
--max_length_dna 2048 \
|
| 133 |
+
--truncate_dna_per_side 1024 \
|
| 134 |
+
--train_just_classifier True \
|
| 135 |
+
--dna_is_evo2 True \
|
| 136 |
+
--dna_embedding_layer blocks.20.mlp.l3 \
|
| 137 |
+
--learning_rate 3e-4 \
|
| 138 |
+
--dataset_type variant_effect_non_snv
|
BioReason/sh_train_dna_qwen.sh
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH --job-name=train_dna_qwen # Name of the job
|
| 3 |
+
#SBATCH --time=12:00:00 # Time limit
|
| 4 |
+
#SBATCH --partition=gpu_batch # Partition
|
| 5 |
+
#SBATCH --gpus=1 # Number of GPUs
|
| 6 |
+
#SBATCH --ntasks=1 # Number of tasks
|
| 7 |
+
#SBATCH --cpus-per-task=8 # Number of cores
|
| 8 |
+
#SBATCH --mem=128gb # Memory limit
|
| 9 |
+
#SBATCH --output=train_dna_qwen_%j_%x.out # Output file
|
| 10 |
+
#SBATCH --error=train_dna_qwen_%j_%x.err # Error file
|
| 11 |
+
|
| 12 |
+
## Environment Setup
|
| 13 |
+
echo "CUDA_HOME: $CUDA_HOME"
|
| 14 |
+
echo "PATH: $PATH"
|
| 15 |
+
echo "LD_LIBRARY_PATH: $LD_LIBRARY_PATH"
|
| 16 |
+
echo "which python: $(which python)"
|
| 17 |
+
|
| 18 |
+
## Configuration Variables
|
| 19 |
+
# Change these to match your setup
|
| 20 |
+
CONDA_ENV=CONDA_ENV # Change to your conda environment name
|
| 21 |
+
CACHE_DIR=CACHE_DIR # Change to your HuggingFace cache directory
|
| 22 |
+
OUTPUT_DIR=OUTPUT_DIR # Change to your output/log directory
|
| 23 |
+
WANDB_PROJECT=WANDB_PROJECT # Change to your W&B project name
|
| 24 |
+
|
| 25 |
+
## Setup Environment
|
| 26 |
+
conda activate $CONDA_ENV # Change to your conda environment
|
| 27 |
+
cd .../BioReason/ # Change to the directory containing the script
|
| 28 |
+
nvidia-smi # Check GPU status
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
## =============================================================================
|
| 32 |
+
## KEGG Dataset Training
|
| 33 |
+
## =============================================================================
|
| 34 |
+
|
| 35 |
+
# NT-500M + Qwen3-1.7B on KEGG
|
| 36 |
+
stdbuf -oL -eL srun python train_dna_qwen.py \
|
| 37 |
+
--cache_dir $CACHE_DIR \
|
| 38 |
+
--wandb_project $WANDB_PROJECT \
|
| 39 |
+
--text_model_name Qwen/Qwen3-1.7B \
|
| 40 |
+
--dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 41 |
+
--strategy deepspeed_stage_2 \
|
| 42 |
+
--max_epochs 5 \
|
| 43 |
+
--num_gpus 1 \
|
| 44 |
+
--batch_size 1 \
|
| 45 |
+
--model_type dna-llm \
|
| 46 |
+
--dataset_type kegg \
|
| 47 |
+
--merge_val_test_set True \
|
| 48 |
+
--return_answer_in_batch True
|
| 49 |
+
|
| 50 |
+
# EVO2-1B + Qwen3-1.7B on KEGG
|
| 51 |
+
stdbuf -oL -eL srun python train_dna_qwen.py \
|
| 52 |
+
--cache_dir $CACHE_DIR \
|
| 53 |
+
--wandb_project $WANDB_PROJECT \
|
| 54 |
+
--text_model_name Qwen/Qwen3-1.7B \
|
| 55 |
+
--dna_model_name evo2_1b_base \
|
| 56 |
+
--strategy deepspeed_stage_2 \
|
| 57 |
+
--max_epochs 5 \
|
| 58 |
+
--num_gpus 1 \
|
| 59 |
+
--batch_size 1 \
|
| 60 |
+
--model_type dna-llm \
|
| 61 |
+
--dataset_type kegg \
|
| 62 |
+
--max_length_dna 2048 \
|
| 63 |
+
--truncate_dna_per_side 1024 \
|
| 64 |
+
--dna_is_evo2 True \
|
| 65 |
+
--dna_embedding_layer blocks.20.mlp.l3 \
|
| 66 |
+
--merge_val_test_set True \
|
| 67 |
+
--return_answer_in_batch True
|
| 68 |
+
|
| 69 |
+
# Qwen3-4B on KEGG (LLM-only)
|
| 70 |
+
stdbuf -oL -eL srun python train_dna_qwen.py \
|
| 71 |
+
--cache_dir $CACHE_DIR \
|
| 72 |
+
--wandb_project $WANDB_PROJECT \
|
| 73 |
+
--text_model_name Qwen/Qwen3-4B \
|
| 74 |
+
--dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 75 |
+
--strategy deepspeed_stage_2 \
|
| 76 |
+
--max_epochs 5 \
|
| 77 |
+
--num_gpus 1 \
|
| 78 |
+
--batch_size 1 \
|
| 79 |
+
--model_type llm \
|
| 80 |
+
--dataset_type kegg \
|
| 81 |
+
--max_length_dna 4 \
|
| 82 |
+
--max_length_text 8192 \
|
| 83 |
+
--truncate_dna_per_side 1024 \
|
| 84 |
+
--merge_val_test_set True \
|
| 85 |
+
--return_answer_in_batch True
|
| 86 |
+
|
| 87 |
+
## =============================================================================
|
| 88 |
+
## Variant Effect Prediction (VEP) Training
|
| 89 |
+
## =============================================================================
|
| 90 |
+
|
| 91 |
+
# NT-500M + Qwen3-4B on VEP
|
| 92 |
+
stdbuf -oL -eL srun python train_dna_qwen.py \
|
| 93 |
+
--cache_dir $CACHE_DIR \
|
| 94 |
+
--wandb_project $WANDB_PROJECT \
|
| 95 |
+
--text_model_name Qwen/Qwen3-4B \
|
| 96 |
+
--dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 97 |
+
--strategy deepspeed_stage_2 \
|
| 98 |
+
--max_epochs 3 \
|
| 99 |
+
--num_gpus 1 \
|
| 100 |
+
--batch_size 2 \
|
| 101 |
+
--model_type dna-llm \
|
| 102 |
+
--dataset_type variant_effect_coding \
|
| 103 |
+
--return_answer_in_batch True
|
| 104 |
+
|
| 105 |
+
# EVO2-1B + Qwen3-1.7B on VEP
|
| 106 |
+
stdbuf -oL -eL srun python train_dna_qwen.py \
|
| 107 |
+
--cache_dir $CACHE_DIR \
|
| 108 |
+
--wandb_project $WANDB_PROJECT \
|
| 109 |
+
--text_model_name Qwen/Qwen3-1.7B \
|
| 110 |
+
--dna_model_name evo2_1b_base \
|
| 111 |
+
--strategy deepspeed_stage_2 \
|
| 112 |
+
--max_epochs 3 \
|
| 113 |
+
--num_gpus 1 \
|
| 114 |
+
--batch_size 2 \
|
| 115 |
+
--model_type dna-llm \
|
| 116 |
+
--dataset_type variant_effect_coding \
|
| 117 |
+
--max_length_dna 2048 \
|
| 118 |
+
--truncate_dna_per_side 1024 \
|
| 119 |
+
--dna_is_evo2 True \
|
| 120 |
+
--dna_embedding_layer blocks.20.mlp.l3 \
|
| 121 |
+
--return_answer_in_batch True
|
| 122 |
+
|
| 123 |
+
# Qwen3-4B on VEP (LLM-only) - Testing max length text
|
| 124 |
+
stdbuf -oL -eL srun python train_dna_qwen.py \
|
| 125 |
+
--cache_dir $CACHE_DIR \
|
| 126 |
+
--wandb_project $WANDB_PROJECT \
|
| 127 |
+
--text_model_name Qwen/Qwen3-4B \
|
| 128 |
+
--dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 129 |
+
--strategy deepspeed_stage_2 \
|
| 130 |
+
--max_epochs 3 \
|
| 131 |
+
--num_gpus 1 \
|
| 132 |
+
--batch_size 2 \
|
| 133 |
+
--model_type llm \
|
| 134 |
+
--dataset_type variant_effect_coding \
|
| 135 |
+
--max_length_dna 4 \
|
| 136 |
+
--max_length_text 4096 \
|
| 137 |
+
--truncate_dna_per_side 1024 \
|
| 138 |
+
--return_answer_in_batch True
|
| 139 |
+
|
| 140 |
+
## =============================================================================
|
| 141 |
+
## Variant Effect Prediction Non-SNV Training
|
| 142 |
+
## =============================================================================
|
| 143 |
+
|
| 144 |
+
# NT-500M + Qwen3-4B on VEP Non-SNV
|
| 145 |
+
stdbuf -oL -eL srun python train_dna_qwen.py \
|
| 146 |
+
--cache_dir $CACHE_DIR \
|
| 147 |
+
--wandb_project $WANDB_PROJECT \
|
| 148 |
+
--text_model_name Qwen/Qwen3-4B \
|
| 149 |
+
--dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 150 |
+
--strategy deepspeed_stage_2 \
|
| 151 |
+
--max_epochs 1 \
|
| 152 |
+
--num_gpus 1 \
|
| 153 |
+
--batch_size 2 \
|
| 154 |
+
--model_type dna-llm \
|
| 155 |
+
--dataset_type variant_effect_non_snv \
|
| 156 |
+
--return_answer_in_batch True
|
| 157 |
+
|
| 158 |
+
# EVO2-1B + Qwen3-4B on VEP Non-SNV
|
| 159 |
+
stdbuf -oL -eL srun python train_dna_qwen.py \
|
| 160 |
+
--cache_dir $CACHE_DIR \
|
| 161 |
+
--wandb_project $WANDB_PROJECT \
|
| 162 |
+
--text_model_name Qwen/Qwen3-4B \
|
| 163 |
+
--dna_model_name evo2_1b_base \
|
| 164 |
+
--strategy deepspeed_stage_2 \
|
| 165 |
+
--max_epochs 3 \
|
| 166 |
+
--num_gpus 1 \
|
| 167 |
+
--batch_size 2 \
|
| 168 |
+
--model_type dna-llm \
|
| 169 |
+
--dataset_type variant_effect_non_snv \
|
| 170 |
+
--max_length_dna 2048 \
|
| 171 |
+
--truncate_dna_per_side 1024 \
|
| 172 |
+
--dna_is_evo2 True \
|
| 173 |
+
--dna_embedding_layer blocks.20.mlp.l3 \
|
| 174 |
+
--return_answer_in_batch True
|
| 175 |
+
|
| 176 |
+
# Qwen3-4B on VEP Non-SNV (LLM-only) - Testing max length text
|
| 177 |
+
stdbuf -oL -eL srun python train_dna_qwen.py \
|
| 178 |
+
--cache_dir $CACHE_DIR \
|
| 179 |
+
--wandb_project $WANDB_PROJECT \
|
| 180 |
+
--text_model_name Qwen/Qwen3-4B \
|
| 181 |
+
--dna_model_name InstaDeepAI/nucleotide-transformer-v2-500m-multi-species \
|
| 182 |
+
--strategy deepspeed_stage_2 \
|
| 183 |
+
--max_epochs 1 \
|
| 184 |
+
--num_gpus 1 \
|
| 185 |
+
--batch_size 2 \
|
| 186 |
+
--model_type llm \
|
| 187 |
+
--dataset_type variant_effect_non_snv \
|
| 188 |
+
--max_length_dna 4 \
|
| 189 |
+
--max_length_text 4096 \
|
| 190 |
+
--truncate_dna_per_side 1024 \
|
| 191 |
+
--return_answer_in_batch True
|
BioReason/train_dna_only.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import argparse
|
| 4 |
+
import torch
|
| 5 |
+
import wandb
|
| 6 |
+
from torch.optim import AdamW
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from transformers import get_cosine_schedule_with_warmup, AutoTokenizer
|
| 9 |
+
from datasets import load_dataset, concatenate_datasets
|
| 10 |
+
import pytorch_lightning as pl
|
| 11 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 12 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 13 |
+
from pytorch_lightning.strategies import DeepSpeedStrategy
|
| 14 |
+
from bioreason.models.dna_only import DNAClassifierModel
|
| 15 |
+
from bioreason.dataset.utils import truncate_dna
|
| 16 |
+
from bioreason.dataset.kegg import dna_collate_fn
|
| 17 |
+
from bioreason.dataset.variant_effect import clean_variant_effect_example
|
| 18 |
+
from bioreason.models.evo2_tokenizer import Evo2Tokenizer, register_evo2_tokenizer
|
| 19 |
+
register_evo2_tokenizer()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DNAClassifierModelTrainer(pl.LightningModule):
|
| 23 |
+
"""
|
| 24 |
+
PyTorch Lightning module for training the DNA classifier.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, args):
|
| 28 |
+
"""
|
| 29 |
+
Initialize the DNAClassifierModelTrainer.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
args: Command line arguments
|
| 33 |
+
"""
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.save_hyperparameters(args)
|
| 36 |
+
|
| 37 |
+
# Load dataset and labels
|
| 38 |
+
self.dataset, self.labels = self.load_dataset()
|
| 39 |
+
self.label2id = {label: i for i, label in enumerate(self.labels)}
|
| 40 |
+
|
| 41 |
+
# Load model
|
| 42 |
+
self.dna_model = DNAClassifierModel(
|
| 43 |
+
dna_model_name=self.hparams.dna_model_name,
|
| 44 |
+
cache_dir=self.hparams.cache_dir,
|
| 45 |
+
max_length_dna=self.hparams.max_length_dna,
|
| 46 |
+
num_classes=len(self.labels),
|
| 47 |
+
dna_is_evo2=self.hparams.dna_is_evo2,
|
| 48 |
+
dna_embedding_layer=self.hparams.dna_embedding_layer,
|
| 49 |
+
train_just_classifier=self.hparams.train_just_classifier,
|
| 50 |
+
)
|
| 51 |
+
self.dna_tokenizer = self.dna_model.dna_tokenizer
|
| 52 |
+
|
| 53 |
+
# Set the training mode for the classifier and pooler
|
| 54 |
+
self.dna_model.pooler.train()
|
| 55 |
+
self.dna_model.classifier.train()
|
| 56 |
+
|
| 57 |
+
# Freeze the DNA model parameters
|
| 58 |
+
if self.hparams.dna_is_evo2:
|
| 59 |
+
self.dna_model_params = self.dna_model.dna_model.model.parameters()
|
| 60 |
+
else:
|
| 61 |
+
self.dna_model_params = self.dna_model.dna_model.parameters()
|
| 62 |
+
|
| 63 |
+
if self.hparams.train_just_classifier:
|
| 64 |
+
for param in self.dna_model_params:
|
| 65 |
+
param.requires_grad = False
|
| 66 |
+
|
| 67 |
+
def _step(self, prefix, batch_idx, batch):
|
| 68 |
+
"""
|
| 69 |
+
Performs a single training/validation step.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
batch: Dictionary containing the batch data
|
| 73 |
+
prefix: String indicating the step type ('train' or 'val')
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
torch.Tensor: The computed loss for this batch
|
| 77 |
+
"""
|
| 78 |
+
ref_ids = batch["ref_ids"].to(self.device)
|
| 79 |
+
alt_ids = batch["alt_ids"].to(self.device)
|
| 80 |
+
ref_attention_mask = batch["ref_attention_mask"].to(self.device)
|
| 81 |
+
alt_attention_mask = batch["alt_attention_mask"].to(self.device)
|
| 82 |
+
labels = batch["labels"].to(self.device)
|
| 83 |
+
|
| 84 |
+
# Forward pass
|
| 85 |
+
logits = self.dna_model(ref_ids=ref_ids, alt_ids=alt_ids, ref_attention_mask=ref_attention_mask, alt_attention_mask=alt_attention_mask)
|
| 86 |
+
|
| 87 |
+
# Calculate loss
|
| 88 |
+
loss_fn = torch.nn.CrossEntropyLoss()
|
| 89 |
+
loss = loss_fn(logits, labels)
|
| 90 |
+
|
| 91 |
+
# Calculate accuracy
|
| 92 |
+
preds = torch.argmax(logits, dim=1)
|
| 93 |
+
acc = (preds == labels).float().mean()
|
| 94 |
+
|
| 95 |
+
# Calculate F1 score, precision, and recall for binary classification
|
| 96 |
+
# Assuming label 1 is positive and label 0 is negative as mentioned
|
| 97 |
+
true_positives = ((preds == 1) & (labels == 1)).float().sum()
|
| 98 |
+
false_positives = ((preds == 1) & (labels == 0)).float().sum()
|
| 99 |
+
false_negatives = ((preds == 0) & (labels == 1)).float().sum()
|
| 100 |
+
|
| 101 |
+
# Calculate precision, recall, and F1 score
|
| 102 |
+
precision = true_positives / (true_positives + false_positives + 1e-8) # add small epsilon to avoid division by zero
|
| 103 |
+
recall = true_positives / (true_positives + false_negatives + 1e-8)
|
| 104 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-8)
|
| 105 |
+
|
| 106 |
+
# Logging metrics
|
| 107 |
+
self.log(
|
| 108 |
+
f"{prefix}_loss",
|
| 109 |
+
loss,
|
| 110 |
+
on_step=True,
|
| 111 |
+
on_epoch=False,
|
| 112 |
+
prog_bar=True,
|
| 113 |
+
logger=True,
|
| 114 |
+
)
|
| 115 |
+
self.log(
|
| 116 |
+
f"{prefix}_acc",
|
| 117 |
+
acc,
|
| 118 |
+
on_step=True,
|
| 119 |
+
on_epoch=False,
|
| 120 |
+
prog_bar=True,
|
| 121 |
+
logger=True,
|
| 122 |
+
)
|
| 123 |
+
self.log(
|
| 124 |
+
f"{prefix}_loss_epoch",
|
| 125 |
+
loss,
|
| 126 |
+
on_step=False,
|
| 127 |
+
on_epoch=True,
|
| 128 |
+
prog_bar=True,
|
| 129 |
+
logger=True,
|
| 130 |
+
sync_dist=True,
|
| 131 |
+
)
|
| 132 |
+
self.log(
|
| 133 |
+
f"{prefix}_acc_epoch",
|
| 134 |
+
acc,
|
| 135 |
+
on_step=False,
|
| 136 |
+
on_epoch=True,
|
| 137 |
+
prog_bar=True,
|
| 138 |
+
logger=True,
|
| 139 |
+
sync_dist=True,
|
| 140 |
+
)
|
| 141 |
+
self.log(
|
| 142 |
+
f"{prefix}_precision",
|
| 143 |
+
precision,
|
| 144 |
+
on_step=True,
|
| 145 |
+
on_epoch=False,
|
| 146 |
+
prog_bar=True,
|
| 147 |
+
logger=True,
|
| 148 |
+
)
|
| 149 |
+
self.log(
|
| 150 |
+
f"{prefix}_precision_epoch",
|
| 151 |
+
precision,
|
| 152 |
+
on_step=False,
|
| 153 |
+
on_epoch=True,
|
| 154 |
+
prog_bar=True,
|
| 155 |
+
logger=True,
|
| 156 |
+
sync_dist=True,
|
| 157 |
+
)
|
| 158 |
+
self.log(
|
| 159 |
+
f"{prefix}_recall",
|
| 160 |
+
recall,
|
| 161 |
+
on_step=True,
|
| 162 |
+
on_epoch=False,
|
| 163 |
+
prog_bar=True,
|
| 164 |
+
logger=True,
|
| 165 |
+
)
|
| 166 |
+
self.log(
|
| 167 |
+
f"{prefix}_recall_epoch",
|
| 168 |
+
recall,
|
| 169 |
+
on_step=False,
|
| 170 |
+
on_epoch=True,
|
| 171 |
+
prog_bar=True,
|
| 172 |
+
logger=True,
|
| 173 |
+
sync_dist=True,
|
| 174 |
+
)
|
| 175 |
+
self.log(
|
| 176 |
+
f"{prefix}_f1",
|
| 177 |
+
f1,
|
| 178 |
+
on_step=True,
|
| 179 |
+
on_epoch=False,
|
| 180 |
+
prog_bar=True,
|
| 181 |
+
logger=True,
|
| 182 |
+
)
|
| 183 |
+
self.log(
|
| 184 |
+
f"{prefix}_f1_epoch",
|
| 185 |
+
f1,
|
| 186 |
+
on_step=False,
|
| 187 |
+
on_epoch=True,
|
| 188 |
+
prog_bar=True,
|
| 189 |
+
logger=True,
|
| 190 |
+
sync_dist=True,
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
if (prefix == "test") or (prefix == "train" and (self.global_step % 1000 == 0)) or (prefix == "val" and (batch_idx % 100 == 0)):
|
| 194 |
+
wandb_logger = self.logger.experiment
|
| 195 |
+
|
| 196 |
+
pred_label = self.labels[preds[0]]
|
| 197 |
+
true_label = self.labels[labels[0]]
|
| 198 |
+
timestamp = time.time()
|
| 199 |
+
step_id = f"gen_{self.global_step}-{timestamp}"
|
| 200 |
+
|
| 201 |
+
wandb_logger.log(
|
| 202 |
+
{
|
| 203 |
+
step_id: wandb.Table(
|
| 204 |
+
columns=["timestamp", "prefix", "pred_label", "true_label"],
|
| 205 |
+
data=[[timestamp, prefix, pred_label, true_label]],
|
| 206 |
+
)
|
| 207 |
+
}
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
print(f"Example {prefix} {batch_idx} {self.global_step}: Prediction: {pred_label}, Target: {true_label}")
|
| 211 |
+
|
| 212 |
+
return loss
|
| 213 |
+
|
| 214 |
+
def training_step(self, batch, batch_idx):
|
| 215 |
+
"""Perform a training step."""
|
| 216 |
+
return self._step(prefix="train", batch_idx=batch_idx, batch=batch)
|
| 217 |
+
|
| 218 |
+
def validation_step(self, batch, batch_idx):
|
| 219 |
+
"""Perform a validation step."""
|
| 220 |
+
return self._step(prefix="val", batch_idx=batch_idx, batch=batch)
|
| 221 |
+
|
| 222 |
+
def test_step(self, batch, batch_idx):
|
| 223 |
+
"""Perform a test step."""
|
| 224 |
+
return self._step(prefix="test", batch_idx=batch_idx, batch=batch)
|
| 225 |
+
|
| 226 |
+
def configure_optimizers(self):
|
| 227 |
+
"""Configure optimizers and learning rate schedulers."""
|
| 228 |
+
# Only include parameters that require gradients
|
| 229 |
+
classifier_params = [
|
| 230 |
+
{
|
| 231 |
+
"params": self.dna_model.classifier.parameters(),
|
| 232 |
+
"lr": self.hparams.learning_rate,
|
| 233 |
+
},
|
| 234 |
+
{
|
| 235 |
+
"params": self.dna_model.pooler.parameters(),
|
| 236 |
+
"lr": self.hparams.learning_rate,
|
| 237 |
+
}
|
| 238 |
+
]
|
| 239 |
+
dna_model_params = [
|
| 240 |
+
{
|
| 241 |
+
"params": self.dna_model_params,
|
| 242 |
+
"lr": self.hparams.learning_rate * 0.1,
|
| 243 |
+
},
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
if self.hparams.train_just_classifier:
|
| 247 |
+
# Only train classifier parameters
|
| 248 |
+
optimizer = AdamW(
|
| 249 |
+
classifier_params,
|
| 250 |
+
weight_decay=self.hparams.weight_decay,
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
# Train both DNA model and classifier with different learning rates
|
| 254 |
+
optimizer = AdamW(
|
| 255 |
+
classifier_params + dna_model_params,
|
| 256 |
+
weight_decay=self.hparams.weight_decay,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# Get total steps from trainer's estimated stepping batches
|
| 260 |
+
total_steps = self.trainer.estimated_stepping_batches
|
| 261 |
+
warmup_steps = int(0.1 * total_steps)
|
| 262 |
+
|
| 263 |
+
# Create scheduler
|
| 264 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 265 |
+
optimizer,
|
| 266 |
+
num_warmup_steps=warmup_steps,
|
| 267 |
+
num_training_steps=total_steps,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
| 271 |
+
|
| 272 |
+
def load_dataset(self):
|
| 273 |
+
"""Load the dataset based on the dataset type."""
|
| 274 |
+
if self.hparams.dataset_type == "kegg":
|
| 275 |
+
dataset = load_dataset(self.hparams.kegg_data_dir_huggingface)
|
| 276 |
+
|
| 277 |
+
if self.hparams.truncate_dna_per_side:
|
| 278 |
+
dataset = dataset.map(
|
| 279 |
+
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
labels = []
|
| 283 |
+
for split, data in dataset.items():
|
| 284 |
+
labels.extend(data["answer"])
|
| 285 |
+
labels = list(set(labels))
|
| 286 |
+
|
| 287 |
+
elif self.hparams.dataset_type == "variant_effect_coding":
|
| 288 |
+
dataset = load_dataset("wanglab/bioR_tasks", "variant_effect_coding")
|
| 289 |
+
dataset = dataset.map(clean_variant_effect_example)
|
| 290 |
+
|
| 291 |
+
if self.hparams.truncate_dna_per_side:
|
| 292 |
+
dataset = dataset.map(
|
| 293 |
+
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
labels = []
|
| 297 |
+
for split, data in dataset.items():
|
| 298 |
+
labels.extend(data["answer"])
|
| 299 |
+
labels = sorted(list(set(labels)))
|
| 300 |
+
|
| 301 |
+
elif self.hparams.dataset_type == "variant_effect_non_snv":
|
| 302 |
+
dataset = load_dataset("wanglab/bioR_tasks", "task5_variant_effect_non_snv")
|
| 303 |
+
dataset = dataset.rename_column("mutated_sequence", "variant_sequence")
|
| 304 |
+
dataset = dataset.map(clean_variant_effect_example)
|
| 305 |
+
|
| 306 |
+
if self.hparams.truncate_dna_per_side:
|
| 307 |
+
dataset = dataset.map(
|
| 308 |
+
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
labels = []
|
| 312 |
+
for split, data in dataset.items():
|
| 313 |
+
labels.extend(data["answer"])
|
| 314 |
+
labels = sorted(list(set(labels)))
|
| 315 |
+
|
| 316 |
+
else:
|
| 317 |
+
raise ValueError(f"Invalid dataset type: {self.hparams.dataset_type}")
|
| 318 |
+
|
| 319 |
+
print(f"Dataset:\n{dataset}\nLabels:\n{labels}\nNumber of labels:{len(labels)}")
|
| 320 |
+
return dataset, labels
|
| 321 |
+
|
| 322 |
+
def train_dataloader(self):
|
| 323 |
+
"""Create and return the training DataLoader."""
|
| 324 |
+
if self.hparams.dataset_type == "kegg":
|
| 325 |
+
train_dataset = self.dataset["train"]
|
| 326 |
+
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
|
| 327 |
+
|
| 328 |
+
elif self.hparams.dataset_type == "variant_effect_coding":
|
| 329 |
+
train_dataset = self.dataset["train"]
|
| 330 |
+
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
|
| 331 |
+
|
| 332 |
+
elif self.hparams.dataset_type == "variant_effect_non_snv":
|
| 333 |
+
train_dataset = self.dataset["train"]
|
| 334 |
+
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
|
| 335 |
+
|
| 336 |
+
else:
|
| 337 |
+
raise ValueError(f"Invalid dataset type: {self.hparams.dataset_type}")
|
| 338 |
+
|
| 339 |
+
return DataLoader(
|
| 340 |
+
train_dataset,
|
| 341 |
+
batch_size=self.hparams.batch_size,
|
| 342 |
+
shuffle=True,
|
| 343 |
+
collate_fn=collate_fn,
|
| 344 |
+
num_workers=self.hparams.num_workers,
|
| 345 |
+
persistent_workers=True,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
def val_dataloader(self):
|
| 349 |
+
"""Create and return the training DataLoader."""
|
| 350 |
+
if self.hparams.dataset_type == "kegg":
|
| 351 |
+
|
| 352 |
+
if self.hparams.merge_val_test_set:
|
| 353 |
+
val_dataset = concatenate_datasets([self.dataset['test'], self.dataset['val']])
|
| 354 |
+
else:
|
| 355 |
+
val_dataset = self.dataset["val"]
|
| 356 |
+
|
| 357 |
+
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
|
| 358 |
+
|
| 359 |
+
elif self.hparams.dataset_type == "variant_effect_coding":
|
| 360 |
+
val_dataset = self.dataset["test"]
|
| 361 |
+
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
|
| 362 |
+
|
| 363 |
+
elif self.hparams.dataset_type == "variant_effect_non_snv":
|
| 364 |
+
val_dataset = self.dataset["test"]
|
| 365 |
+
collate_fn = lambda b: dna_collate_fn(b, dna_tokenizer=self.dna_tokenizer, label2id=self.label2id, max_length=self.hparams.max_length_dna)
|
| 366 |
+
|
| 367 |
+
else:
|
| 368 |
+
raise ValueError(f"Invalid dataset type: {self.hparams.dataset_type}")
|
| 369 |
+
|
| 370 |
+
return DataLoader(
|
| 371 |
+
val_dataset,
|
| 372 |
+
batch_size=self.hparams.batch_size,
|
| 373 |
+
shuffle=False,
|
| 374 |
+
collate_fn=collate_fn,
|
| 375 |
+
num_workers=self.hparams.num_workers,
|
| 376 |
+
persistent_workers=True,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
def test_dataloader(self):
|
| 380 |
+
"""Create and return the test DataLoader."""
|
| 381 |
+
return self.val_dataloader()
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def main(args):
|
| 385 |
+
"""Main function to run the training process."""
|
| 386 |
+
# Set random seed and environment variables
|
| 387 |
+
pl.seed_everything(args.seed)
|
| 388 |
+
torch.cuda.empty_cache()
|
| 389 |
+
torch.set_float32_matmul_precision("medium")
|
| 390 |
+
|
| 391 |
+
# Initialize model
|
| 392 |
+
model = DNAClassifierModelTrainer(args)
|
| 393 |
+
|
| 394 |
+
# Setup directories
|
| 395 |
+
run_name = f"{args.wandb_project}-{args.dataset_type}-{args.dna_model_name.split('/')[-1]}"
|
| 396 |
+
args.checkpoint_dir = f"{args.checkpoint_dir}/{run_name}-{time.strftime('%Y%m%d-%H%M%S')}"
|
| 397 |
+
args.output_dir = f"{args.output_dir}/{run_name}-{time.strftime('%Y%m%d-%H%M%S')}"
|
| 398 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 399 |
+
os.makedirs(args.checkpoint_dir, exist_ok=True)
|
| 400 |
+
|
| 401 |
+
# Setup callbacks
|
| 402 |
+
callbacks = [
|
| 403 |
+
ModelCheckpoint(
|
| 404 |
+
dirpath=args.checkpoint_dir,
|
| 405 |
+
filename=f"{run_name}-" + "{epoch:02d}-{val_loss_epoch:.4f}",
|
| 406 |
+
save_top_k=2,
|
| 407 |
+
monitor="val_acc_epoch",
|
| 408 |
+
mode="max",
|
| 409 |
+
save_last=True,
|
| 410 |
+
),
|
| 411 |
+
LearningRateMonitor(logging_interval="step"),
|
| 412 |
+
]
|
| 413 |
+
|
| 414 |
+
# Setup logger
|
| 415 |
+
is_resuming = args.ckpt_path is not None
|
| 416 |
+
logger = WandbLogger(
|
| 417 |
+
project=args.wandb_project,
|
| 418 |
+
entity=args.wandb_entity,
|
| 419 |
+
save_dir=args.log_dir,
|
| 420 |
+
name=run_name,
|
| 421 |
+
resume="allow" if is_resuming else None, # Allow resuming existing run
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Initialize trainer
|
| 425 |
+
trainer = pl.Trainer(
|
| 426 |
+
max_epochs=args.max_epochs,
|
| 427 |
+
accelerator="gpu",
|
| 428 |
+
devices=args.num_gpus,
|
| 429 |
+
strategy=(
|
| 430 |
+
"ddp"
|
| 431 |
+
if args.strategy == "ddp"
|
| 432 |
+
else DeepSpeedStrategy(stage=2, offload_optimizer=False, allgather_bucket_size=5e8, reduce_bucket_size=5e8)
|
| 433 |
+
),
|
| 434 |
+
precision="bf16-mixed",
|
| 435 |
+
callbacks=callbacks,
|
| 436 |
+
logger=logger,
|
| 437 |
+
deterministic=False,
|
| 438 |
+
enable_checkpointing=True,
|
| 439 |
+
enable_progress_bar=True,
|
| 440 |
+
enable_model_summary=True,
|
| 441 |
+
log_every_n_steps=5,
|
| 442 |
+
accumulate_grad_batches=args.gradient_accumulation_steps,
|
| 443 |
+
gradient_clip_val=1.0,
|
| 444 |
+
val_check_interval=1 / 3,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Train model
|
| 448 |
+
trainer.fit(model, ckpt_path=args.ckpt_path)
|
| 449 |
+
trainer.test(model, ckpt_path=args.ckpt_path if args.ckpt_path else "best")
|
| 450 |
+
|
| 451 |
+
# Save final model
|
| 452 |
+
final_model_path = os.path.join(args.output_dir, "final_model")
|
| 453 |
+
torch.save(model.dna_model.state_dict(), final_model_path)
|
| 454 |
+
print(f"Final model saved to {final_model_path}")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
if __name__ == "__main__":
|
| 458 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 459 |
+
parser = argparse.ArgumentParser(description="Train DNA Classifier")
|
| 460 |
+
|
| 461 |
+
# Model parameters
|
| 462 |
+
parser.add_argument(
|
| 463 |
+
"--dna_model_name",
|
| 464 |
+
type=str,
|
| 465 |
+
default="InstaDeepAI/nucleotide-transformer-v2-500m-multi-species",
|
| 466 |
+
)
|
| 467 |
+
parser.add_argument("--cache_dir", type=str, default="/model-weights")
|
| 468 |
+
parser.add_argument("--max_length_dna", type=int, default=1024)
|
| 469 |
+
parser.add_argument("--dna_is_evo2", type=bool, default=False)
|
| 470 |
+
parser.add_argument("--dna_embedding_layer", type=str, default=None)
|
| 471 |
+
|
| 472 |
+
# Training parameters
|
| 473 |
+
parser.add_argument("--strategy", type=str, default="ddp")
|
| 474 |
+
parser.add_argument("--batch_size", type=int, default=8)
|
| 475 |
+
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
| 476 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
| 477 |
+
parser.add_argument("--max_epochs", type=int, default=5)
|
| 478 |
+
parser.add_argument("--max_steps", type=int, default=-1)
|
| 479 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
|
| 480 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 481 |
+
parser.add_argument("--num_gpus", type=int, default=1)
|
| 482 |
+
parser.add_argument("--train_just_classifier", type=bool, default=True)
|
| 483 |
+
parser.add_argument("--dataset_type", type=str, choices=["kegg", "variant_effect_coding", "variant_effect_non_snv"], default="kegg")
|
| 484 |
+
parser.add_argument("--kegg_data_dir_huggingface", type=str, default="wanglab/kegg")
|
| 485 |
+
parser.add_argument("--truncate_dna_per_side", type=int, default=0)
|
| 486 |
+
|
| 487 |
+
# Output parameters
|
| 488 |
+
parser.add_argument("--output_dir", type=str, default="dna_classifier_output")
|
| 489 |
+
parser.add_argument(
|
| 490 |
+
"--checkpoint_dir", type=str, default="checkpoints"
|
| 491 |
+
)
|
| 492 |
+
parser.add_argument("--ckpt_path", type=str, default=None)
|
| 493 |
+
parser.add_argument("--log_dir", type=str, default="logs")
|
| 494 |
+
parser.add_argument("--wandb_project", type=str, default="dna-only-nt-500m")
|
| 495 |
+
parser.add_argument("--wandb_entity", type=str, default="adibvafa")
|
| 496 |
+
parser.add_argument("--merge_val_test_set", type=bool, default=True)
|
| 497 |
+
|
| 498 |
+
# Other parameters
|
| 499 |
+
parser.add_argument("--seed", type=int, default=23)
|
| 500 |
+
|
| 501 |
+
args = parser.parse_args()
|
| 502 |
+
main(args)
|
BioReason/train_dna_qwen.py
ADDED
|
@@ -0,0 +1,1064 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import gc
|
| 3 |
+
import io
|
| 4 |
+
import multiprocessing
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import traceback
|
| 8 |
+
from argparse import ArgumentParser
|
| 9 |
+
from functools import partial
|
| 10 |
+
from typing import *
|
| 11 |
+
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
import wandb
|
| 15 |
+
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
| 16 |
+
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
|
| 17 |
+
from torch.optim import AdamW
|
| 18 |
+
from torch.utils.data import DataLoader
|
| 19 |
+
from transformers import AutoTokenizer, get_cosine_schedule_with_warmup
|
| 20 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 21 |
+
|
| 22 |
+
import pytorch_lightning as pl
|
| 23 |
+
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
| 24 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 25 |
+
from pytorch_lightning.strategies import DeepSpeedStrategy
|
| 26 |
+
|
| 27 |
+
from bioreason.dataset.kegg import get_format_kegg_function, qwen_dna_collate_fn
|
| 28 |
+
from bioreason.dataset.utils import truncate_dna
|
| 29 |
+
from bioreason.dataset.variant_effect import (
|
| 30 |
+
clean_variant_effect_example,
|
| 31 |
+
clean_variant_effect_non_snv_example,
|
| 32 |
+
get_format_variant_effect_function,
|
| 33 |
+
)
|
| 34 |
+
from bioreason.models.dl.processing_dl import DLProcessor
|
| 35 |
+
from bioreason.models.dna_llm import DNALLMModel
|
| 36 |
+
from bioreason.models.evo2_tokenizer import register_evo2_tokenizer
|
| 37 |
+
|
| 38 |
+
register_evo2_tokenizer()
|
| 39 |
+
|
| 40 |
+
# Set start method to 'spawn' for CUDA compatibility with multiprocessing
|
| 41 |
+
torch.multiprocessing.set_sharing_strategy("file_system")
|
| 42 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DNALLMFineTuner(pl.LightningModule):
|
| 46 |
+
"""
|
| 47 |
+
PyTorch Lightning module for fine-tuning DNA-LLM models.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(self, hparams):
|
| 51 |
+
"""
|
| 52 |
+
Initialize the DNALLMFineTuner.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
hparams: Hyperparameters for the model and training
|
| 56 |
+
"""
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.save_hyperparameters(hparams)
|
| 59 |
+
|
| 60 |
+
self.text_model_name = self.hparams.text_model_name
|
| 61 |
+
self.dna_model_name = self.hparams.dna_model_name
|
| 62 |
+
self.cache_dir = self.hparams.cache_dir
|
| 63 |
+
self.learning_rate = self.hparams.learning_rate
|
| 64 |
+
self.weight_decay = self.hparams.weight_decay
|
| 65 |
+
self.text_model_finetune = self.hparams.text_model_finetune
|
| 66 |
+
self.dna_model_finetune = self.hparams.dna_model_finetune
|
| 67 |
+
self.lora_rank = self.hparams.lora_rank
|
| 68 |
+
self.lora_alpha = self.hparams.lora_alpha
|
| 69 |
+
self.lora_dropout = self.hparams.lora_dropout
|
| 70 |
+
self.max_length_dna = self.hparams.max_length_dna
|
| 71 |
+
self.max_length_text = self.hparams.max_length_text
|
| 72 |
+
self.dna_is_evo2 = self.hparams.dna_is_evo2
|
| 73 |
+
self.dna_embedding_layer = self.hparams.dna_embedding_layer
|
| 74 |
+
self.return_answer_in_batch = self.hparams.return_answer_in_batch
|
| 75 |
+
self.merge_val_test_set = self.hparams.merge_val_test_set
|
| 76 |
+
|
| 77 |
+
# Store dataset configuration
|
| 78 |
+
self.dataset_type = self.hparams.dataset_type
|
| 79 |
+
|
| 80 |
+
# Load model
|
| 81 |
+
self.model = DNALLMModel(
|
| 82 |
+
text_model_name=self.text_model_name,
|
| 83 |
+
dna_model_name=self.dna_model_name,
|
| 84 |
+
cache_dir=self.cache_dir,
|
| 85 |
+
max_length_dna=self.max_length_dna,
|
| 86 |
+
max_length_text=self.max_length_text,
|
| 87 |
+
text_model_finetune=self.text_model_finetune,
|
| 88 |
+
dna_model_finetune=self.dna_model_finetune,
|
| 89 |
+
dna_is_evo2=self.dna_is_evo2,
|
| 90 |
+
dna_embedding_layer=self.dna_embedding_layer,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
self.text_model = self.model.text_model
|
| 94 |
+
self.dna_model = self.model.dna_model
|
| 95 |
+
self.dna_projection = self.model.dna_projection
|
| 96 |
+
|
| 97 |
+
# Load tokenizer for target text
|
| 98 |
+
self.tokenizer = self.model.text_tokenizer
|
| 99 |
+
|
| 100 |
+
# Prepare model for training
|
| 101 |
+
self.lora_config = self._prep_for_training()
|
| 102 |
+
|
| 103 |
+
def _get_target_modules(self):
|
| 104 |
+
# Apply LoRA to all linear layers in the text model
|
| 105 |
+
target_modules = []
|
| 106 |
+
|
| 107 |
+
# Get all unique linear layer names
|
| 108 |
+
seen_names = set()
|
| 109 |
+
for name, module in self.text_model.named_modules():
|
| 110 |
+
if isinstance(module, torch.nn.Linear):
|
| 111 |
+
names = name.split(".")
|
| 112 |
+
target_name = names[-1] # Use the last part of the name
|
| 113 |
+
|
| 114 |
+
# Skip output head but include all other linear layers
|
| 115 |
+
if target_name != "lm_head" and target_name not in seen_names:
|
| 116 |
+
target_modules.append(target_name)
|
| 117 |
+
seen_names.add(target_name)
|
| 118 |
+
|
| 119 |
+
# Add attention-specific layers
|
| 120 |
+
attention_patterns = [
|
| 121 |
+
"q_proj",
|
| 122 |
+
"k_proj",
|
| 123 |
+
"v_proj",
|
| 124 |
+
"out_proj",
|
| 125 |
+
"query",
|
| 126 |
+
"key",
|
| 127 |
+
"value",
|
| 128 |
+
]
|
| 129 |
+
for pattern in attention_patterns:
|
| 130 |
+
if pattern not in seen_names:
|
| 131 |
+
target_modules.append(pattern)
|
| 132 |
+
|
| 133 |
+
# Return all unique layer names to apply LoRA to all layers
|
| 134 |
+
return list(target_modules)
|
| 135 |
+
|
| 136 |
+
def _prep_for_training(self) -> LoraConfig:
|
| 137 |
+
"""
|
| 138 |
+
Load and configure the DNALLMModel.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
# Freeze DNA encoder parameters
|
| 142 |
+
if self.dna_model_finetune:
|
| 143 |
+
pass
|
| 144 |
+
else:
|
| 145 |
+
if self.dna_is_evo2:
|
| 146 |
+
for param in self.dna_model.model.parameters():
|
| 147 |
+
param.requires_grad = False
|
| 148 |
+
else:
|
| 149 |
+
for param in self.dna_model.parameters():
|
| 150 |
+
param.requires_grad = False
|
| 151 |
+
|
| 152 |
+
if self.text_model_finetune:
|
| 153 |
+
target_modules = self._get_target_modules()
|
| 154 |
+
|
| 155 |
+
lora_config = LoraConfig(
|
| 156 |
+
r=self.lora_rank,
|
| 157 |
+
lora_alpha=self.lora_alpha,
|
| 158 |
+
lora_dropout=self.lora_dropout,
|
| 159 |
+
target_modules=target_modules,
|
| 160 |
+
init_lora_weights="gaussian",
|
| 161 |
+
bias="none",
|
| 162 |
+
task_type="CAUSAL_LM",
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Prepare text model for training
|
| 166 |
+
self.text_model = prepare_model_for_kbit_training(self.text_model)
|
| 167 |
+
self.text_model = get_peft_model(self.text_model, lora_config)
|
| 168 |
+
else:
|
| 169 |
+
# Freeze text model parameters
|
| 170 |
+
for param in self.text_model.parameters():
|
| 171 |
+
param.requires_grad = False
|
| 172 |
+
|
| 173 |
+
# Make projection layer trainable
|
| 174 |
+
for param in self.dna_projection.parameters():
|
| 175 |
+
param.requires_grad = True
|
| 176 |
+
|
| 177 |
+
return lora_config
|
| 178 |
+
|
| 179 |
+
def _step(self, batch: Dict, batch_idx: int, prefix: str) -> torch.Tensor:
|
| 180 |
+
"""
|
| 181 |
+
Performs a single step for training, validation, or testing.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
batch: Dictionary containing the batch data
|
| 185 |
+
batch_idx: Integer indicating the batch index
|
| 186 |
+
prefix: String indicating the step type ('train', 'val', or 'test')
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
torch.Tensor: The computed loss for this batch
|
| 190 |
+
"""
|
| 191 |
+
if prefix == "test":
|
| 192 |
+
return {"loss": torch.tensor(0.0, device=self.device)}
|
| 193 |
+
|
| 194 |
+
# Get batch data from the collate function
|
| 195 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 196 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 197 |
+
labels = batch["labels"].to(self.device) if "labels" in batch else None
|
| 198 |
+
dna_tokenized = batch.get("dna_tokenized")
|
| 199 |
+
if dna_tokenized is not None:
|
| 200 |
+
dna_tokenized = dna_tokenized.to(self.device)
|
| 201 |
+
batch_idx_map = batch.get("batch_idx_map")
|
| 202 |
+
|
| 203 |
+
# Forward pass through the model
|
| 204 |
+
outputs = self.model(
|
| 205 |
+
input_ids=input_ids,
|
| 206 |
+
attention_mask=attention_mask,
|
| 207 |
+
dna_tokenized=dna_tokenized,
|
| 208 |
+
batch_idx_map=batch_idx_map,
|
| 209 |
+
labels=labels,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# Get the loss from model outputs
|
| 213 |
+
loss = outputs.loss
|
| 214 |
+
|
| 215 |
+
# Occasionally show generations for debugging purposes - ONLY during training/validation
|
| 216 |
+
# You can reduce the frequency of generations by increasing the step size to make the model train faster
|
| 217 |
+
if (prefix == "train" and (self.global_step % 3000 == 0)) or (prefix == "val" and (batch_idx % 300 == 0)):
|
| 218 |
+
try:
|
| 219 |
+
# Select first example from batch for demonstration
|
| 220 |
+
example_idx = 0
|
| 221 |
+
|
| 222 |
+
print(
|
| 223 |
+
f"\n=== Sample Generation (step {self.global_step} / {self.trainer.estimated_stepping_batches}) ==="
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Get the tokens that define the assistant pattern
|
| 227 |
+
assistant_start_marker = "<|im_start|>assistant\n"
|
| 228 |
+
assistant_marker_tokens = self.tokenizer.encode(assistant_start_marker, add_special_tokens=False)
|
| 229 |
+
marker_tensor = torch.tensor(assistant_marker_tokens, device=input_ids.device)
|
| 230 |
+
marker_len = len(assistant_marker_tokens)
|
| 231 |
+
|
| 232 |
+
# Find non-padding tokens in input
|
| 233 |
+
non_pad = (input_ids[example_idx] != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
|
| 234 |
+
if len(non_pad) > 0:
|
| 235 |
+
start_idx = non_pad[0].item() # First non-padding token
|
| 236 |
+
else:
|
| 237 |
+
start_idx = 0
|
| 238 |
+
|
| 239 |
+
# For each position, check if the next marker_len tokens match the pattern
|
| 240 |
+
matches = []
|
| 241 |
+
for pos in range(start_idx, input_ids.size(1) - marker_len + 1):
|
| 242 |
+
if torch.all(input_ids[example_idx, pos : pos + marker_len] == marker_tensor):
|
| 243 |
+
matches.append(pos)
|
| 244 |
+
break # Stop at first match
|
| 245 |
+
|
| 246 |
+
assistant_pos = matches[0] if matches else None
|
| 247 |
+
|
| 248 |
+
if assistant_pos is not None:
|
| 249 |
+
# Get input up to and including the assistant marker
|
| 250 |
+
gen_input_ids = input_ids[
|
| 251 |
+
example_idx : example_idx + 1, start_idx : assistant_pos + marker_len
|
| 252 |
+
]
|
| 253 |
+
gen_attention_mask = attention_mask[
|
| 254 |
+
example_idx : example_idx + 1, start_idx : assistant_pos + marker_len
|
| 255 |
+
]
|
| 256 |
+
|
| 257 |
+
# Extract DNA data for this example
|
| 258 |
+
example_dna_data = None
|
| 259 |
+
example_batch_map = None
|
| 260 |
+
|
| 261 |
+
if dna_tokenized is not None and batch_idx_map is not None:
|
| 262 |
+
# Find DNA sequences for this example
|
| 263 |
+
example_indices = [i for i, idx in enumerate(batch_idx_map) if idx == example_idx]
|
| 264 |
+
|
| 265 |
+
if len(example_indices) > 0:
|
| 266 |
+
# Extract just this example's DNA data
|
| 267 |
+
example_dna_data = BatchEncoding(
|
| 268 |
+
{
|
| 269 |
+
"input_ids": dna_tokenized.input_ids[example_indices].to(self.device),
|
| 270 |
+
"attention_mask": dna_tokenized.attention_mask[example_indices].to(self.device),
|
| 271 |
+
}
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# For generation we need all sequences mapped to index 0
|
| 275 |
+
example_batch_map = [0] * len(example_indices)
|
| 276 |
+
|
| 277 |
+
# Generate text
|
| 278 |
+
with torch.no_grad():
|
| 279 |
+
generated = self.model.generate(
|
| 280 |
+
input_ids=gen_input_ids,
|
| 281 |
+
attention_mask=gen_attention_mask,
|
| 282 |
+
dna_tokenized=example_dna_data,
|
| 283 |
+
batch_idx_map=example_batch_map,
|
| 284 |
+
max_new_tokens=800,
|
| 285 |
+
temperature=0.6,
|
| 286 |
+
top_p=0.95,
|
| 287 |
+
top_k=20,
|
| 288 |
+
do_sample=True,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Decode and display
|
| 292 |
+
user_input = self.tokenizer.decode(gen_input_ids[0], skip_special_tokens=False).strip()
|
| 293 |
+
generation = self.tokenizer.decode(generated[0], skip_special_tokens=False).strip()
|
| 294 |
+
|
| 295 |
+
# Free memory early
|
| 296 |
+
del generated, gen_input_ids, gen_attention_mask, example_dna_data, example_batch_map
|
| 297 |
+
gc.collect()
|
| 298 |
+
|
| 299 |
+
print(f"=====[Sample {prefix} {batch_idx}]=====")
|
| 300 |
+
print(f"=====[User input]=====\n{user_input}")
|
| 301 |
+
print(f"=====[Complete generation]=====\n{generation}")
|
| 302 |
+
|
| 303 |
+
# Get ground truth if available
|
| 304 |
+
ground_truth = ""
|
| 305 |
+
if labels is not None:
|
| 306 |
+
# Find all positions where we have valid labels (not -100)
|
| 307 |
+
valid_label_pos = (labels[example_idx] != -100).nonzero(as_tuple=True)[0]
|
| 308 |
+
|
| 309 |
+
if len(valid_label_pos) > 0:
|
| 310 |
+
# Check if valid labels start after assistant marker
|
| 311 |
+
if valid_label_pos[0] >= assistant_pos + marker_len:
|
| 312 |
+
ground_truth = self.tokenizer.decode(
|
| 313 |
+
input_ids[example_idx, valid_label_pos], skip_special_tokens=False
|
| 314 |
+
).strip()
|
| 315 |
+
print(f"=====[Ground truth]=====\n{ground_truth}")
|
| 316 |
+
|
| 317 |
+
# Log to wandb
|
| 318 |
+
timestamp = time.time()
|
| 319 |
+
step_id = f"gen_{self.global_step}-{timestamp}"
|
| 320 |
+
wandb_logger = self.logger.experiment
|
| 321 |
+
wandb_logger.log(
|
| 322 |
+
{
|
| 323 |
+
step_id: wandb.Table(
|
| 324 |
+
columns=["timestamp", "prefix", "batch_idx", "user_input", "generation", "ground_truth"],
|
| 325 |
+
data=[[timestamp, prefix, batch_idx, user_input, generation, ground_truth]],
|
| 326 |
+
)
|
| 327 |
+
}
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
# Clean up memory
|
| 331 |
+
del user_input, generation, ground_truth
|
| 332 |
+
torch.cuda.empty_cache()
|
| 333 |
+
gc.collect()
|
| 334 |
+
|
| 335 |
+
else:
|
| 336 |
+
print("No assistant marker found in the input sequence")
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
print(f"Error during sample generation: {str(e)}")
|
| 340 |
+
traceback.print_exc()
|
| 341 |
+
|
| 342 |
+
# Get current learning rate (skip during test as scheduler might not be available)
|
| 343 |
+
if prefix != "test":
|
| 344 |
+
current_lr = self.lr_schedulers().get_last_lr()[0]
|
| 345 |
+
else:
|
| 346 |
+
current_lr = 0
|
| 347 |
+
|
| 348 |
+
# Logging metrics
|
| 349 |
+
self.log(
|
| 350 |
+
f"{prefix}_loss",
|
| 351 |
+
loss,
|
| 352 |
+
on_step=True,
|
| 353 |
+
on_epoch=False,
|
| 354 |
+
prog_bar=True,
|
| 355 |
+
logger=True,
|
| 356 |
+
)
|
| 357 |
+
self.log(
|
| 358 |
+
f"{prefix}_loss_epoch",
|
| 359 |
+
loss,
|
| 360 |
+
on_step=False,
|
| 361 |
+
on_epoch=True,
|
| 362 |
+
prog_bar=True,
|
| 363 |
+
logger=True,
|
| 364 |
+
sync_dist=True,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Only log learning rate during training/validation
|
| 368 |
+
if prefix != "test":
|
| 369 |
+
self.log(
|
| 370 |
+
"lr",
|
| 371 |
+
current_lr,
|
| 372 |
+
on_step=True,
|
| 373 |
+
on_epoch=True,
|
| 374 |
+
prog_bar=True,
|
| 375 |
+
logger=True,
|
| 376 |
+
sync_dist=True,
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
return loss
|
| 380 |
+
|
| 381 |
+
def training_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
|
| 382 |
+
"""Perform a single training step."""
|
| 383 |
+
return self._step(batch, batch_idx, prefix="train")
|
| 384 |
+
|
| 385 |
+
def validation_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
|
| 386 |
+
"""Perform a single validation step."""
|
| 387 |
+
return self._step(batch, batch_idx, prefix="val")
|
| 388 |
+
|
| 389 |
+
def test_step(self, batch: Dict, batch_idx: int) -> torch.Tensor:
|
| 390 |
+
"""Perform a single test step."""
|
| 391 |
+
return self._step(batch, batch_idx, prefix="test")
|
| 392 |
+
|
| 393 |
+
def configure_optimizers(self):
|
| 394 |
+
"""
|
| 395 |
+
Configure optimizers and learning rate schedulers.
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
Tuple[List, List]: A tuple containing a list of optimizers and schedulers
|
| 399 |
+
"""
|
| 400 |
+
optimizer = AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
|
| 401 |
+
|
| 402 |
+
total_steps = self.trainer.estimated_stepping_batches
|
| 403 |
+
warmup_steps = int(0.1 * total_steps)
|
| 404 |
+
|
| 405 |
+
scheduler = get_cosine_schedule_with_warmup(
|
| 406 |
+
optimizer,
|
| 407 |
+
num_warmup_steps=warmup_steps,
|
| 408 |
+
num_training_steps=total_steps,
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
| 412 |
+
|
| 413 |
+
def train_dataloader(self) -> DataLoader:
|
| 414 |
+
"""Create and return the training DataLoader."""
|
| 415 |
+
# Load dataset based on type specified in hyperparameters
|
| 416 |
+
|
| 417 |
+
if self.hparams.dataset_type == "kegg":
|
| 418 |
+
# Use Hugging Face dataset if provided
|
| 419 |
+
dataset = load_dataset(self.hparams.kegg_data_dir_huggingface)
|
| 420 |
+
dataset = dataset.map(get_format_kegg_function(self.hparams.model_type))
|
| 421 |
+
|
| 422 |
+
labels = []
|
| 423 |
+
for split, data in dataset.items():
|
| 424 |
+
labels.extend(data["answer"])
|
| 425 |
+
self.labels = sorted(list(set(labels)))
|
| 426 |
+
|
| 427 |
+
train_dataset = dataset["train"]
|
| 428 |
+
|
| 429 |
+
if self.hparams.truncate_dna_per_side:
|
| 430 |
+
train_dataset = train_dataset.map(
|
| 431 |
+
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
processor = DLProcessor(
|
| 435 |
+
tokenizer=self.model.text_tokenizer,
|
| 436 |
+
dna_tokenizer=self.model.dna_tokenizer,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
# Create partial function with all required arguments except the batch
|
| 440 |
+
collate_fn = partial(
|
| 441 |
+
qwen_dna_collate_fn,
|
| 442 |
+
processor=processor,
|
| 443 |
+
max_length_text=self.max_length_text,
|
| 444 |
+
max_length_dna=self.max_length_dna,
|
| 445 |
+
return_answer_in_batch=self.return_answer_in_batch,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
elif self.hparams.dataset_type == "variant_effect_coding":
|
| 450 |
+
dataset = load_dataset(self.hparams.variant_effect_coding_data_dir_huggingface)
|
| 451 |
+
cleaned_dataset = dataset.map(clean_variant_effect_example)
|
| 452 |
+
dataset = dataset.map(get_format_variant_effect_function(self.hparams.model_type))
|
| 453 |
+
|
| 454 |
+
labels = []
|
| 455 |
+
for split, data in cleaned_dataset.items():
|
| 456 |
+
labels.extend(data["answer"])
|
| 457 |
+
self.labels = sorted(list(set(labels)))
|
| 458 |
+
|
| 459 |
+
train_dataset = dataset["train"]
|
| 460 |
+
|
| 461 |
+
if self.hparams.truncate_dna_per_side:
|
| 462 |
+
train_dataset = train_dataset.map(
|
| 463 |
+
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
processor = DLProcessor(
|
| 467 |
+
tokenizer=self.model.text_tokenizer,
|
| 468 |
+
dna_tokenizer=self.model.dna_tokenizer,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# Create partial function with all required arguments except the batch
|
| 472 |
+
collate_fn = partial(
|
| 473 |
+
qwen_dna_collate_fn,
|
| 474 |
+
processor=processor,
|
| 475 |
+
max_length_text=self.max_length_text,
|
| 476 |
+
max_length_dna=self.max_length_dna,
|
| 477 |
+
return_answer_in_batch=self.return_answer_in_batch,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
elif self.hparams.dataset_type == "variant_effect_non_snv":
|
| 481 |
+
dataset = load_dataset(self.hparams.variant_effect_non_snv_data_dir_huggingface)
|
| 482 |
+
dataset = dataset.map(clean_variant_effect_non_snv_example)
|
| 483 |
+
cleaned_dataset = dataset.map(clean_variant_effect_example)
|
| 484 |
+
dataset = dataset.rename_column("mutated_sequence", "variant_sequence")
|
| 485 |
+
|
| 486 |
+
labels = []
|
| 487 |
+
for split, data in cleaned_dataset.items():
|
| 488 |
+
labels.extend(data["answer"])
|
| 489 |
+
self.labels = sorted(list(set(labels)))
|
| 490 |
+
|
| 491 |
+
train_dataset = dataset["train"]
|
| 492 |
+
|
| 493 |
+
if self.hparams.truncate_dna_per_side:
|
| 494 |
+
train_dataset = train_dataset.map(
|
| 495 |
+
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
|
| 496 |
+
)
|
| 497 |
+
train_dataset = train_dataset.map(get_format_variant_effect_function(self.hparams.model_type))
|
| 498 |
+
|
| 499 |
+
processor = DLProcessor(
|
| 500 |
+
tokenizer=self.model.text_tokenizer,
|
| 501 |
+
dna_tokenizer=self.model.dna_tokenizer,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# Create partial function with all required arguments except the batch
|
| 505 |
+
collate_fn = partial(
|
| 506 |
+
qwen_dna_collate_fn,
|
| 507 |
+
processor=processor,
|
| 508 |
+
max_length_text=self.max_length_text,
|
| 509 |
+
max_length_dna=self.max_length_dna,
|
| 510 |
+
return_answer_in_batch=self.return_answer_in_batch,
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
else:
|
| 514 |
+
raise ValueError(f"Unknown dataset type: {self.hparams.dataset_type}")
|
| 515 |
+
|
| 516 |
+
return DataLoader(
|
| 517 |
+
train_dataset,
|
| 518 |
+
batch_size=self.hparams.batch_size,
|
| 519 |
+
shuffle=True,
|
| 520 |
+
collate_fn=collate_fn,
|
| 521 |
+
num_workers=self.hparams.num_workers,
|
| 522 |
+
persistent_workers=False,
|
| 523 |
+
pin_memory=False,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
def val_dataloader(self) -> DataLoader:
|
| 527 |
+
"""Create and return the validation DataLoader."""
|
| 528 |
+
|
| 529 |
+
if self.hparams.dataset_type == "kegg":
|
| 530 |
+
# Use Hugging Face dataset
|
| 531 |
+
dataset = load_dataset(self.hparams.kegg_data_dir_huggingface)
|
| 532 |
+
dataset = dataset.map(get_format_kegg_function(self.hparams.model_type))
|
| 533 |
+
|
| 534 |
+
if self.hparams.merge_val_test_set:
|
| 535 |
+
val_dataset = concatenate_datasets([dataset['test'], dataset['val']])
|
| 536 |
+
else:
|
| 537 |
+
val_dataset = dataset["val"]
|
| 538 |
+
|
| 539 |
+
labels = []
|
| 540 |
+
for split, data in dataset.items():
|
| 541 |
+
labels.extend(data["answer"])
|
| 542 |
+
self.labels = sorted(list(set(labels)))
|
| 543 |
+
|
| 544 |
+
if self.hparams.truncate_dna_per_side:
|
| 545 |
+
val_dataset = val_dataset.map(
|
| 546 |
+
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
processor = DLProcessor(
|
| 550 |
+
tokenizer=self.model.text_tokenizer,
|
| 551 |
+
dna_tokenizer=self.model.dna_tokenizer,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
# Create partial function with all required arguments except the batch
|
| 555 |
+
collate_fn = partial(
|
| 556 |
+
qwen_dna_collate_fn,
|
| 557 |
+
processor=processor,
|
| 558 |
+
max_length_text=self.max_length_text,
|
| 559 |
+
max_length_dna=self.max_length_dna,
|
| 560 |
+
return_answer_in_batch=self.return_answer_in_batch,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
elif self.hparams.dataset_type == "variant_effect_coding":
|
| 564 |
+
dataset = load_dataset(self.hparams.variant_effect_coding_data_dir_huggingface)
|
| 565 |
+
cleaned_dataset = dataset.map(clean_variant_effect_example)
|
| 566 |
+
dataset = dataset.map(get_format_variant_effect_function(self.hparams.model_type))
|
| 567 |
+
|
| 568 |
+
labels = []
|
| 569 |
+
for split, data in cleaned_dataset.items():
|
| 570 |
+
labels.extend(data["answer"])
|
| 571 |
+
self.labels = sorted(list(set(labels)))
|
| 572 |
+
|
| 573 |
+
val_dataset = dataset["test"]
|
| 574 |
+
|
| 575 |
+
if self.hparams.truncate_dna_per_side:
|
| 576 |
+
val_dataset = val_dataset.map(
|
| 577 |
+
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
processor = DLProcessor(
|
| 581 |
+
tokenizer=self.model.text_tokenizer,
|
| 582 |
+
dna_tokenizer=self.model.dna_tokenizer,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# Create partial function with all required arguments except the batch
|
| 586 |
+
collate_fn = partial(
|
| 587 |
+
qwen_dna_collate_fn,
|
| 588 |
+
processor=processor,
|
| 589 |
+
max_length_text=self.max_length_text,
|
| 590 |
+
max_length_dna=self.max_length_dna,
|
| 591 |
+
return_answer_in_batch=self.return_answer_in_batch,
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
elif self.hparams.dataset_type == "variant_effect_non_snv":
|
| 595 |
+
dataset = load_dataset(self.hparams.variant_effect_non_snv_data_dir_huggingface)
|
| 596 |
+
cleaned_dataset = dataset.map(clean_variant_effect_example)
|
| 597 |
+
dataset = dataset.map(clean_variant_effect_non_snv_example)
|
| 598 |
+
|
| 599 |
+
labels = []
|
| 600 |
+
for split, data in cleaned_dataset.items():
|
| 601 |
+
labels.extend(data["answer"])
|
| 602 |
+
self.labels = sorted(list(set(labels)))
|
| 603 |
+
|
| 604 |
+
dataset = dataset.rename_column("mutated_sequence", "variant_sequence")
|
| 605 |
+
val_dataset = dataset["test"]
|
| 606 |
+
|
| 607 |
+
if self.hparams.truncate_dna_per_side:
|
| 608 |
+
val_dataset = val_dataset.map(
|
| 609 |
+
truncate_dna, fn_kwargs={"truncate_dna_per_side": self.hparams.truncate_dna_per_side}
|
| 610 |
+
)
|
| 611 |
+
val_dataset = val_dataset.map(get_format_variant_effect_function(self.hparams.model_type))
|
| 612 |
+
|
| 613 |
+
processor = DLProcessor(
|
| 614 |
+
tokenizer=self.model.text_tokenizer,
|
| 615 |
+
dna_tokenizer=self.model.dna_tokenizer,
|
| 616 |
+
)
|
| 617 |
+
|
| 618 |
+
# Create partial function with all required arguments except the batch
|
| 619 |
+
collate_fn = partial(
|
| 620 |
+
qwen_dna_collate_fn,
|
| 621 |
+
processor=processor,
|
| 622 |
+
max_length_text=self.max_length_text,
|
| 623 |
+
max_length_dna=self.max_length_dna,
|
| 624 |
+
return_answer_in_batch=self.return_answer_in_batch,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
else:
|
| 628 |
+
raise ValueError(f"Unknown dataset type: {self.hparams.dataset_type}")
|
| 629 |
+
|
| 630 |
+
return DataLoader(
|
| 631 |
+
val_dataset,
|
| 632 |
+
batch_size=self.hparams.batch_size,
|
| 633 |
+
shuffle=False,
|
| 634 |
+
collate_fn=collate_fn,
|
| 635 |
+
num_workers=self.hparams.num_workers,
|
| 636 |
+
persistent_workers=False,
|
| 637 |
+
pin_memory=False,
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
def test_dataloader(self) -> DataLoader:
|
| 641 |
+
"""Create and return the test DataLoader."""
|
| 642 |
+
return self.val_dataloader()
|
| 643 |
+
|
| 644 |
+
# Only for VEP datasets, for KEGG use the resulting generations in W&B
|
| 645 |
+
def on_test_epoch_end(self):
|
| 646 |
+
"""
|
| 647 |
+
Called at the end of test epoch to generate text for all test examples
|
| 648 |
+
and calculate accuracy, precision, recall, and F1 score based on whether
|
| 649 |
+
the label appears in the generated response.
|
| 650 |
+
"""
|
| 651 |
+
# Get wandb logger
|
| 652 |
+
wandb_logger = self.logger.experiment
|
| 653 |
+
wandb_logger.log({"test_progress": 0.0, "status": "starting test generation"})
|
| 654 |
+
|
| 655 |
+
# Set model to eval mode
|
| 656 |
+
self.model.eval()
|
| 657 |
+
|
| 658 |
+
# Get test dataloader
|
| 659 |
+
test_dataloader = self.test_dataloader()
|
| 660 |
+
total_batches = len(test_dataloader)
|
| 661 |
+
|
| 662 |
+
# Get negative and positive labels
|
| 663 |
+
neg_label = self.labels[0] # Negative label (first item)
|
| 664 |
+
pos_label = self.labels[1] # Positive label (second item)
|
| 665 |
+
|
| 666 |
+
# Log label information
|
| 667 |
+
wandb_logger.log({
|
| 668 |
+
"positive_label": pos_label,
|
| 669 |
+
"negative_label": neg_label
|
| 670 |
+
})
|
| 671 |
+
print(f"Using labels - Positive: '{pos_label}', Negative: '{neg_label}'")
|
| 672 |
+
|
| 673 |
+
# Initialize counters and storage for generations
|
| 674 |
+
total_examples = 0
|
| 675 |
+
true_positives = 0
|
| 676 |
+
true_negatives = 0
|
| 677 |
+
false_positives = 0
|
| 678 |
+
false_negatives = 0
|
| 679 |
+
processed_batches = 0
|
| 680 |
+
generations = []
|
| 681 |
+
|
| 682 |
+
# Process each batch in the test dataloader
|
| 683 |
+
for batch_idx, batch in enumerate(test_dataloader):
|
| 684 |
+
# Log batch start to wandb
|
| 685 |
+
wandb_logger.log({
|
| 686 |
+
"test_progress": batch_idx / total_batches,
|
| 687 |
+
"status": f"processing batch {batch_idx}/{total_batches}"
|
| 688 |
+
})
|
| 689 |
+
|
| 690 |
+
# Get batch data
|
| 691 |
+
input_ids = batch["input_ids"].to(self.device)
|
| 692 |
+
attention_mask = batch["attention_mask"].to(self.device)
|
| 693 |
+
answer = batch["answer"]
|
| 694 |
+
dna_tokenized = batch.get("dna_tokenized")
|
| 695 |
+
if dna_tokenized is not None:
|
| 696 |
+
dna_tokenized = dna_tokenized.to(self.device)
|
| 697 |
+
batch_idx_map = batch.get("batch_idx_map")
|
| 698 |
+
|
| 699 |
+
# Get assistant marker position
|
| 700 |
+
assistant_start_marker = "<|im_start|>assistant\n"
|
| 701 |
+
assistant_marker_tokens = self.tokenizer.encode(assistant_start_marker, add_special_tokens=False)
|
| 702 |
+
marker_tensor = torch.tensor(assistant_marker_tokens, device=input_ids.device)
|
| 703 |
+
marker_len = len(assistant_marker_tokens)
|
| 704 |
+
|
| 705 |
+
# Log batch metadata to wandb
|
| 706 |
+
wandb_logger.log({
|
| 707 |
+
"batch_size": input_ids.shape[0],
|
| 708 |
+
"input_sequence_length": input_ids.shape[1]
|
| 709 |
+
})
|
| 710 |
+
|
| 711 |
+
# Process examples in the batch
|
| 712 |
+
examples_in_batch = 0
|
| 713 |
+
for example_idx in range(input_ids.size(0)):
|
| 714 |
+
# Log example progress to wandb
|
| 715 |
+
if total_examples % 10 == 0:
|
| 716 |
+
current_accuracy = (true_positives + true_negatives) / max(1, total_examples)
|
| 717 |
+
wandb_logger.log({
|
| 718 |
+
"examples_processed": total_examples,
|
| 719 |
+
"current_accuracy": current_accuracy
|
| 720 |
+
})
|
| 721 |
+
|
| 722 |
+
# Find non-padding tokens
|
| 723 |
+
non_pad = (input_ids[example_idx] != self.tokenizer.pad_token_id).nonzero(as_tuple=True)[0]
|
| 724 |
+
start_idx = non_pad[0].item() if len(non_pad) > 0 else 0
|
| 725 |
+
|
| 726 |
+
# Find assistant marker position
|
| 727 |
+
assistant_pos = None
|
| 728 |
+
for pos in range(start_idx, input_ids.size(1) - marker_len + 1):
|
| 729 |
+
if torch.all(input_ids[example_idx, pos:pos + marker_len] == marker_tensor):
|
| 730 |
+
assistant_pos = pos
|
| 731 |
+
break
|
| 732 |
+
|
| 733 |
+
# Log to wandb if assistant marker was found
|
| 734 |
+
wandb_logger.log({"assistant_marker_found": assistant_pos is not None})
|
| 735 |
+
|
| 736 |
+
if assistant_pos is not None:
|
| 737 |
+
# Prepare input for generation
|
| 738 |
+
gen_input_ids = input_ids[example_idx:example_idx + 1, start_idx:assistant_pos + marker_len]
|
| 739 |
+
gen_attention_mask = attention_mask[example_idx:example_idx + 1, start_idx:assistant_pos + marker_len]
|
| 740 |
+
|
| 741 |
+
# Extract DNA data for this example
|
| 742 |
+
example_dna_data = None
|
| 743 |
+
example_batch_map = None
|
| 744 |
+
|
| 745 |
+
if dna_tokenized is not None and batch_idx_map is not None:
|
| 746 |
+
example_indices = [i for i, idx in enumerate(batch_idx_map) if idx == example_idx]
|
| 747 |
+
|
| 748 |
+
if example_indices:
|
| 749 |
+
example_dna_data = BatchEncoding({
|
| 750 |
+
"input_ids": dna_tokenized.input_ids[example_indices].to(self.device),
|
| 751 |
+
"attention_mask": dna_tokenized.attention_mask[example_indices].to(self.device),
|
| 752 |
+
})
|
| 753 |
+
example_batch_map = [0] * len(example_indices)
|
| 754 |
+
|
| 755 |
+
# Log generation start to wandb
|
| 756 |
+
wandb_logger.log({"status": f"generating for example {example_idx} in batch {batch_idx}"})
|
| 757 |
+
|
| 758 |
+
# Generate text
|
| 759 |
+
with torch.no_grad():
|
| 760 |
+
generated = self.model.generate(
|
| 761 |
+
input_ids=gen_input_ids,
|
| 762 |
+
attention_mask=gen_attention_mask,
|
| 763 |
+
dna_tokenized=example_dna_data,
|
| 764 |
+
batch_idx_map=example_batch_map,
|
| 765 |
+
max_new_tokens=800,
|
| 766 |
+
temperature=0.6,
|
| 767 |
+
top_p=0.95,
|
| 768 |
+
top_k=20,
|
| 769 |
+
do_sample=True,
|
| 770 |
+
)
|
| 771 |
+
|
| 772 |
+
# Decode user input and generated text
|
| 773 |
+
user_input = self.tokenizer.decode(gen_input_ids[0], skip_special_tokens=False).strip()
|
| 774 |
+
generation = self.tokenizer.decode(generated[0], skip_special_tokens=False).strip()
|
| 775 |
+
|
| 776 |
+
# Get ground truth and clean it if needed
|
| 777 |
+
ground_truth = answer[example_idx]
|
| 778 |
+
if ";" in ground_truth:
|
| 779 |
+
ground_truth = ground_truth.split(";")[0]
|
| 780 |
+
|
| 781 |
+
# Determine if this is a positive or negative example
|
| 782 |
+
is_positive_example = ground_truth.lower() == pos_label.lower()
|
| 783 |
+
is_negative_example = ground_truth.lower() == neg_label.lower()
|
| 784 |
+
|
| 785 |
+
# Check if the generated text contains the ground truth
|
| 786 |
+
generation_contains_ground_truth = ground_truth.lower() in generation.lower()
|
| 787 |
+
|
| 788 |
+
# Update metrics based on the classification
|
| 789 |
+
total_examples += 1
|
| 790 |
+
examples_in_batch += 1
|
| 791 |
+
|
| 792 |
+
if is_positive_example and generation_contains_ground_truth:
|
| 793 |
+
true_positives += 1
|
| 794 |
+
elif is_positive_example and not generation_contains_ground_truth:
|
| 795 |
+
false_negatives += 1
|
| 796 |
+
elif is_negative_example and generation_contains_ground_truth:
|
| 797 |
+
true_negatives += 1
|
| 798 |
+
elif is_negative_example and not generation_contains_ground_truth:
|
| 799 |
+
false_positives += 1
|
| 800 |
+
|
| 801 |
+
# Add metadata about the prediction
|
| 802 |
+
prediction_category = (
|
| 803 |
+
"TP" if is_positive_example and generation_contains_ground_truth else
|
| 804 |
+
"FN" if is_positive_example and not generation_contains_ground_truth else
|
| 805 |
+
"TN" if is_negative_example and generation_contains_ground_truth else
|
| 806 |
+
"FP"
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
# Store generation data
|
| 810 |
+
generations.append({
|
| 811 |
+
"batch_idx": batch_idx,
|
| 812 |
+
"example_idx": example_idx,
|
| 813 |
+
"user_input": user_input,
|
| 814 |
+
"generation": generation,
|
| 815 |
+
"ground_truth": ground_truth,
|
| 816 |
+
"contains_ground_truth": generation_contains_ground_truth,
|
| 817 |
+
"is_positive_example": is_positive_example,
|
| 818 |
+
"prediction_category": prediction_category
|
| 819 |
+
})
|
| 820 |
+
|
| 821 |
+
# Clean up memory
|
| 822 |
+
torch.cuda.empty_cache()
|
| 823 |
+
gc.collect()
|
| 824 |
+
|
| 825 |
+
# Log batch completion to wandb
|
| 826 |
+
processed_batches += 1
|
| 827 |
+
|
| 828 |
+
# Calculate current metrics
|
| 829 |
+
current_accuracy = (true_positives + true_negatives) / max(total_examples, 1)
|
| 830 |
+
current_precision = true_positives / max(true_positives + false_positives, 1)
|
| 831 |
+
current_recall = true_positives / max(true_positives + false_negatives, 1)
|
| 832 |
+
current_f1 = 2 * current_precision * current_recall / max(current_precision + current_recall, 1e-8)
|
| 833 |
+
|
| 834 |
+
wandb_logger.log({
|
| 835 |
+
"batches_processed": processed_batches,
|
| 836 |
+
"examples_processed": total_examples,
|
| 837 |
+
"examples_in_last_batch": examples_in_batch,
|
| 838 |
+
"current_accuracy": current_accuracy,
|
| 839 |
+
"current_precision": current_precision,
|
| 840 |
+
"current_recall": current_recall,
|
| 841 |
+
"current_f1": current_f1,
|
| 842 |
+
"progress_percentage": (batch_idx + 1) / total_batches * 100
|
| 843 |
+
})
|
| 844 |
+
|
| 845 |
+
# Calculate final metrics
|
| 846 |
+
accuracy = (true_positives + true_negatives) / max(total_examples, 1)
|
| 847 |
+
precision = true_positives / max(true_positives + false_positives, 1)
|
| 848 |
+
recall = true_positives / max(true_positives + false_negatives, 1)
|
| 849 |
+
f1 = 2 * precision * recall / max(precision + recall, 1e-8)
|
| 850 |
+
|
| 851 |
+
# Log final metrics to wandb
|
| 852 |
+
wandb_logger.log({
|
| 853 |
+
"test_accuracy": accuracy,
|
| 854 |
+
"test_precision": precision,
|
| 855 |
+
"test_recall": recall,
|
| 856 |
+
"test_f1": f1,
|
| 857 |
+
"true_positives": true_positives,
|
| 858 |
+
"false_positives": false_positives,
|
| 859 |
+
"true_negatives": true_negatives,
|
| 860 |
+
"false_negatives": false_negatives,
|
| 861 |
+
"total_examples_processed": total_examples,
|
| 862 |
+
"positive_examples": true_positives + false_negatives,
|
| 863 |
+
"negative_examples": true_negatives + false_positives,
|
| 864 |
+
"test_status": "completed"
|
| 865 |
+
})
|
| 866 |
+
|
| 867 |
+
# Create a confusion matrix
|
| 868 |
+
confusion_matrix = {
|
| 869 |
+
"True Positives": true_positives,
|
| 870 |
+
"False Positives": false_positives,
|
| 871 |
+
"True Negatives": true_negatives,
|
| 872 |
+
"False Negatives": false_negatives
|
| 873 |
+
}
|
| 874 |
+
wandb_logger.log({"confusion_matrix": confusion_matrix})
|
| 875 |
+
|
| 876 |
+
# Create a table with all the generations
|
| 877 |
+
if generations:
|
| 878 |
+
columns = [
|
| 879 |
+
"batch_idx",
|
| 880 |
+
"example_idx",
|
| 881 |
+
"user_input",
|
| 882 |
+
"generation",
|
| 883 |
+
"ground_truth",
|
| 884 |
+
"contains_ground_truth",
|
| 885 |
+
"is_positive_example",
|
| 886 |
+
"prediction_category"
|
| 887 |
+
]
|
| 888 |
+
data = []
|
| 889 |
+
for g in generations:
|
| 890 |
+
# Handle any missing keys
|
| 891 |
+
row = [g.get(c, "") for c in columns]
|
| 892 |
+
data.append(row)
|
| 893 |
+
|
| 894 |
+
wandb_logger.log({
|
| 895 |
+
f"test_generations_{time.strftime('%Y%m%d-%H%M%S')}:": wandb.Table(columns=columns, data=data)
|
| 896 |
+
})
|
| 897 |
+
|
| 898 |
+
# Save generations to a CSV file
|
| 899 |
+
model_name = self.hparams.text_model_name.split('/')[-1]
|
| 900 |
+
if self.hparams.ckpt_path:
|
| 901 |
+
csv_path = os.path.join(self.hparams.ckpt_path, f"{time.strftime('%Y%m%d-%H%M%S')}-test_generations_{model_name}.csv")
|
| 902 |
+
else:
|
| 903 |
+
csv_path = os.path.join(self.hparams.checkpoint_dir, f"{time.strftime('%Y%m%d-%H%M%S')}-test_generations_{model_name}.csv")
|
| 904 |
+
|
| 905 |
+
try:
|
| 906 |
+
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
|
| 907 |
+
if generations:
|
| 908 |
+
writer = csv.DictWriter(f, fieldnames=generations[0].keys())
|
| 909 |
+
writer.writeheader()
|
| 910 |
+
for g in generations:
|
| 911 |
+
writer.writerow(g)
|
| 912 |
+
|
| 913 |
+
wandb_logger.log({"csv_saved": True, "csv_path": csv_path})
|
| 914 |
+
except Exception as e:
|
| 915 |
+
wandb_logger.log({"csv_saved": False, "csv_path": csv_path, "error": str(e)})
|
| 916 |
+
|
| 917 |
+
# Log a summary of the metrics
|
| 918 |
+
summary = (
|
| 919 |
+
f"Test Results Summary:\n"
|
| 920 |
+
f"Total examples: {total_examples}\n"
|
| 921 |
+
f"Accuracy: {accuracy:.4f}\n"
|
| 922 |
+
f"Precision: {precision:.4f}\n"
|
| 923 |
+
f"Recall: {recall:.4f}\n"
|
| 924 |
+
f"F1 Score: {f1:.4f}\n"
|
| 925 |
+
f"TP: {true_positives}, FP: {false_positives}, TN: {true_negatives}, FN: {false_negatives}"
|
| 926 |
+
)
|
| 927 |
+
print(summary)
|
| 928 |
+
wandb_logger.log({"test_summary": summary})
|
| 929 |
+
|
| 930 |
+
# Force garbage collection
|
| 931 |
+
torch.cuda.empty_cache()
|
| 932 |
+
gc.collect()
|
| 933 |
+
|
| 934 |
+
return {
|
| 935 |
+
"test_accuracy": accuracy,
|
| 936 |
+
"test_precision": precision,
|
| 937 |
+
"test_recall": recall,
|
| 938 |
+
"test_f1": f1
|
| 939 |
+
}
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
def main(args: ArgumentParser):
|
| 943 |
+
"""
|
| 944 |
+
Main function to run the DNA-Text fine-tuning process.
|
| 945 |
+
|
| 946 |
+
Args:
|
| 947 |
+
args (ArgumentParser): Parsed command-line arguments
|
| 948 |
+
"""
|
| 949 |
+
# Set random seed and environment variables
|
| 950 |
+
pl.seed_everything(args.seed)
|
| 951 |
+
torch.cuda.empty_cache()
|
| 952 |
+
torch.set_float32_matmul_precision("medium")
|
| 953 |
+
|
| 954 |
+
# Setup directories
|
| 955 |
+
run_name = f"{args.wandb_project}-{args.dataset_type}-{args.text_model_name.split('/')[-1]}"
|
| 956 |
+
args.checkpoint_dir = f"{args.checkpoint_dir}/{run_name}-{time.strftime('%Y%m%d-%H%M%S')}"
|
| 957 |
+
|
| 958 |
+
# Initialize model
|
| 959 |
+
model = DNALLMFineTuner(args)
|
| 960 |
+
|
| 961 |
+
# Setup callbacks
|
| 962 |
+
callbacks = [
|
| 963 |
+
ModelCheckpoint(
|
| 964 |
+
dirpath=args.checkpoint_dir,
|
| 965 |
+
filename=f"{run_name}-" + "{epoch:02d}-{val_loss_epoch:.4f}",
|
| 966 |
+
save_top_k=2,
|
| 967 |
+
monitor="val_loss_epoch",
|
| 968 |
+
mode="min",
|
| 969 |
+
save_last=True,
|
| 970 |
+
),
|
| 971 |
+
LearningRateMonitor(logging_interval="step"),
|
| 972 |
+
]
|
| 973 |
+
|
| 974 |
+
# Setup logger
|
| 975 |
+
is_resuming = args.ckpt_path is not None
|
| 976 |
+
logger = WandbLogger(
|
| 977 |
+
project=args.wandb_project,
|
| 978 |
+
entity=args.wandb_entity,
|
| 979 |
+
save_dir=args.log_dir,
|
| 980 |
+
name=run_name,
|
| 981 |
+
resume="allow" if is_resuming else None, # Allow resuming existing run
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
# Initialize the PyTorch Lightning Trainer
|
| 985 |
+
trainer = pl.Trainer(
|
| 986 |
+
max_epochs=args.max_epochs,
|
| 987 |
+
accelerator="gpu",
|
| 988 |
+
devices=args.num_gpus,
|
| 989 |
+
strategy=(
|
| 990 |
+
"ddp"
|
| 991 |
+
if args.strategy == "ddp"
|
| 992 |
+
else DeepSpeedStrategy(stage=2, offload_optimizer=False, allgather_bucket_size=5e8, reduce_bucket_size=5e8)
|
| 993 |
+
),
|
| 994 |
+
precision="bf16-mixed",
|
| 995 |
+
callbacks=callbacks,
|
| 996 |
+
logger=logger,
|
| 997 |
+
deterministic=False,
|
| 998 |
+
enable_checkpointing=True,
|
| 999 |
+
enable_progress_bar=True,
|
| 1000 |
+
enable_model_summary=True,
|
| 1001 |
+
log_every_n_steps=5,
|
| 1002 |
+
accumulate_grad_batches=args.gradient_accumulation_steps,
|
| 1003 |
+
gradient_clip_val=1.0,
|
| 1004 |
+
val_check_interval=1 / 3,
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
# Start the training process
|
| 1008 |
+
trainer.fit(model, ckpt_path=args.ckpt_path)
|
| 1009 |
+
trainer.test(model, ckpt_path=args.ckpt_path if args.ckpt_path else "best")
|
| 1010 |
+
|
| 1011 |
+
if __name__ == "__main__":
|
| 1012 |
+
parser = ArgumentParser()
|
| 1013 |
+
|
| 1014 |
+
# Model configuration
|
| 1015 |
+
parser.add_argument("--model_type", type=str, choices=["llm", "dna-llm"], default="dna-llm")
|
| 1016 |
+
parser.add_argument("--text_model_name", type=str, default="Qwen/Qwen3-1.7B")
|
| 1017 |
+
parser.add_argument("--dna_model_name", type=str, default="InstaDeepAI/nucleotide-transformer-v2-500m-multi-species")
|
| 1018 |
+
parser.add_argument("--text_model_finetune", type=bool, default=True)
|
| 1019 |
+
parser.add_argument("--dna_model_finetune", type=bool, default=False)
|
| 1020 |
+
parser.add_argument("--dna_is_evo2", type=bool, default=False)
|
| 1021 |
+
parser.add_argument("--dna_embedding_layer", type=str, default=None)
|
| 1022 |
+
|
| 1023 |
+
# Training parameters
|
| 1024 |
+
parser.add_argument("--seed", type=int, default=23)
|
| 1025 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
| 1026 |
+
parser.add_argument("--max_epochs", type=int, default=5)
|
| 1027 |
+
parser.add_argument("--learning_rate", type=float, default=5e-5)
|
| 1028 |
+
parser.add_argument("--weight_decay", type=float, default=0.01)
|
| 1029 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
|
| 1030 |
+
parser.add_argument("--max_length_dna", type=int, default=1024)
|
| 1031 |
+
parser.add_argument("--max_length_text", type=int, default=1024)
|
| 1032 |
+
parser.add_argument("--truncate_dna_per_side", type=int, default=1024)
|
| 1033 |
+
parser.add_argument("--return_answer_in_batch", type=bool, default=False)
|
| 1034 |
+
|
| 1035 |
+
# LoRA parameters
|
| 1036 |
+
parser.add_argument("--lora_rank", type=int, default=32)
|
| 1037 |
+
parser.add_argument("--lora_alpha", type=int, default=64)
|
| 1038 |
+
parser.add_argument("--lora_dropout", type=float, default=0.05)
|
| 1039 |
+
|
| 1040 |
+
# Infrastructure and paths
|
| 1041 |
+
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints")
|
| 1042 |
+
parser.add_argument("--log_dir", type=str, default="logs")
|
| 1043 |
+
parser.add_argument("--cache_dir", type=str, default="/model-weights")
|
| 1044 |
+
parser.add_argument("--ckpt_path", type=str, default=None)
|
| 1045 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 1046 |
+
parser.add_argument("--num_gpus", type=int, default=1)
|
| 1047 |
+
parser.add_argument("--strategy", type=str, default="ddp")
|
| 1048 |
+
|
| 1049 |
+
# Dataset configuration
|
| 1050 |
+
parser.add_argument("--dataset_type", type=str, choices=["kegg", "variant_effect_coding", "variant_effect_non_snv"], default="kegg")
|
| 1051 |
+
parser.add_argument("--use_qwen_dna_collate_fn", type=bool, default=True)
|
| 1052 |
+
parser.add_argument("--kegg_data_dir_local", type=str, default="data/kegg")
|
| 1053 |
+
parser.add_argument("--kegg_data_dir_huggingface", type=str, default="wanglab/kegg")
|
| 1054 |
+
parser.add_argument("--variant_effect_coding_data_dir_huggingface", type=str, default="wanglab/variant_effect_coding")
|
| 1055 |
+
parser.add_argument("--variant_effect_non_snv_data_dir_huggingface", type=str, default="wanglab/variant_effect_non_snv")
|
| 1056 |
+
parser.add_argument("--merge_val_test_set", type=bool, default=False)
|
| 1057 |
+
|
| 1058 |
+
# Logging and monitoring
|
| 1059 |
+
parser.add_argument("--wandb_project", type=str, default="nt-500m-qwen3-1.7b-finetune")
|
| 1060 |
+
parser.add_argument("--wandb_entity", type=str)
|
| 1061 |
+
|
| 1062 |
+
args = parser.parse_args()
|
| 1063 |
+
|
| 1064 |
+
main(args)
|