Batch upload part 19
Browse files- nl_tasks/expsBOFT/seed42/ft/special_tokens_map.json +24 -0
- nl_tasks/expsBOFT/seed42/ft/tokenizer.json +0 -0
- nl_tasks/expsBOFT/seed42/ft/tokenizer.model +3 -0
- nl_tasks/expsBOFT/seed42/ft/tokenizer_config.json +43 -0
- nl_tasks/expsBOFT/seed42/ft2/README.md +205 -0
- nl_tasks/expsBOFT/seed42/ft2/adapter_config.json +27 -0
- nl_tasks/expsBOFT/seed42/ft2/adapter_model.safetensors +3 -0
- nl_tasks/expsBOFT/seed42/trainer_state.json +218 -0
- nl_tasks/expsBOFT/seed43/ft/special_tokens_map.json +24 -0
- nl_tasks/expsBOFT/seed43/ft/tokenizer.json +0 -0
- nl_tasks/expsBOFT/seed43/ft/tokenizer.model +3 -0
- nl_tasks/expsBOFT/seed43/ft/tokenizer_config.json +43 -0
- nl_tasks/expsBOFT/seed43/ft2/README.md +205 -0
- nl_tasks/expsBOFT/seed43/ft2/adapter_config.json +27 -0
- nl_tasks/expsBOFT/seed43/ft2/adapter_model.safetensors +3 -0
- nl_tasks/expsOFT/seed42/ft/special_tokens_map.json +24 -0
- nl_tasks/expsOFT/seed42/ft/tokenizer.json +0 -0
- nl_tasks/expsOFT/seed42/ft/tokenizer.model +3 -0
- nl_tasks/expsOFT/seed42/ft/tokenizer_config.json +43 -0
- nl_tasks/expsOFT/seed42/ft2/README.md +205 -0
- nl_tasks/expsOFT/seed42/ft2/adapter_config.json +31 -0
- nl_tasks/expsOFT/seed42/ft2/adapter_model.safetensors +3 -0
- nl_tasks/expsOFT/seed42/trainer_state.json +218 -0
- nl_tasks/expsOFT/seed43/ft/special_tokens_map.json +24 -0
- nl_tasks/expsOFT/seed43/ft/tokenizer.json +0 -0
- nl_tasks/expsOFT/seed43/ft/tokenizer.model +3 -0
- nl_tasks/expsOFT/seed43/ft/tokenizer_config.json +43 -0
- nl_tasks/expsOFT/seed43/ft2/README.md +205 -0
- nl_tasks/expsOFT/seed43/ft2/adapter_config.json +31 -0
- nl_tasks/expsOFT/seed43/ft2/adapter_model.safetensors +3 -0
- nl_tasks/expsOFT/seed43/trainer_state.json +218 -0
- nl_tasks/expsOFT/seed44/ft/special_tokens_map.json +24 -0
- nl_tasks/expsOFT/seed44/ft/tokenizer.json +0 -0
- nl_tasks/expsOFT/seed44/ft/tokenizer.model +3 -0
- nl_tasks/expsOFT/seed44/ft/tokenizer_config.json +43 -0
- nl_tasks/expsOFT/seed44/ft2/README.md +205 -0
- nl_tasks/expsOFT/seed44/ft2/adapter_config.json +31 -0
- nl_tasks/expsOFT/seed44/ft2/adapter_model.safetensors +3 -0
- nl_tasks/expsOFT/seed44/trainer_state.json +218 -0
- omini/__init__.py +0 -0
- omini/pipeline/flux_omini.py +734 -0
- omini/pipeline/flux_omini_ablate_qkv.py +772 -0
- omini/pipeline/flux_omini_ablate_scale.py +748 -0
- omini/rotation/__init__.py +3 -0
- omini/rotation/layer.py +313 -0
- omini/rotation/layer_test.py +296 -0
- omini/rotation/model.py +390 -0
- omini/rotation/rotation_config.py +81 -0
- omini/train_flux/train_custom.py +50 -0
- omini/train_flux/train_multi_condition.py +160 -0
nl_tasks/expsBOFT/seed42/ft/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "<unk>",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<unk>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
nl_tasks/expsBOFT/seed42/ft/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nl_tasks/expsBOFT/seed42/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
nl_tasks/expsBOFT/seed42/ft/tokenizer_config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"bos_token": "<s>",
|
| 32 |
+
"clean_up_tokenization_spaces": false,
|
| 33 |
+
"eos_token": "</s>",
|
| 34 |
+
"extra_special_tokens": {},
|
| 35 |
+
"legacy": false,
|
| 36 |
+
"model_max_length": 512,
|
| 37 |
+
"pad_token": "<unk>",
|
| 38 |
+
"padding_side": "right",
|
| 39 |
+
"sp_model_kwargs": {},
|
| 40 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 41 |
+
"unk_token": "<unk>",
|
| 42 |
+
"use_default_system_prompt": false
|
| 43 |
+
}
|
nl_tasks/expsBOFT/seed42/ft2/README.md
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
tags:
|
| 5 |
+
- base_model:adapter:meta-llama/Llama-2-7b-hf
|
| 6 |
+
- transformers
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Model Card for Model ID
|
| 10 |
+
|
| 11 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## Model Details
|
| 16 |
+
|
| 17 |
+
### Model Description
|
| 18 |
+
|
| 19 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
- **Developed by:** [More Information Needed]
|
| 24 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 25 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 26 |
+
- **Model type:** [More Information Needed]
|
| 27 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 28 |
+
- **License:** [More Information Needed]
|
| 29 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 30 |
+
|
| 31 |
+
### Model Sources [optional]
|
| 32 |
+
|
| 33 |
+
<!-- Provide the basic links for the model. -->
|
| 34 |
+
|
| 35 |
+
- **Repository:** [More Information Needed]
|
| 36 |
+
- **Paper [optional]:** [More Information Needed]
|
| 37 |
+
- **Demo [optional]:** [More Information Needed]
|
| 38 |
+
|
| 39 |
+
## Uses
|
| 40 |
+
|
| 41 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 42 |
+
|
| 43 |
+
### Direct Use
|
| 44 |
+
|
| 45 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 46 |
+
|
| 47 |
+
[More Information Needed]
|
| 48 |
+
|
| 49 |
+
### Downstream Use [optional]
|
| 50 |
+
|
| 51 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 52 |
+
|
| 53 |
+
[More Information Needed]
|
| 54 |
+
|
| 55 |
+
### Out-of-Scope Use
|
| 56 |
+
|
| 57 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 58 |
+
|
| 59 |
+
[More Information Needed]
|
| 60 |
+
|
| 61 |
+
## Bias, Risks, and Limitations
|
| 62 |
+
|
| 63 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 64 |
+
|
| 65 |
+
[More Information Needed]
|
| 66 |
+
|
| 67 |
+
### Recommendations
|
| 68 |
+
|
| 69 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 70 |
+
|
| 71 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 72 |
+
|
| 73 |
+
## How to Get Started with the Model
|
| 74 |
+
|
| 75 |
+
Use the code below to get started with the model.
|
| 76 |
+
|
| 77 |
+
[More Information Needed]
|
| 78 |
+
|
| 79 |
+
## Training Details
|
| 80 |
+
|
| 81 |
+
### Training Data
|
| 82 |
+
|
| 83 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 84 |
+
|
| 85 |
+
[More Information Needed]
|
| 86 |
+
|
| 87 |
+
### Training Procedure
|
| 88 |
+
|
| 89 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 90 |
+
|
| 91 |
+
#### Preprocessing [optional]
|
| 92 |
+
|
| 93 |
+
[More Information Needed]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
#### Training Hyperparameters
|
| 97 |
+
|
| 98 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 99 |
+
|
| 100 |
+
#### Speeds, Sizes, Times [optional]
|
| 101 |
+
|
| 102 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 103 |
+
|
| 104 |
+
[More Information Needed]
|
| 105 |
+
|
| 106 |
+
## Evaluation
|
| 107 |
+
|
| 108 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 109 |
+
|
| 110 |
+
### Testing Data, Factors & Metrics
|
| 111 |
+
|
| 112 |
+
#### Testing Data
|
| 113 |
+
|
| 114 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 115 |
+
|
| 116 |
+
[More Information Needed]
|
| 117 |
+
|
| 118 |
+
#### Factors
|
| 119 |
+
|
| 120 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 121 |
+
|
| 122 |
+
[More Information Needed]
|
| 123 |
+
|
| 124 |
+
#### Metrics
|
| 125 |
+
|
| 126 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 127 |
+
|
| 128 |
+
[More Information Needed]
|
| 129 |
+
|
| 130 |
+
### Results
|
| 131 |
+
|
| 132 |
+
[More Information Needed]
|
| 133 |
+
|
| 134 |
+
#### Summary
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
## Model Examination [optional]
|
| 139 |
+
|
| 140 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 141 |
+
|
| 142 |
+
[More Information Needed]
|
| 143 |
+
|
| 144 |
+
## Environmental Impact
|
| 145 |
+
|
| 146 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 147 |
+
|
| 148 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 149 |
+
|
| 150 |
+
- **Hardware Type:** [More Information Needed]
|
| 151 |
+
- **Hours used:** [More Information Needed]
|
| 152 |
+
- **Cloud Provider:** [More Information Needed]
|
| 153 |
+
- **Compute Region:** [More Information Needed]
|
| 154 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 155 |
+
|
| 156 |
+
## Technical Specifications [optional]
|
| 157 |
+
|
| 158 |
+
### Model Architecture and Objective
|
| 159 |
+
|
| 160 |
+
[More Information Needed]
|
| 161 |
+
|
| 162 |
+
### Compute Infrastructure
|
| 163 |
+
|
| 164 |
+
[More Information Needed]
|
| 165 |
+
|
| 166 |
+
#### Hardware
|
| 167 |
+
|
| 168 |
+
[More Information Needed]
|
| 169 |
+
|
| 170 |
+
#### Software
|
| 171 |
+
|
| 172 |
+
[More Information Needed]
|
| 173 |
+
|
| 174 |
+
## Citation [optional]
|
| 175 |
+
|
| 176 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 177 |
+
|
| 178 |
+
**BibTeX:**
|
| 179 |
+
|
| 180 |
+
[More Information Needed]
|
| 181 |
+
|
| 182 |
+
**APA:**
|
| 183 |
+
|
| 184 |
+
[More Information Needed]
|
| 185 |
+
|
| 186 |
+
## Glossary [optional]
|
| 187 |
+
|
| 188 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 189 |
+
|
| 190 |
+
[More Information Needed]
|
| 191 |
+
|
| 192 |
+
## More Information [optional]
|
| 193 |
+
|
| 194 |
+
[More Information Needed]
|
| 195 |
+
|
| 196 |
+
## Model Card Authors [optional]
|
| 197 |
+
|
| 198 |
+
[More Information Needed]
|
| 199 |
+
|
| 200 |
+
## Model Card Contact
|
| 201 |
+
|
| 202 |
+
[More Information Needed]
|
| 203 |
+
### Framework versions
|
| 204 |
+
|
| 205 |
+
- PEFT 0.18.0
|
nl_tasks/expsBOFT/seed42/ft2/adapter_config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_mapping": {
|
| 3 |
+
"base_model_class": "LlamaForCausalLM",
|
| 4 |
+
"parent_library": "transformers.models.llama.modeling_llama"
|
| 5 |
+
},
|
| 6 |
+
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
| 7 |
+
"bias": "none",
|
| 8 |
+
"boft_block_num": 0,
|
| 9 |
+
"boft_block_size": 16,
|
| 10 |
+
"boft_dropout": 0.05,
|
| 11 |
+
"boft_n_butterfly_factor": 2,
|
| 12 |
+
"exclude_modules": null,
|
| 13 |
+
"fan_in_fan_out": false,
|
| 14 |
+
"inference_mode": true,
|
| 15 |
+
"init_weights": true,
|
| 16 |
+
"layers_pattern": null,
|
| 17 |
+
"layers_to_transform": null,
|
| 18 |
+
"modules_to_save": null,
|
| 19 |
+
"peft_type": "BOFT",
|
| 20 |
+
"peft_version": "0.18.0",
|
| 21 |
+
"revision": null,
|
| 22 |
+
"target_modules": [
|
| 23 |
+
"v_proj",
|
| 24 |
+
"q_proj"
|
| 25 |
+
],
|
| 26 |
+
"task_type": null
|
| 27 |
+
}
|
nl_tasks/expsBOFT/seed42/ft2/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:584526a06a1f45f2f77e6a89a7201b05aa25a3d6be60f231b255a32c48c4b261
|
| 3 |
+
size 34619504
|
nl_tasks/expsBOFT/seed42/trainer_state.json
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 2.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 1250,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.08,
|
| 14 |
+
"grad_norm": 0.08375173062086105,
|
| 15 |
+
"learning_rate": 0.000392,
|
| 16 |
+
"loss": 0.5193,
|
| 17 |
+
"step": 50
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.16,
|
| 21 |
+
"grad_norm": 0.09268203377723694,
|
| 22 |
+
"learning_rate": 0.0007920000000000001,
|
| 23 |
+
"loss": 0.3316,
|
| 24 |
+
"step": 100
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.24,
|
| 28 |
+
"grad_norm": 0.08198747783899307,
|
| 29 |
+
"learning_rate": 0.0007964216926581925,
|
| 30 |
+
"loss": 0.304,
|
| 31 |
+
"step": 150
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.32,
|
| 35 |
+
"grad_norm": 0.0816216915845871,
|
| 36 |
+
"learning_rate": 0.0007854602918076551,
|
| 37 |
+
"loss": 0.2918,
|
| 38 |
+
"step": 200
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.4,
|
| 42 |
+
"grad_norm": 0.07457849383354187,
|
| 43 |
+
"learning_rate": 0.0007673184950396212,
|
| 44 |
+
"loss": 0.274,
|
| 45 |
+
"step": 250
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 0.48,
|
| 49 |
+
"grad_norm": 0.07685171067714691,
|
| 50 |
+
"learning_rate": 0.0007423342497022817,
|
| 51 |
+
"loss": 0.2687,
|
| 52 |
+
"step": 300
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 0.56,
|
| 56 |
+
"grad_norm": 0.07849128544330597,
|
| 57 |
+
"learning_rate": 0.0007109729650142636,
|
| 58 |
+
"loss": 0.2651,
|
| 59 |
+
"step": 350
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 0.64,
|
| 63 |
+
"grad_norm": 0.07266736030578613,
|
| 64 |
+
"learning_rate": 0.0006738188423714755,
|
| 65 |
+
"loss": 0.2575,
|
| 66 |
+
"step": 400
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 0.72,
|
| 70 |
+
"grad_norm": 0.06927025318145752,
|
| 71 |
+
"learning_rate": 0.0006315639927804526,
|
| 72 |
+
"loss": 0.2525,
|
| 73 |
+
"step": 450
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 0.8,
|
| 77 |
+
"grad_norm": 0.08536054193973541,
|
| 78 |
+
"learning_rate": 0.00058499554413983,
|
| 79 |
+
"loss": 0.2494,
|
| 80 |
+
"step": 500
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 0.88,
|
| 84 |
+
"grad_norm": 0.07602768391370773,
|
| 85 |
+
"learning_rate": 0.000534980978536894,
|
| 86 |
+
"loss": 0.2429,
|
| 87 |
+
"step": 550
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 0.96,
|
| 91 |
+
"grad_norm": 0.07055249065160751,
|
| 92 |
+
"learning_rate": 0.00048245197269763485,
|
| 93 |
+
"loss": 0.2457,
|
| 94 |
+
"step": 600
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 1.04,
|
| 98 |
+
"grad_norm": 0.07144515216350555,
|
| 99 |
+
"learning_rate": 0.00042838704261214224,
|
| 100 |
+
"loss": 0.2292,
|
| 101 |
+
"step": 650
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 1.12,
|
| 105 |
+
"grad_norm": 0.07937044650316238,
|
| 106 |
+
"learning_rate": 0.00037379331563313267,
|
| 107 |
+
"loss": 0.2169,
|
| 108 |
+
"step": 700
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 1.2,
|
| 112 |
+
"grad_norm": 0.07409252226352692,
|
| 113 |
+
"learning_rate": 0.00031968776959892677,
|
| 114 |
+
"loss": 0.2098,
|
| 115 |
+
"step": 750
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 1.28,
|
| 119 |
+
"grad_norm": 0.07844420522451401,
|
| 120 |
+
"learning_rate": 0.00026707828846051743,
|
| 121 |
+
"loss": 0.2145,
|
| 122 |
+
"step": 800
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 1.3599999999999999,
|
| 126 |
+
"grad_norm": 0.07791652530431747,
|
| 127 |
+
"learning_rate": 0.00021694488731055218,
|
| 128 |
+
"loss": 0.2082,
|
| 129 |
+
"step": 850
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 1.44,
|
| 133 |
+
"grad_norm": 0.0782908946275711,
|
| 134 |
+
"learning_rate": 0.00017022145655641685,
|
| 135 |
+
"loss": 0.2077,
|
| 136 |
+
"step": 900
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 1.52,
|
| 140 |
+
"grad_norm": 0.0826650932431221,
|
| 141 |
+
"learning_rate": 0.00012777836530893536,
|
| 142 |
+
"loss": 0.2137,
|
| 143 |
+
"step": 950
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 1.6,
|
| 147 |
+
"grad_norm": 0.0696156919002533,
|
| 148 |
+
"learning_rate": 9.040624805263558e-05,
|
| 149 |
+
"loss": 0.2076,
|
| 150 |
+
"step": 1000
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 1.6800000000000002,
|
| 154 |
+
"grad_norm": 0.06966507434844971,
|
| 155 |
+
"learning_rate": 5.880127662124091e-05,
|
| 156 |
+
"loss": 0.2108,
|
| 157 |
+
"step": 1050
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 1.76,
|
| 161 |
+
"grad_norm": 0.08326321095228195,
|
| 162 |
+
"learning_rate": 3.355219183361582e-05,
|
| 163 |
+
"loss": 0.2106,
|
| 164 |
+
"step": 1100
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 1.8399999999999999,
|
| 168 |
+
"grad_norm": 0.0792745053768158,
|
| 169 |
+
"learning_rate": 1.512933636625089e-05,
|
| 170 |
+
"loss": 0.2073,
|
| 171 |
+
"step": 1150
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 1.92,
|
| 175 |
+
"grad_norm": 0.07648582756519318,
|
| 176 |
+
"learning_rate": 3.8758931591217575e-06,
|
| 177 |
+
"loss": 0.209,
|
| 178 |
+
"step": 1200
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 2.0,
|
| 182 |
+
"grad_norm": 0.0787830799818039,
|
| 183 |
+
"learning_rate": 1.4925668450960217e-09,
|
| 184 |
+
"loss": 0.2124,
|
| 185 |
+
"step": 1250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 2.0,
|
| 189 |
+
"step": 1250,
|
| 190 |
+
"total_flos": 1.62594677587968e+18,
|
| 191 |
+
"train_loss": 0.25041088790893556,
|
| 192 |
+
"train_runtime": 3370.9131,
|
| 193 |
+
"train_samples_per_second": 23.732,
|
| 194 |
+
"train_steps_per_second": 0.371
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"logging_steps": 50,
|
| 198 |
+
"max_steps": 1250,
|
| 199 |
+
"num_input_tokens_seen": 0,
|
| 200 |
+
"num_train_epochs": 2,
|
| 201 |
+
"save_steps": 0,
|
| 202 |
+
"stateful_callbacks": {
|
| 203 |
+
"TrainerControl": {
|
| 204 |
+
"args": {
|
| 205 |
+
"should_epoch_stop": false,
|
| 206 |
+
"should_evaluate": false,
|
| 207 |
+
"should_log": false,
|
| 208 |
+
"should_save": false,
|
| 209 |
+
"should_training_stop": false
|
| 210 |
+
},
|
| 211 |
+
"attributes": {}
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"total_flos": 1.62594677587968e+18,
|
| 215 |
+
"train_batch_size": 32,
|
| 216 |
+
"trial_name": null,
|
| 217 |
+
"trial_params": null
|
| 218 |
+
}
|
nl_tasks/expsBOFT/seed43/ft/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "<unk>",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<unk>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
nl_tasks/expsBOFT/seed43/ft/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nl_tasks/expsBOFT/seed43/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
nl_tasks/expsBOFT/seed43/ft/tokenizer_config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"bos_token": "<s>",
|
| 32 |
+
"clean_up_tokenization_spaces": false,
|
| 33 |
+
"eos_token": "</s>",
|
| 34 |
+
"extra_special_tokens": {},
|
| 35 |
+
"legacy": false,
|
| 36 |
+
"model_max_length": 512,
|
| 37 |
+
"pad_token": "<unk>",
|
| 38 |
+
"padding_side": "right",
|
| 39 |
+
"sp_model_kwargs": {},
|
| 40 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 41 |
+
"unk_token": "<unk>",
|
| 42 |
+
"use_default_system_prompt": false
|
| 43 |
+
}
|
nl_tasks/expsBOFT/seed43/ft2/README.md
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
tags:
|
| 5 |
+
- base_model:adapter:meta-llama/Llama-2-7b-hf
|
| 6 |
+
- transformers
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Model Card for Model ID
|
| 10 |
+
|
| 11 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## Model Details
|
| 16 |
+
|
| 17 |
+
### Model Description
|
| 18 |
+
|
| 19 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
- **Developed by:** [More Information Needed]
|
| 24 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 25 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 26 |
+
- **Model type:** [More Information Needed]
|
| 27 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 28 |
+
- **License:** [More Information Needed]
|
| 29 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 30 |
+
|
| 31 |
+
### Model Sources [optional]
|
| 32 |
+
|
| 33 |
+
<!-- Provide the basic links for the model. -->
|
| 34 |
+
|
| 35 |
+
- **Repository:** [More Information Needed]
|
| 36 |
+
- **Paper [optional]:** [More Information Needed]
|
| 37 |
+
- **Demo [optional]:** [More Information Needed]
|
| 38 |
+
|
| 39 |
+
## Uses
|
| 40 |
+
|
| 41 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 42 |
+
|
| 43 |
+
### Direct Use
|
| 44 |
+
|
| 45 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 46 |
+
|
| 47 |
+
[More Information Needed]
|
| 48 |
+
|
| 49 |
+
### Downstream Use [optional]
|
| 50 |
+
|
| 51 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 52 |
+
|
| 53 |
+
[More Information Needed]
|
| 54 |
+
|
| 55 |
+
### Out-of-Scope Use
|
| 56 |
+
|
| 57 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 58 |
+
|
| 59 |
+
[More Information Needed]
|
| 60 |
+
|
| 61 |
+
## Bias, Risks, and Limitations
|
| 62 |
+
|
| 63 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 64 |
+
|
| 65 |
+
[More Information Needed]
|
| 66 |
+
|
| 67 |
+
### Recommendations
|
| 68 |
+
|
| 69 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 70 |
+
|
| 71 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 72 |
+
|
| 73 |
+
## How to Get Started with the Model
|
| 74 |
+
|
| 75 |
+
Use the code below to get started with the model.
|
| 76 |
+
|
| 77 |
+
[More Information Needed]
|
| 78 |
+
|
| 79 |
+
## Training Details
|
| 80 |
+
|
| 81 |
+
### Training Data
|
| 82 |
+
|
| 83 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 84 |
+
|
| 85 |
+
[More Information Needed]
|
| 86 |
+
|
| 87 |
+
### Training Procedure
|
| 88 |
+
|
| 89 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 90 |
+
|
| 91 |
+
#### Preprocessing [optional]
|
| 92 |
+
|
| 93 |
+
[More Information Needed]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
#### Training Hyperparameters
|
| 97 |
+
|
| 98 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 99 |
+
|
| 100 |
+
#### Speeds, Sizes, Times [optional]
|
| 101 |
+
|
| 102 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 103 |
+
|
| 104 |
+
[More Information Needed]
|
| 105 |
+
|
| 106 |
+
## Evaluation
|
| 107 |
+
|
| 108 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 109 |
+
|
| 110 |
+
### Testing Data, Factors & Metrics
|
| 111 |
+
|
| 112 |
+
#### Testing Data
|
| 113 |
+
|
| 114 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 115 |
+
|
| 116 |
+
[More Information Needed]
|
| 117 |
+
|
| 118 |
+
#### Factors
|
| 119 |
+
|
| 120 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 121 |
+
|
| 122 |
+
[More Information Needed]
|
| 123 |
+
|
| 124 |
+
#### Metrics
|
| 125 |
+
|
| 126 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 127 |
+
|
| 128 |
+
[More Information Needed]
|
| 129 |
+
|
| 130 |
+
### Results
|
| 131 |
+
|
| 132 |
+
[More Information Needed]
|
| 133 |
+
|
| 134 |
+
#### Summary
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
## Model Examination [optional]
|
| 139 |
+
|
| 140 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 141 |
+
|
| 142 |
+
[More Information Needed]
|
| 143 |
+
|
| 144 |
+
## Environmental Impact
|
| 145 |
+
|
| 146 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 147 |
+
|
| 148 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 149 |
+
|
| 150 |
+
- **Hardware Type:** [More Information Needed]
|
| 151 |
+
- **Hours used:** [More Information Needed]
|
| 152 |
+
- **Cloud Provider:** [More Information Needed]
|
| 153 |
+
- **Compute Region:** [More Information Needed]
|
| 154 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 155 |
+
|
| 156 |
+
## Technical Specifications [optional]
|
| 157 |
+
|
| 158 |
+
### Model Architecture and Objective
|
| 159 |
+
|
| 160 |
+
[More Information Needed]
|
| 161 |
+
|
| 162 |
+
### Compute Infrastructure
|
| 163 |
+
|
| 164 |
+
[More Information Needed]
|
| 165 |
+
|
| 166 |
+
#### Hardware
|
| 167 |
+
|
| 168 |
+
[More Information Needed]
|
| 169 |
+
|
| 170 |
+
#### Software
|
| 171 |
+
|
| 172 |
+
[More Information Needed]
|
| 173 |
+
|
| 174 |
+
## Citation [optional]
|
| 175 |
+
|
| 176 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 177 |
+
|
| 178 |
+
**BibTeX:**
|
| 179 |
+
|
| 180 |
+
[More Information Needed]
|
| 181 |
+
|
| 182 |
+
**APA:**
|
| 183 |
+
|
| 184 |
+
[More Information Needed]
|
| 185 |
+
|
| 186 |
+
## Glossary [optional]
|
| 187 |
+
|
| 188 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 189 |
+
|
| 190 |
+
[More Information Needed]
|
| 191 |
+
|
| 192 |
+
## More Information [optional]
|
| 193 |
+
|
| 194 |
+
[More Information Needed]
|
| 195 |
+
|
| 196 |
+
## Model Card Authors [optional]
|
| 197 |
+
|
| 198 |
+
[More Information Needed]
|
| 199 |
+
|
| 200 |
+
## Model Card Contact
|
| 201 |
+
|
| 202 |
+
[More Information Needed]
|
| 203 |
+
### Framework versions
|
| 204 |
+
|
| 205 |
+
- PEFT 0.18.0
|
nl_tasks/expsBOFT/seed43/ft2/adapter_config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_mapping": {
|
| 3 |
+
"base_model_class": "LlamaForCausalLM",
|
| 4 |
+
"parent_library": "transformers.models.llama.modeling_llama"
|
| 5 |
+
},
|
| 6 |
+
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
| 7 |
+
"bias": "none",
|
| 8 |
+
"boft_block_num": 0,
|
| 9 |
+
"boft_block_size": 16,
|
| 10 |
+
"boft_dropout": 0.05,
|
| 11 |
+
"boft_n_butterfly_factor": 2,
|
| 12 |
+
"exclude_modules": null,
|
| 13 |
+
"fan_in_fan_out": false,
|
| 14 |
+
"inference_mode": true,
|
| 15 |
+
"init_weights": true,
|
| 16 |
+
"layers_pattern": null,
|
| 17 |
+
"layers_to_transform": null,
|
| 18 |
+
"modules_to_save": null,
|
| 19 |
+
"peft_type": "BOFT",
|
| 20 |
+
"peft_version": "0.18.0",
|
| 21 |
+
"revision": null,
|
| 22 |
+
"target_modules": [
|
| 23 |
+
"v_proj",
|
| 24 |
+
"q_proj"
|
| 25 |
+
],
|
| 26 |
+
"task_type": null
|
| 27 |
+
}
|
nl_tasks/expsBOFT/seed43/ft2/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:584526a06a1f45f2f77e6a89a7201b05aa25a3d6be60f231b255a32c48c4b261
|
| 3 |
+
size 34619504
|
nl_tasks/expsOFT/seed42/ft/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "<unk>",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<unk>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
nl_tasks/expsOFT/seed42/ft/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nl_tasks/expsOFT/seed42/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
nl_tasks/expsOFT/seed42/ft/tokenizer_config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"bos_token": "<s>",
|
| 32 |
+
"clean_up_tokenization_spaces": false,
|
| 33 |
+
"eos_token": "</s>",
|
| 34 |
+
"extra_special_tokens": {},
|
| 35 |
+
"legacy": false,
|
| 36 |
+
"model_max_length": 512,
|
| 37 |
+
"pad_token": "<unk>",
|
| 38 |
+
"padding_side": "right",
|
| 39 |
+
"sp_model_kwargs": {},
|
| 40 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 41 |
+
"unk_token": "<unk>",
|
| 42 |
+
"use_default_system_prompt": false
|
| 43 |
+
}
|
nl_tasks/expsOFT/seed42/ft2/README.md
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
tags:
|
| 5 |
+
- base_model:adapter:meta-llama/Llama-2-7b-hf
|
| 6 |
+
- transformers
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Model Card for Model ID
|
| 10 |
+
|
| 11 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## Model Details
|
| 16 |
+
|
| 17 |
+
### Model Description
|
| 18 |
+
|
| 19 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
- **Developed by:** [More Information Needed]
|
| 24 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 25 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 26 |
+
- **Model type:** [More Information Needed]
|
| 27 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 28 |
+
- **License:** [More Information Needed]
|
| 29 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 30 |
+
|
| 31 |
+
### Model Sources [optional]
|
| 32 |
+
|
| 33 |
+
<!-- Provide the basic links for the model. -->
|
| 34 |
+
|
| 35 |
+
- **Repository:** [More Information Needed]
|
| 36 |
+
- **Paper [optional]:** [More Information Needed]
|
| 37 |
+
- **Demo [optional]:** [More Information Needed]
|
| 38 |
+
|
| 39 |
+
## Uses
|
| 40 |
+
|
| 41 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 42 |
+
|
| 43 |
+
### Direct Use
|
| 44 |
+
|
| 45 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 46 |
+
|
| 47 |
+
[More Information Needed]
|
| 48 |
+
|
| 49 |
+
### Downstream Use [optional]
|
| 50 |
+
|
| 51 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 52 |
+
|
| 53 |
+
[More Information Needed]
|
| 54 |
+
|
| 55 |
+
### Out-of-Scope Use
|
| 56 |
+
|
| 57 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 58 |
+
|
| 59 |
+
[More Information Needed]
|
| 60 |
+
|
| 61 |
+
## Bias, Risks, and Limitations
|
| 62 |
+
|
| 63 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 64 |
+
|
| 65 |
+
[More Information Needed]
|
| 66 |
+
|
| 67 |
+
### Recommendations
|
| 68 |
+
|
| 69 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 70 |
+
|
| 71 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 72 |
+
|
| 73 |
+
## How to Get Started with the Model
|
| 74 |
+
|
| 75 |
+
Use the code below to get started with the model.
|
| 76 |
+
|
| 77 |
+
[More Information Needed]
|
| 78 |
+
|
| 79 |
+
## Training Details
|
| 80 |
+
|
| 81 |
+
### Training Data
|
| 82 |
+
|
| 83 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 84 |
+
|
| 85 |
+
[More Information Needed]
|
| 86 |
+
|
| 87 |
+
### Training Procedure
|
| 88 |
+
|
| 89 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 90 |
+
|
| 91 |
+
#### Preprocessing [optional]
|
| 92 |
+
|
| 93 |
+
[More Information Needed]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
#### Training Hyperparameters
|
| 97 |
+
|
| 98 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 99 |
+
|
| 100 |
+
#### Speeds, Sizes, Times [optional]
|
| 101 |
+
|
| 102 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 103 |
+
|
| 104 |
+
[More Information Needed]
|
| 105 |
+
|
| 106 |
+
## Evaluation
|
| 107 |
+
|
| 108 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 109 |
+
|
| 110 |
+
### Testing Data, Factors & Metrics
|
| 111 |
+
|
| 112 |
+
#### Testing Data
|
| 113 |
+
|
| 114 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 115 |
+
|
| 116 |
+
[More Information Needed]
|
| 117 |
+
|
| 118 |
+
#### Factors
|
| 119 |
+
|
| 120 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 121 |
+
|
| 122 |
+
[More Information Needed]
|
| 123 |
+
|
| 124 |
+
#### Metrics
|
| 125 |
+
|
| 126 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 127 |
+
|
| 128 |
+
[More Information Needed]
|
| 129 |
+
|
| 130 |
+
### Results
|
| 131 |
+
|
| 132 |
+
[More Information Needed]
|
| 133 |
+
|
| 134 |
+
#### Summary
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
## Model Examination [optional]
|
| 139 |
+
|
| 140 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 141 |
+
|
| 142 |
+
[More Information Needed]
|
| 143 |
+
|
| 144 |
+
## Environmental Impact
|
| 145 |
+
|
| 146 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 147 |
+
|
| 148 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 149 |
+
|
| 150 |
+
- **Hardware Type:** [More Information Needed]
|
| 151 |
+
- **Hours used:** [More Information Needed]
|
| 152 |
+
- **Cloud Provider:** [More Information Needed]
|
| 153 |
+
- **Compute Region:** [More Information Needed]
|
| 154 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 155 |
+
|
| 156 |
+
## Technical Specifications [optional]
|
| 157 |
+
|
| 158 |
+
### Model Architecture and Objective
|
| 159 |
+
|
| 160 |
+
[More Information Needed]
|
| 161 |
+
|
| 162 |
+
### Compute Infrastructure
|
| 163 |
+
|
| 164 |
+
[More Information Needed]
|
| 165 |
+
|
| 166 |
+
#### Hardware
|
| 167 |
+
|
| 168 |
+
[More Information Needed]
|
| 169 |
+
|
| 170 |
+
#### Software
|
| 171 |
+
|
| 172 |
+
[More Information Needed]
|
| 173 |
+
|
| 174 |
+
## Citation [optional]
|
| 175 |
+
|
| 176 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 177 |
+
|
| 178 |
+
**BibTeX:**
|
| 179 |
+
|
| 180 |
+
[More Information Needed]
|
| 181 |
+
|
| 182 |
+
**APA:**
|
| 183 |
+
|
| 184 |
+
[More Information Needed]
|
| 185 |
+
|
| 186 |
+
## Glossary [optional]
|
| 187 |
+
|
| 188 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 189 |
+
|
| 190 |
+
[More Information Needed]
|
| 191 |
+
|
| 192 |
+
## More Information [optional]
|
| 193 |
+
|
| 194 |
+
[More Information Needed]
|
| 195 |
+
|
| 196 |
+
## Model Card Authors [optional]
|
| 197 |
+
|
| 198 |
+
[More Information Needed]
|
| 199 |
+
|
| 200 |
+
## Model Card Contact
|
| 201 |
+
|
| 202 |
+
[More Information Needed]
|
| 203 |
+
### Framework versions
|
| 204 |
+
|
| 205 |
+
- PEFT 0.18.0
|
nl_tasks/expsOFT/seed42/ft2/adapter_config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_mapping": {
|
| 3 |
+
"base_model_class": "LlamaForCausalLM",
|
| 4 |
+
"parent_library": "transformers.models.llama.modeling_llama"
|
| 5 |
+
},
|
| 6 |
+
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
| 7 |
+
"bias": "none",
|
| 8 |
+
"block_share": false,
|
| 9 |
+
"coft": false,
|
| 10 |
+
"eps": 6e-05,
|
| 11 |
+
"exclude_modules": null,
|
| 12 |
+
"fan_in_fan_out": false,
|
| 13 |
+
"inference_mode": true,
|
| 14 |
+
"init_weights": true,
|
| 15 |
+
"layers_pattern": null,
|
| 16 |
+
"layers_to_transform": null,
|
| 17 |
+
"module_dropout": 0.05,
|
| 18 |
+
"modules_to_save": null,
|
| 19 |
+
"num_cayley_neumann_terms": 5,
|
| 20 |
+
"oft_block_size": 64,
|
| 21 |
+
"peft_type": "OFT",
|
| 22 |
+
"peft_version": "0.18.0",
|
| 23 |
+
"r": 0,
|
| 24 |
+
"revision": null,
|
| 25 |
+
"target_modules": [
|
| 26 |
+
"q_proj",
|
| 27 |
+
"v_proj"
|
| 28 |
+
],
|
| 29 |
+
"task_type": null,
|
| 30 |
+
"use_cayley_neumann": true
|
| 31 |
+
}
|
nl_tasks/expsOFT/seed42/ft2/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d16378461c75d46a179539ea2223803c3af83b5ebb2dcc6face78c64e3ac4f9c
|
| 3 |
+
size 33038696
|
nl_tasks/expsOFT/seed42/trainer_state.json
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 2.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 1250,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.08,
|
| 14 |
+
"grad_norm": 0.15338309109210968,
|
| 15 |
+
"learning_rate": 0.000392,
|
| 16 |
+
"loss": 0.4726,
|
| 17 |
+
"step": 50
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.16,
|
| 21 |
+
"grad_norm": 0.1656411737203598,
|
| 22 |
+
"learning_rate": 0.0007920000000000001,
|
| 23 |
+
"loss": 0.3098,
|
| 24 |
+
"step": 100
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.24,
|
| 28 |
+
"grad_norm": 0.161162331700325,
|
| 29 |
+
"learning_rate": 0.0007964216926581925,
|
| 30 |
+
"loss": 0.2883,
|
| 31 |
+
"step": 150
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.32,
|
| 35 |
+
"grad_norm": 0.14719629287719727,
|
| 36 |
+
"learning_rate": 0.0007854602918076551,
|
| 37 |
+
"loss": 0.2773,
|
| 38 |
+
"step": 200
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.4,
|
| 42 |
+
"grad_norm": 0.1362672597169876,
|
| 43 |
+
"learning_rate": 0.0007673184950396212,
|
| 44 |
+
"loss": 0.2606,
|
| 45 |
+
"step": 250
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 0.48,
|
| 49 |
+
"grad_norm": 0.1420401930809021,
|
| 50 |
+
"learning_rate": 0.0007423342497022817,
|
| 51 |
+
"loss": 0.2549,
|
| 52 |
+
"step": 300
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 0.56,
|
| 56 |
+
"grad_norm": 0.15255458652973175,
|
| 57 |
+
"learning_rate": 0.0007109729650142636,
|
| 58 |
+
"loss": 0.2516,
|
| 59 |
+
"step": 350
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 0.64,
|
| 63 |
+
"grad_norm": 0.13546934723854065,
|
| 64 |
+
"learning_rate": 0.0006738188423714755,
|
| 65 |
+
"loss": 0.2439,
|
| 66 |
+
"step": 400
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 0.72,
|
| 70 |
+
"grad_norm": 0.1296033263206482,
|
| 71 |
+
"learning_rate": 0.0006315639927804526,
|
| 72 |
+
"loss": 0.2383,
|
| 73 |
+
"step": 450
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 0.8,
|
| 77 |
+
"grad_norm": 0.14936736226081848,
|
| 78 |
+
"learning_rate": 0.00058499554413983,
|
| 79 |
+
"loss": 0.2348,
|
| 80 |
+
"step": 500
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 0.88,
|
| 84 |
+
"grad_norm": 0.12654532492160797,
|
| 85 |
+
"learning_rate": 0.000534980978536894,
|
| 86 |
+
"loss": 0.2274,
|
| 87 |
+
"step": 550
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 0.96,
|
| 91 |
+
"grad_norm": 0.1250297725200653,
|
| 92 |
+
"learning_rate": 0.00048245197269763485,
|
| 93 |
+
"loss": 0.2298,
|
| 94 |
+
"step": 600
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 1.04,
|
| 98 |
+
"grad_norm": 0.1344439834356308,
|
| 99 |
+
"learning_rate": 0.00042838704261214224,
|
| 100 |
+
"loss": 0.2065,
|
| 101 |
+
"step": 650
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 1.12,
|
| 105 |
+
"grad_norm": 0.12664927542209625,
|
| 106 |
+
"learning_rate": 0.00037379331563313267,
|
| 107 |
+
"loss": 0.1907,
|
| 108 |
+
"step": 700
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 1.2,
|
| 112 |
+
"grad_norm": 0.1543550342321396,
|
| 113 |
+
"learning_rate": 0.00031968776959892677,
|
| 114 |
+
"loss": 0.1887,
|
| 115 |
+
"step": 750
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 1.28,
|
| 119 |
+
"grad_norm": 0.13837428390979767,
|
| 120 |
+
"learning_rate": 0.00026707828846051743,
|
| 121 |
+
"loss": 0.185,
|
| 122 |
+
"step": 800
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 1.3599999999999999,
|
| 126 |
+
"grad_norm": 0.12324073910713196,
|
| 127 |
+
"learning_rate": 0.00021694488731055218,
|
| 128 |
+
"loss": 0.1787,
|
| 129 |
+
"step": 850
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 1.44,
|
| 133 |
+
"grad_norm": 0.14447391033172607,
|
| 134 |
+
"learning_rate": 0.00017022145655641685,
|
| 135 |
+
"loss": 0.1779,
|
| 136 |
+
"step": 900
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 1.52,
|
| 140 |
+
"grad_norm": 0.13559409976005554,
|
| 141 |
+
"learning_rate": 0.00012777836530893536,
|
| 142 |
+
"loss": 0.1785,
|
| 143 |
+
"step": 950
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 1.6,
|
| 147 |
+
"grad_norm": 0.13572397828102112,
|
| 148 |
+
"learning_rate": 9.040624805263558e-05,
|
| 149 |
+
"loss": 0.176,
|
| 150 |
+
"step": 1000
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 1.6800000000000002,
|
| 154 |
+
"grad_norm": 0.13348858058452606,
|
| 155 |
+
"learning_rate": 5.880127662124091e-05,
|
| 156 |
+
"loss": 0.1743,
|
| 157 |
+
"step": 1050
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 1.76,
|
| 161 |
+
"grad_norm": 0.1402943730354309,
|
| 162 |
+
"learning_rate": 3.355219183361582e-05,
|
| 163 |
+
"loss": 0.1755,
|
| 164 |
+
"step": 1100
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 1.8399999999999999,
|
| 168 |
+
"grad_norm": 0.14928816258907318,
|
| 169 |
+
"learning_rate": 1.512933636625089e-05,
|
| 170 |
+
"loss": 0.1729,
|
| 171 |
+
"step": 1150
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 1.92,
|
| 175 |
+
"grad_norm": 0.14678366482257843,
|
| 176 |
+
"learning_rate": 3.8758931591217575e-06,
|
| 177 |
+
"loss": 0.1785,
|
| 178 |
+
"step": 1200
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 2.0,
|
| 182 |
+
"grad_norm": 0.13319681584835052,
|
| 183 |
+
"learning_rate": 1.4925668450960217e-09,
|
| 184 |
+
"loss": 0.1739,
|
| 185 |
+
"step": 1250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 2.0,
|
| 189 |
+
"step": 1250,
|
| 190 |
+
"total_flos": 1.62585013911552e+18,
|
| 191 |
+
"train_loss": 0.2258549835205078,
|
| 192 |
+
"train_runtime": 2135.866,
|
| 193 |
+
"train_samples_per_second": 37.456,
|
| 194 |
+
"train_steps_per_second": 0.585
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"logging_steps": 50,
|
| 198 |
+
"max_steps": 1250,
|
| 199 |
+
"num_input_tokens_seen": 0,
|
| 200 |
+
"num_train_epochs": 2,
|
| 201 |
+
"save_steps": 0,
|
| 202 |
+
"stateful_callbacks": {
|
| 203 |
+
"TrainerControl": {
|
| 204 |
+
"args": {
|
| 205 |
+
"should_epoch_stop": false,
|
| 206 |
+
"should_evaluate": false,
|
| 207 |
+
"should_log": false,
|
| 208 |
+
"should_save": false,
|
| 209 |
+
"should_training_stop": false
|
| 210 |
+
},
|
| 211 |
+
"attributes": {}
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"total_flos": 1.62585013911552e+18,
|
| 215 |
+
"train_batch_size": 64,
|
| 216 |
+
"trial_name": null,
|
| 217 |
+
"trial_params": null
|
| 218 |
+
}
|
nl_tasks/expsOFT/seed43/ft/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "<unk>",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<unk>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
nl_tasks/expsOFT/seed43/ft/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nl_tasks/expsOFT/seed43/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
nl_tasks/expsOFT/seed43/ft/tokenizer_config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"bos_token": "<s>",
|
| 32 |
+
"clean_up_tokenization_spaces": false,
|
| 33 |
+
"eos_token": "</s>",
|
| 34 |
+
"extra_special_tokens": {},
|
| 35 |
+
"legacy": false,
|
| 36 |
+
"model_max_length": 512,
|
| 37 |
+
"pad_token": "<unk>",
|
| 38 |
+
"padding_side": "right",
|
| 39 |
+
"sp_model_kwargs": {},
|
| 40 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 41 |
+
"unk_token": "<unk>",
|
| 42 |
+
"use_default_system_prompt": false
|
| 43 |
+
}
|
nl_tasks/expsOFT/seed43/ft2/README.md
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
tags:
|
| 5 |
+
- base_model:adapter:meta-llama/Llama-2-7b-hf
|
| 6 |
+
- transformers
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Model Card for Model ID
|
| 10 |
+
|
| 11 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## Model Details
|
| 16 |
+
|
| 17 |
+
### Model Description
|
| 18 |
+
|
| 19 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
- **Developed by:** [More Information Needed]
|
| 24 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 25 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 26 |
+
- **Model type:** [More Information Needed]
|
| 27 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 28 |
+
- **License:** [More Information Needed]
|
| 29 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 30 |
+
|
| 31 |
+
### Model Sources [optional]
|
| 32 |
+
|
| 33 |
+
<!-- Provide the basic links for the model. -->
|
| 34 |
+
|
| 35 |
+
- **Repository:** [More Information Needed]
|
| 36 |
+
- **Paper [optional]:** [More Information Needed]
|
| 37 |
+
- **Demo [optional]:** [More Information Needed]
|
| 38 |
+
|
| 39 |
+
## Uses
|
| 40 |
+
|
| 41 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 42 |
+
|
| 43 |
+
### Direct Use
|
| 44 |
+
|
| 45 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 46 |
+
|
| 47 |
+
[More Information Needed]
|
| 48 |
+
|
| 49 |
+
### Downstream Use [optional]
|
| 50 |
+
|
| 51 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 52 |
+
|
| 53 |
+
[More Information Needed]
|
| 54 |
+
|
| 55 |
+
### Out-of-Scope Use
|
| 56 |
+
|
| 57 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 58 |
+
|
| 59 |
+
[More Information Needed]
|
| 60 |
+
|
| 61 |
+
## Bias, Risks, and Limitations
|
| 62 |
+
|
| 63 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 64 |
+
|
| 65 |
+
[More Information Needed]
|
| 66 |
+
|
| 67 |
+
### Recommendations
|
| 68 |
+
|
| 69 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 70 |
+
|
| 71 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 72 |
+
|
| 73 |
+
## How to Get Started with the Model
|
| 74 |
+
|
| 75 |
+
Use the code below to get started with the model.
|
| 76 |
+
|
| 77 |
+
[More Information Needed]
|
| 78 |
+
|
| 79 |
+
## Training Details
|
| 80 |
+
|
| 81 |
+
### Training Data
|
| 82 |
+
|
| 83 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 84 |
+
|
| 85 |
+
[More Information Needed]
|
| 86 |
+
|
| 87 |
+
### Training Procedure
|
| 88 |
+
|
| 89 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 90 |
+
|
| 91 |
+
#### Preprocessing [optional]
|
| 92 |
+
|
| 93 |
+
[More Information Needed]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
#### Training Hyperparameters
|
| 97 |
+
|
| 98 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 99 |
+
|
| 100 |
+
#### Speeds, Sizes, Times [optional]
|
| 101 |
+
|
| 102 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 103 |
+
|
| 104 |
+
[More Information Needed]
|
| 105 |
+
|
| 106 |
+
## Evaluation
|
| 107 |
+
|
| 108 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 109 |
+
|
| 110 |
+
### Testing Data, Factors & Metrics
|
| 111 |
+
|
| 112 |
+
#### Testing Data
|
| 113 |
+
|
| 114 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 115 |
+
|
| 116 |
+
[More Information Needed]
|
| 117 |
+
|
| 118 |
+
#### Factors
|
| 119 |
+
|
| 120 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 121 |
+
|
| 122 |
+
[More Information Needed]
|
| 123 |
+
|
| 124 |
+
#### Metrics
|
| 125 |
+
|
| 126 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 127 |
+
|
| 128 |
+
[More Information Needed]
|
| 129 |
+
|
| 130 |
+
### Results
|
| 131 |
+
|
| 132 |
+
[More Information Needed]
|
| 133 |
+
|
| 134 |
+
#### Summary
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
## Model Examination [optional]
|
| 139 |
+
|
| 140 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 141 |
+
|
| 142 |
+
[More Information Needed]
|
| 143 |
+
|
| 144 |
+
## Environmental Impact
|
| 145 |
+
|
| 146 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 147 |
+
|
| 148 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 149 |
+
|
| 150 |
+
- **Hardware Type:** [More Information Needed]
|
| 151 |
+
- **Hours used:** [More Information Needed]
|
| 152 |
+
- **Cloud Provider:** [More Information Needed]
|
| 153 |
+
- **Compute Region:** [More Information Needed]
|
| 154 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 155 |
+
|
| 156 |
+
## Technical Specifications [optional]
|
| 157 |
+
|
| 158 |
+
### Model Architecture and Objective
|
| 159 |
+
|
| 160 |
+
[More Information Needed]
|
| 161 |
+
|
| 162 |
+
### Compute Infrastructure
|
| 163 |
+
|
| 164 |
+
[More Information Needed]
|
| 165 |
+
|
| 166 |
+
#### Hardware
|
| 167 |
+
|
| 168 |
+
[More Information Needed]
|
| 169 |
+
|
| 170 |
+
#### Software
|
| 171 |
+
|
| 172 |
+
[More Information Needed]
|
| 173 |
+
|
| 174 |
+
## Citation [optional]
|
| 175 |
+
|
| 176 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 177 |
+
|
| 178 |
+
**BibTeX:**
|
| 179 |
+
|
| 180 |
+
[More Information Needed]
|
| 181 |
+
|
| 182 |
+
**APA:**
|
| 183 |
+
|
| 184 |
+
[More Information Needed]
|
| 185 |
+
|
| 186 |
+
## Glossary [optional]
|
| 187 |
+
|
| 188 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 189 |
+
|
| 190 |
+
[More Information Needed]
|
| 191 |
+
|
| 192 |
+
## More Information [optional]
|
| 193 |
+
|
| 194 |
+
[More Information Needed]
|
| 195 |
+
|
| 196 |
+
## Model Card Authors [optional]
|
| 197 |
+
|
| 198 |
+
[More Information Needed]
|
| 199 |
+
|
| 200 |
+
## Model Card Contact
|
| 201 |
+
|
| 202 |
+
[More Information Needed]
|
| 203 |
+
### Framework versions
|
| 204 |
+
|
| 205 |
+
- PEFT 0.18.0
|
nl_tasks/expsOFT/seed43/ft2/adapter_config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_mapping": {
|
| 3 |
+
"base_model_class": "LlamaForCausalLM",
|
| 4 |
+
"parent_library": "transformers.models.llama.modeling_llama"
|
| 5 |
+
},
|
| 6 |
+
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
| 7 |
+
"bias": "none",
|
| 8 |
+
"block_share": false,
|
| 9 |
+
"coft": false,
|
| 10 |
+
"eps": 6e-05,
|
| 11 |
+
"exclude_modules": null,
|
| 12 |
+
"fan_in_fan_out": false,
|
| 13 |
+
"inference_mode": true,
|
| 14 |
+
"init_weights": true,
|
| 15 |
+
"layers_pattern": null,
|
| 16 |
+
"layers_to_transform": null,
|
| 17 |
+
"module_dropout": 0.05,
|
| 18 |
+
"modules_to_save": null,
|
| 19 |
+
"num_cayley_neumann_terms": 5,
|
| 20 |
+
"oft_block_size": 64,
|
| 21 |
+
"peft_type": "OFT",
|
| 22 |
+
"peft_version": "0.18.0",
|
| 23 |
+
"r": 0,
|
| 24 |
+
"revision": null,
|
| 25 |
+
"target_modules": [
|
| 26 |
+
"q_proj",
|
| 27 |
+
"v_proj"
|
| 28 |
+
],
|
| 29 |
+
"task_type": null,
|
| 30 |
+
"use_cayley_neumann": true
|
| 31 |
+
}
|
nl_tasks/expsOFT/seed43/ft2/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d16378461c75d46a179539ea2223803c3af83b5ebb2dcc6face78c64e3ac4f9c
|
| 3 |
+
size 33038696
|
nl_tasks/expsOFT/seed43/trainer_state.json
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 2.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 1250,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.08,
|
| 14 |
+
"grad_norm": 0.15338309109210968,
|
| 15 |
+
"learning_rate": 0.000392,
|
| 16 |
+
"loss": 0.4726,
|
| 17 |
+
"step": 50
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.16,
|
| 21 |
+
"grad_norm": 0.1656411737203598,
|
| 22 |
+
"learning_rate": 0.0007920000000000001,
|
| 23 |
+
"loss": 0.3098,
|
| 24 |
+
"step": 100
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.24,
|
| 28 |
+
"grad_norm": 0.161162331700325,
|
| 29 |
+
"learning_rate": 0.0007964216926581925,
|
| 30 |
+
"loss": 0.2883,
|
| 31 |
+
"step": 150
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.32,
|
| 35 |
+
"grad_norm": 0.14719629287719727,
|
| 36 |
+
"learning_rate": 0.0007854602918076551,
|
| 37 |
+
"loss": 0.2773,
|
| 38 |
+
"step": 200
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.4,
|
| 42 |
+
"grad_norm": 0.1362672597169876,
|
| 43 |
+
"learning_rate": 0.0007673184950396212,
|
| 44 |
+
"loss": 0.2606,
|
| 45 |
+
"step": 250
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 0.48,
|
| 49 |
+
"grad_norm": 0.1420401930809021,
|
| 50 |
+
"learning_rate": 0.0007423342497022817,
|
| 51 |
+
"loss": 0.2549,
|
| 52 |
+
"step": 300
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 0.56,
|
| 56 |
+
"grad_norm": 0.15255458652973175,
|
| 57 |
+
"learning_rate": 0.0007109729650142636,
|
| 58 |
+
"loss": 0.2516,
|
| 59 |
+
"step": 350
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 0.64,
|
| 63 |
+
"grad_norm": 0.13546934723854065,
|
| 64 |
+
"learning_rate": 0.0006738188423714755,
|
| 65 |
+
"loss": 0.2439,
|
| 66 |
+
"step": 400
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 0.72,
|
| 70 |
+
"grad_norm": 0.1296033263206482,
|
| 71 |
+
"learning_rate": 0.0006315639927804526,
|
| 72 |
+
"loss": 0.2383,
|
| 73 |
+
"step": 450
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 0.8,
|
| 77 |
+
"grad_norm": 0.14936736226081848,
|
| 78 |
+
"learning_rate": 0.00058499554413983,
|
| 79 |
+
"loss": 0.2348,
|
| 80 |
+
"step": 500
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 0.88,
|
| 84 |
+
"grad_norm": 0.12654532492160797,
|
| 85 |
+
"learning_rate": 0.000534980978536894,
|
| 86 |
+
"loss": 0.2274,
|
| 87 |
+
"step": 550
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 0.96,
|
| 91 |
+
"grad_norm": 0.1250297725200653,
|
| 92 |
+
"learning_rate": 0.00048245197269763485,
|
| 93 |
+
"loss": 0.2298,
|
| 94 |
+
"step": 600
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 1.04,
|
| 98 |
+
"grad_norm": 0.1344439834356308,
|
| 99 |
+
"learning_rate": 0.00042838704261214224,
|
| 100 |
+
"loss": 0.2065,
|
| 101 |
+
"step": 650
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 1.12,
|
| 105 |
+
"grad_norm": 0.12664927542209625,
|
| 106 |
+
"learning_rate": 0.00037379331563313267,
|
| 107 |
+
"loss": 0.1907,
|
| 108 |
+
"step": 700
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 1.2,
|
| 112 |
+
"grad_norm": 0.1543550342321396,
|
| 113 |
+
"learning_rate": 0.00031968776959892677,
|
| 114 |
+
"loss": 0.1887,
|
| 115 |
+
"step": 750
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 1.28,
|
| 119 |
+
"grad_norm": 0.13837428390979767,
|
| 120 |
+
"learning_rate": 0.00026707828846051743,
|
| 121 |
+
"loss": 0.185,
|
| 122 |
+
"step": 800
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 1.3599999999999999,
|
| 126 |
+
"grad_norm": 0.12324073910713196,
|
| 127 |
+
"learning_rate": 0.00021694488731055218,
|
| 128 |
+
"loss": 0.1787,
|
| 129 |
+
"step": 850
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 1.44,
|
| 133 |
+
"grad_norm": 0.14447391033172607,
|
| 134 |
+
"learning_rate": 0.00017022145655641685,
|
| 135 |
+
"loss": 0.1779,
|
| 136 |
+
"step": 900
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 1.52,
|
| 140 |
+
"grad_norm": 0.13559409976005554,
|
| 141 |
+
"learning_rate": 0.00012777836530893536,
|
| 142 |
+
"loss": 0.1785,
|
| 143 |
+
"step": 950
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 1.6,
|
| 147 |
+
"grad_norm": 0.13572397828102112,
|
| 148 |
+
"learning_rate": 9.040624805263558e-05,
|
| 149 |
+
"loss": 0.176,
|
| 150 |
+
"step": 1000
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 1.6800000000000002,
|
| 154 |
+
"grad_norm": 0.13348858058452606,
|
| 155 |
+
"learning_rate": 5.880127662124091e-05,
|
| 156 |
+
"loss": 0.1743,
|
| 157 |
+
"step": 1050
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 1.76,
|
| 161 |
+
"grad_norm": 0.1402943730354309,
|
| 162 |
+
"learning_rate": 3.355219183361582e-05,
|
| 163 |
+
"loss": 0.1755,
|
| 164 |
+
"step": 1100
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 1.8399999999999999,
|
| 168 |
+
"grad_norm": 0.14928816258907318,
|
| 169 |
+
"learning_rate": 1.512933636625089e-05,
|
| 170 |
+
"loss": 0.1729,
|
| 171 |
+
"step": 1150
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 1.92,
|
| 175 |
+
"grad_norm": 0.14678366482257843,
|
| 176 |
+
"learning_rate": 3.8758931591217575e-06,
|
| 177 |
+
"loss": 0.1785,
|
| 178 |
+
"step": 1200
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 2.0,
|
| 182 |
+
"grad_norm": 0.13319681584835052,
|
| 183 |
+
"learning_rate": 1.4925668450960217e-09,
|
| 184 |
+
"loss": 0.1739,
|
| 185 |
+
"step": 1250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 2.0,
|
| 189 |
+
"step": 1250,
|
| 190 |
+
"total_flos": 1.62585013911552e+18,
|
| 191 |
+
"train_loss": 0.2258549835205078,
|
| 192 |
+
"train_runtime": 2134.8975,
|
| 193 |
+
"train_samples_per_second": 37.473,
|
| 194 |
+
"train_steps_per_second": 0.586
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"logging_steps": 50,
|
| 198 |
+
"max_steps": 1250,
|
| 199 |
+
"num_input_tokens_seen": 0,
|
| 200 |
+
"num_train_epochs": 2,
|
| 201 |
+
"save_steps": 0,
|
| 202 |
+
"stateful_callbacks": {
|
| 203 |
+
"TrainerControl": {
|
| 204 |
+
"args": {
|
| 205 |
+
"should_epoch_stop": false,
|
| 206 |
+
"should_evaluate": false,
|
| 207 |
+
"should_log": false,
|
| 208 |
+
"should_save": false,
|
| 209 |
+
"should_training_stop": false
|
| 210 |
+
},
|
| 211 |
+
"attributes": {}
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"total_flos": 1.62585013911552e+18,
|
| 215 |
+
"train_batch_size": 64,
|
| 216 |
+
"trial_name": null,
|
| 217 |
+
"trial_params": null
|
| 218 |
+
}
|
nl_tasks/expsOFT/seed44/ft/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "<unk>",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<unk>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
nl_tasks/expsOFT/seed44/ft/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
nl_tasks/expsOFT/seed44/ft/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
nl_tasks/expsOFT/seed44/ft/tokenizer_config.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"bos_token": "<s>",
|
| 32 |
+
"clean_up_tokenization_spaces": false,
|
| 33 |
+
"eos_token": "</s>",
|
| 34 |
+
"extra_special_tokens": {},
|
| 35 |
+
"legacy": false,
|
| 36 |
+
"model_max_length": 512,
|
| 37 |
+
"pad_token": "<unk>",
|
| 38 |
+
"padding_side": "right",
|
| 39 |
+
"sp_model_kwargs": {},
|
| 40 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 41 |
+
"unk_token": "<unk>",
|
| 42 |
+
"use_default_system_prompt": false
|
| 43 |
+
}
|
nl_tasks/expsOFT/seed44/ft2/README.md
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
tags:
|
| 5 |
+
- base_model:adapter:meta-llama/Llama-2-7b-hf
|
| 6 |
+
- transformers
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# Model Card for Model ID
|
| 10 |
+
|
| 11 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## Model Details
|
| 16 |
+
|
| 17 |
+
### Model Description
|
| 18 |
+
|
| 19 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
- **Developed by:** [More Information Needed]
|
| 24 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 25 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 26 |
+
- **Model type:** [More Information Needed]
|
| 27 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 28 |
+
- **License:** [More Information Needed]
|
| 29 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 30 |
+
|
| 31 |
+
### Model Sources [optional]
|
| 32 |
+
|
| 33 |
+
<!-- Provide the basic links for the model. -->
|
| 34 |
+
|
| 35 |
+
- **Repository:** [More Information Needed]
|
| 36 |
+
- **Paper [optional]:** [More Information Needed]
|
| 37 |
+
- **Demo [optional]:** [More Information Needed]
|
| 38 |
+
|
| 39 |
+
## Uses
|
| 40 |
+
|
| 41 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 42 |
+
|
| 43 |
+
### Direct Use
|
| 44 |
+
|
| 45 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 46 |
+
|
| 47 |
+
[More Information Needed]
|
| 48 |
+
|
| 49 |
+
### Downstream Use [optional]
|
| 50 |
+
|
| 51 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 52 |
+
|
| 53 |
+
[More Information Needed]
|
| 54 |
+
|
| 55 |
+
### Out-of-Scope Use
|
| 56 |
+
|
| 57 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 58 |
+
|
| 59 |
+
[More Information Needed]
|
| 60 |
+
|
| 61 |
+
## Bias, Risks, and Limitations
|
| 62 |
+
|
| 63 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 64 |
+
|
| 65 |
+
[More Information Needed]
|
| 66 |
+
|
| 67 |
+
### Recommendations
|
| 68 |
+
|
| 69 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 70 |
+
|
| 71 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 72 |
+
|
| 73 |
+
## How to Get Started with the Model
|
| 74 |
+
|
| 75 |
+
Use the code below to get started with the model.
|
| 76 |
+
|
| 77 |
+
[More Information Needed]
|
| 78 |
+
|
| 79 |
+
## Training Details
|
| 80 |
+
|
| 81 |
+
### Training Data
|
| 82 |
+
|
| 83 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 84 |
+
|
| 85 |
+
[More Information Needed]
|
| 86 |
+
|
| 87 |
+
### Training Procedure
|
| 88 |
+
|
| 89 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 90 |
+
|
| 91 |
+
#### Preprocessing [optional]
|
| 92 |
+
|
| 93 |
+
[More Information Needed]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
#### Training Hyperparameters
|
| 97 |
+
|
| 98 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 99 |
+
|
| 100 |
+
#### Speeds, Sizes, Times [optional]
|
| 101 |
+
|
| 102 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 103 |
+
|
| 104 |
+
[More Information Needed]
|
| 105 |
+
|
| 106 |
+
## Evaluation
|
| 107 |
+
|
| 108 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 109 |
+
|
| 110 |
+
### Testing Data, Factors & Metrics
|
| 111 |
+
|
| 112 |
+
#### Testing Data
|
| 113 |
+
|
| 114 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 115 |
+
|
| 116 |
+
[More Information Needed]
|
| 117 |
+
|
| 118 |
+
#### Factors
|
| 119 |
+
|
| 120 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 121 |
+
|
| 122 |
+
[More Information Needed]
|
| 123 |
+
|
| 124 |
+
#### Metrics
|
| 125 |
+
|
| 126 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 127 |
+
|
| 128 |
+
[More Information Needed]
|
| 129 |
+
|
| 130 |
+
### Results
|
| 131 |
+
|
| 132 |
+
[More Information Needed]
|
| 133 |
+
|
| 134 |
+
#### Summary
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
## Model Examination [optional]
|
| 139 |
+
|
| 140 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 141 |
+
|
| 142 |
+
[More Information Needed]
|
| 143 |
+
|
| 144 |
+
## Environmental Impact
|
| 145 |
+
|
| 146 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 147 |
+
|
| 148 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 149 |
+
|
| 150 |
+
- **Hardware Type:** [More Information Needed]
|
| 151 |
+
- **Hours used:** [More Information Needed]
|
| 152 |
+
- **Cloud Provider:** [More Information Needed]
|
| 153 |
+
- **Compute Region:** [More Information Needed]
|
| 154 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 155 |
+
|
| 156 |
+
## Technical Specifications [optional]
|
| 157 |
+
|
| 158 |
+
### Model Architecture and Objective
|
| 159 |
+
|
| 160 |
+
[More Information Needed]
|
| 161 |
+
|
| 162 |
+
### Compute Infrastructure
|
| 163 |
+
|
| 164 |
+
[More Information Needed]
|
| 165 |
+
|
| 166 |
+
#### Hardware
|
| 167 |
+
|
| 168 |
+
[More Information Needed]
|
| 169 |
+
|
| 170 |
+
#### Software
|
| 171 |
+
|
| 172 |
+
[More Information Needed]
|
| 173 |
+
|
| 174 |
+
## Citation [optional]
|
| 175 |
+
|
| 176 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 177 |
+
|
| 178 |
+
**BibTeX:**
|
| 179 |
+
|
| 180 |
+
[More Information Needed]
|
| 181 |
+
|
| 182 |
+
**APA:**
|
| 183 |
+
|
| 184 |
+
[More Information Needed]
|
| 185 |
+
|
| 186 |
+
## Glossary [optional]
|
| 187 |
+
|
| 188 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 189 |
+
|
| 190 |
+
[More Information Needed]
|
| 191 |
+
|
| 192 |
+
## More Information [optional]
|
| 193 |
+
|
| 194 |
+
[More Information Needed]
|
| 195 |
+
|
| 196 |
+
## Model Card Authors [optional]
|
| 197 |
+
|
| 198 |
+
[More Information Needed]
|
| 199 |
+
|
| 200 |
+
## Model Card Contact
|
| 201 |
+
|
| 202 |
+
[More Information Needed]
|
| 203 |
+
### Framework versions
|
| 204 |
+
|
| 205 |
+
- PEFT 0.18.0
|
nl_tasks/expsOFT/seed44/ft2/adapter_config.json
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"auto_mapping": {
|
| 3 |
+
"base_model_class": "LlamaForCausalLM",
|
| 4 |
+
"parent_library": "transformers.models.llama.modeling_llama"
|
| 5 |
+
},
|
| 6 |
+
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
| 7 |
+
"bias": "none",
|
| 8 |
+
"block_share": false,
|
| 9 |
+
"coft": false,
|
| 10 |
+
"eps": 6e-05,
|
| 11 |
+
"exclude_modules": null,
|
| 12 |
+
"fan_in_fan_out": false,
|
| 13 |
+
"inference_mode": true,
|
| 14 |
+
"init_weights": true,
|
| 15 |
+
"layers_pattern": null,
|
| 16 |
+
"layers_to_transform": null,
|
| 17 |
+
"module_dropout": 0.05,
|
| 18 |
+
"modules_to_save": null,
|
| 19 |
+
"num_cayley_neumann_terms": 5,
|
| 20 |
+
"oft_block_size": 64,
|
| 21 |
+
"peft_type": "OFT",
|
| 22 |
+
"peft_version": "0.18.0",
|
| 23 |
+
"r": 0,
|
| 24 |
+
"revision": null,
|
| 25 |
+
"target_modules": [
|
| 26 |
+
"q_proj",
|
| 27 |
+
"v_proj"
|
| 28 |
+
],
|
| 29 |
+
"task_type": null,
|
| 30 |
+
"use_cayley_neumann": true
|
| 31 |
+
}
|
nl_tasks/expsOFT/seed44/ft2/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d16378461c75d46a179539ea2223803c3af83b5ebb2dcc6face78c64e3ac4f9c
|
| 3 |
+
size 33038696
|
nl_tasks/expsOFT/seed44/trainer_state.json
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_global_step": null,
|
| 3 |
+
"best_metric": null,
|
| 4 |
+
"best_model_checkpoint": null,
|
| 5 |
+
"epoch": 2.0,
|
| 6 |
+
"eval_steps": 500,
|
| 7 |
+
"global_step": 1250,
|
| 8 |
+
"is_hyper_param_search": false,
|
| 9 |
+
"is_local_process_zero": true,
|
| 10 |
+
"is_world_process_zero": true,
|
| 11 |
+
"log_history": [
|
| 12 |
+
{
|
| 13 |
+
"epoch": 0.08,
|
| 14 |
+
"grad_norm": 0.15338309109210968,
|
| 15 |
+
"learning_rate": 0.000392,
|
| 16 |
+
"loss": 0.4726,
|
| 17 |
+
"step": 50
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"epoch": 0.16,
|
| 21 |
+
"grad_norm": 0.1656411737203598,
|
| 22 |
+
"learning_rate": 0.0007920000000000001,
|
| 23 |
+
"loss": 0.3098,
|
| 24 |
+
"step": 100
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"epoch": 0.24,
|
| 28 |
+
"grad_norm": 0.161162331700325,
|
| 29 |
+
"learning_rate": 0.0007964216926581925,
|
| 30 |
+
"loss": 0.2883,
|
| 31 |
+
"step": 150
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"epoch": 0.32,
|
| 35 |
+
"grad_norm": 0.14719629287719727,
|
| 36 |
+
"learning_rate": 0.0007854602918076551,
|
| 37 |
+
"loss": 0.2773,
|
| 38 |
+
"step": 200
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"epoch": 0.4,
|
| 42 |
+
"grad_norm": 0.1362672597169876,
|
| 43 |
+
"learning_rate": 0.0007673184950396212,
|
| 44 |
+
"loss": 0.2606,
|
| 45 |
+
"step": 250
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"epoch": 0.48,
|
| 49 |
+
"grad_norm": 0.1420401930809021,
|
| 50 |
+
"learning_rate": 0.0007423342497022817,
|
| 51 |
+
"loss": 0.2549,
|
| 52 |
+
"step": 300
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"epoch": 0.56,
|
| 56 |
+
"grad_norm": 0.15255458652973175,
|
| 57 |
+
"learning_rate": 0.0007109729650142636,
|
| 58 |
+
"loss": 0.2516,
|
| 59 |
+
"step": 350
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"epoch": 0.64,
|
| 63 |
+
"grad_norm": 0.13546934723854065,
|
| 64 |
+
"learning_rate": 0.0006738188423714755,
|
| 65 |
+
"loss": 0.2439,
|
| 66 |
+
"step": 400
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"epoch": 0.72,
|
| 70 |
+
"grad_norm": 0.1296033263206482,
|
| 71 |
+
"learning_rate": 0.0006315639927804526,
|
| 72 |
+
"loss": 0.2383,
|
| 73 |
+
"step": 450
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"epoch": 0.8,
|
| 77 |
+
"grad_norm": 0.14936736226081848,
|
| 78 |
+
"learning_rate": 0.00058499554413983,
|
| 79 |
+
"loss": 0.2348,
|
| 80 |
+
"step": 500
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"epoch": 0.88,
|
| 84 |
+
"grad_norm": 0.12654532492160797,
|
| 85 |
+
"learning_rate": 0.000534980978536894,
|
| 86 |
+
"loss": 0.2274,
|
| 87 |
+
"step": 550
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"epoch": 0.96,
|
| 91 |
+
"grad_norm": 0.1250297725200653,
|
| 92 |
+
"learning_rate": 0.00048245197269763485,
|
| 93 |
+
"loss": 0.2298,
|
| 94 |
+
"step": 600
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"epoch": 1.04,
|
| 98 |
+
"grad_norm": 0.1344439834356308,
|
| 99 |
+
"learning_rate": 0.00042838704261214224,
|
| 100 |
+
"loss": 0.2065,
|
| 101 |
+
"step": 650
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"epoch": 1.12,
|
| 105 |
+
"grad_norm": 0.12664927542209625,
|
| 106 |
+
"learning_rate": 0.00037379331563313267,
|
| 107 |
+
"loss": 0.1907,
|
| 108 |
+
"step": 700
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"epoch": 1.2,
|
| 112 |
+
"grad_norm": 0.1543550342321396,
|
| 113 |
+
"learning_rate": 0.00031968776959892677,
|
| 114 |
+
"loss": 0.1887,
|
| 115 |
+
"step": 750
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"epoch": 1.28,
|
| 119 |
+
"grad_norm": 0.13837428390979767,
|
| 120 |
+
"learning_rate": 0.00026707828846051743,
|
| 121 |
+
"loss": 0.185,
|
| 122 |
+
"step": 800
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"epoch": 1.3599999999999999,
|
| 126 |
+
"grad_norm": 0.12324073910713196,
|
| 127 |
+
"learning_rate": 0.00021694488731055218,
|
| 128 |
+
"loss": 0.1787,
|
| 129 |
+
"step": 850
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"epoch": 1.44,
|
| 133 |
+
"grad_norm": 0.14447391033172607,
|
| 134 |
+
"learning_rate": 0.00017022145655641685,
|
| 135 |
+
"loss": 0.1779,
|
| 136 |
+
"step": 900
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"epoch": 1.52,
|
| 140 |
+
"grad_norm": 0.13559409976005554,
|
| 141 |
+
"learning_rate": 0.00012777836530893536,
|
| 142 |
+
"loss": 0.1785,
|
| 143 |
+
"step": 950
|
| 144 |
+
},
|
| 145 |
+
{
|
| 146 |
+
"epoch": 1.6,
|
| 147 |
+
"grad_norm": 0.13572397828102112,
|
| 148 |
+
"learning_rate": 9.040624805263558e-05,
|
| 149 |
+
"loss": 0.176,
|
| 150 |
+
"step": 1000
|
| 151 |
+
},
|
| 152 |
+
{
|
| 153 |
+
"epoch": 1.6800000000000002,
|
| 154 |
+
"grad_norm": 0.13348858058452606,
|
| 155 |
+
"learning_rate": 5.880127662124091e-05,
|
| 156 |
+
"loss": 0.1743,
|
| 157 |
+
"step": 1050
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"epoch": 1.76,
|
| 161 |
+
"grad_norm": 0.1402943730354309,
|
| 162 |
+
"learning_rate": 3.355219183361582e-05,
|
| 163 |
+
"loss": 0.1755,
|
| 164 |
+
"step": 1100
|
| 165 |
+
},
|
| 166 |
+
{
|
| 167 |
+
"epoch": 1.8399999999999999,
|
| 168 |
+
"grad_norm": 0.14928816258907318,
|
| 169 |
+
"learning_rate": 1.512933636625089e-05,
|
| 170 |
+
"loss": 0.1729,
|
| 171 |
+
"step": 1150
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"epoch": 1.92,
|
| 175 |
+
"grad_norm": 0.14678366482257843,
|
| 176 |
+
"learning_rate": 3.8758931591217575e-06,
|
| 177 |
+
"loss": 0.1785,
|
| 178 |
+
"step": 1200
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"epoch": 2.0,
|
| 182 |
+
"grad_norm": 0.13319681584835052,
|
| 183 |
+
"learning_rate": 1.4925668450960217e-09,
|
| 184 |
+
"loss": 0.1739,
|
| 185 |
+
"step": 1250
|
| 186 |
+
},
|
| 187 |
+
{
|
| 188 |
+
"epoch": 2.0,
|
| 189 |
+
"step": 1250,
|
| 190 |
+
"total_flos": 1.62585013911552e+18,
|
| 191 |
+
"train_loss": 0.2258549835205078,
|
| 192 |
+
"train_runtime": 2124.0047,
|
| 193 |
+
"train_samples_per_second": 37.665,
|
| 194 |
+
"train_steps_per_second": 0.589
|
| 195 |
+
}
|
| 196 |
+
],
|
| 197 |
+
"logging_steps": 50,
|
| 198 |
+
"max_steps": 1250,
|
| 199 |
+
"num_input_tokens_seen": 0,
|
| 200 |
+
"num_train_epochs": 2,
|
| 201 |
+
"save_steps": 0,
|
| 202 |
+
"stateful_callbacks": {
|
| 203 |
+
"TrainerControl": {
|
| 204 |
+
"args": {
|
| 205 |
+
"should_epoch_stop": false,
|
| 206 |
+
"should_evaluate": false,
|
| 207 |
+
"should_log": false,
|
| 208 |
+
"should_save": false,
|
| 209 |
+
"should_training_stop": false
|
| 210 |
+
},
|
| 211 |
+
"attributes": {}
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"total_flos": 1.62585013911552e+18,
|
| 215 |
+
"train_batch_size": 64,
|
| 216 |
+
"trial_name": null,
|
| 217 |
+
"trial_params": null
|
| 218 |
+
}
|
omini/__init__.py
ADDED
|
File without changes
|
omini/pipeline/flux_omini.py
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import List, Union, Optional, Dict, Any, Callable, Type, Tuple
|
| 3 |
+
|
| 4 |
+
from diffusers.pipelines import FluxPipeline
|
| 5 |
+
from diffusers.pipelines.flux.pipeline_flux import (
|
| 6 |
+
FluxPipelineOutput,
|
| 7 |
+
FluxTransformer2DModel,
|
| 8 |
+
calculate_shift,
|
| 9 |
+
retrieve_timesteps,
|
| 10 |
+
np,
|
| 11 |
+
)
|
| 12 |
+
from diffusers.models.attention_processor import Attention, F
|
| 13 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 14 |
+
from transformers import pipeline
|
| 15 |
+
|
| 16 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
| 17 |
+
from accelerate.utils import is_torch_version
|
| 18 |
+
|
| 19 |
+
from contextlib import contextmanager
|
| 20 |
+
|
| 21 |
+
import cv2
|
| 22 |
+
|
| 23 |
+
from PIL import Image, ImageFilter
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def seed_everything(seed: int = 42):
|
| 27 |
+
torch.backends.cudnn.deterministic = True
|
| 28 |
+
torch.manual_seed(seed)
|
| 29 |
+
np.random.seed(seed)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def clip_hidden_states(hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 33 |
+
if hidden_states.dtype == torch.float16:
|
| 34 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 35 |
+
return hidden_states
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def encode_images(pipeline: FluxPipeline, images: torch.Tensor):
|
| 39 |
+
"""
|
| 40 |
+
Encodes the images into tokens and ids for FLUX pipeline.
|
| 41 |
+
"""
|
| 42 |
+
images = pipeline.image_processor.preprocess(images)
|
| 43 |
+
images = images.to(pipeline.device).to(pipeline.dtype)
|
| 44 |
+
images = pipeline.vae.encode(images).latent_dist.sample()
|
| 45 |
+
images = (
|
| 46 |
+
images - pipeline.vae.config.shift_factor
|
| 47 |
+
) * pipeline.vae.config.scaling_factor
|
| 48 |
+
images_tokens = pipeline._pack_latents(images, *images.shape)
|
| 49 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
| 50 |
+
images.shape[0],
|
| 51 |
+
images.shape[2],
|
| 52 |
+
images.shape[3],
|
| 53 |
+
pipeline.device,
|
| 54 |
+
pipeline.dtype,
|
| 55 |
+
)
|
| 56 |
+
if images_tokens.shape[1] != images_ids.shape[0]:
|
| 57 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
| 58 |
+
images.shape[0],
|
| 59 |
+
images.shape[2] // 2,
|
| 60 |
+
images.shape[3] // 2,
|
| 61 |
+
pipeline.device,
|
| 62 |
+
pipeline.dtype,
|
| 63 |
+
)
|
| 64 |
+
return images_tokens, images_ids
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
depth_pipe = None
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def convert_to_condition(
|
| 71 |
+
condition_type: str,
|
| 72 |
+
raw_img: Union[Image.Image, torch.Tensor],
|
| 73 |
+
blur_radius: Optional[int] = 5,
|
| 74 |
+
) -> Union[Image.Image, torch.Tensor]:
|
| 75 |
+
if condition_type == "depth":
|
| 76 |
+
global depth_pipe
|
| 77 |
+
depth_pipe = depth_pipe or pipeline(
|
| 78 |
+
task="depth-estimation",
|
| 79 |
+
model="LiheYoung/depth-anything-small-hf",
|
| 80 |
+
device="cpu", # Use "cpu" to enable parallel processing
|
| 81 |
+
)
|
| 82 |
+
source_image = raw_img.convert("RGB")
|
| 83 |
+
condition_img = depth_pipe(source_image)["depth"].convert("RGB")
|
| 84 |
+
return condition_img
|
| 85 |
+
elif condition_type == "canny":
|
| 86 |
+
img = np.array(raw_img)
|
| 87 |
+
edges = cv2.Canny(img, 100, 200)
|
| 88 |
+
edges = Image.fromarray(edges).convert("RGB")
|
| 89 |
+
return edges
|
| 90 |
+
elif condition_type == "coloring":
|
| 91 |
+
return raw_img.convert("L").convert("RGB")
|
| 92 |
+
elif condition_type == "deblurring":
|
| 93 |
+
condition_image = (
|
| 94 |
+
raw_img.convert("RGB")
|
| 95 |
+
.filter(ImageFilter.GaussianBlur(blur_radius))
|
| 96 |
+
.convert("RGB")
|
| 97 |
+
)
|
| 98 |
+
return condition_image
|
| 99 |
+
else:
|
| 100 |
+
print("Warning: Returning the raw image.")
|
| 101 |
+
return raw_img.convert("RGB")
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Condition(object):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
condition: Union[Image.Image, torch.Tensor],
|
| 108 |
+
adapter_setting: Union[str, dict],
|
| 109 |
+
position_delta=None,
|
| 110 |
+
position_scale=1.0,
|
| 111 |
+
latent_mask=None,
|
| 112 |
+
is_complement=False,
|
| 113 |
+
) -> None:
|
| 114 |
+
self.condition = condition
|
| 115 |
+
self.adapter = adapter_setting
|
| 116 |
+
self.position_delta = position_delta
|
| 117 |
+
self.position_scale = position_scale
|
| 118 |
+
self.latent_mask = (
|
| 119 |
+
latent_mask.T.reshape(-1) if latent_mask is not None else None
|
| 120 |
+
)
|
| 121 |
+
self.is_complement = is_complement
|
| 122 |
+
|
| 123 |
+
def encode(
|
| 124 |
+
self, pipe: FluxPipeline, empty: bool = False
|
| 125 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 126 |
+
condition_empty = Image.new("RGB", self.condition.size, (0, 0, 0))
|
| 127 |
+
tokens, ids = encode_images(pipe, condition_empty if empty else self.condition)
|
| 128 |
+
|
| 129 |
+
if self.position_delta is not None:
|
| 130 |
+
ids[:, 1] += self.position_delta[0]
|
| 131 |
+
ids[:, 2] += self.position_delta[1]
|
| 132 |
+
|
| 133 |
+
if self.position_scale != 1.0:
|
| 134 |
+
scale_bias = (self.position_scale - 1.0) / 2
|
| 135 |
+
ids[:, 1:] *= self.position_scale
|
| 136 |
+
ids[:, 1:] += scale_bias
|
| 137 |
+
|
| 138 |
+
if self.latent_mask is not None:
|
| 139 |
+
tokens = tokens[:, self.latent_mask]
|
| 140 |
+
ids = ids[self.latent_mask]
|
| 141 |
+
|
| 142 |
+
return tokens, ids
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
@contextmanager
|
| 146 |
+
def specify_lora(lora_modules: List[BaseTunerLayer], specified_lora):
|
| 147 |
+
# Filter valid lora modules
|
| 148 |
+
valid_lora_modules = [m for m in lora_modules if isinstance(m, BaseTunerLayer)]
|
| 149 |
+
# Save original scales
|
| 150 |
+
original_scales = [
|
| 151 |
+
{
|
| 152 |
+
adapter: module.scaling[adapter]
|
| 153 |
+
for adapter in module.active_adapters
|
| 154 |
+
if adapter in module.scaling
|
| 155 |
+
}
|
| 156 |
+
for module in valid_lora_modules
|
| 157 |
+
]
|
| 158 |
+
# Enter context: adjust scaling
|
| 159 |
+
for module in valid_lora_modules:
|
| 160 |
+
for adapter in module.active_adapters:
|
| 161 |
+
if adapter in module.scaling:
|
| 162 |
+
module.scaling[adapter] = 1 if adapter == specified_lora else 0
|
| 163 |
+
try:
|
| 164 |
+
yield
|
| 165 |
+
finally:
|
| 166 |
+
# Exit context: restore original scales
|
| 167 |
+
for module, scales in zip(valid_lora_modules, original_scales):
|
| 168 |
+
for adapter in module.active_adapters:
|
| 169 |
+
if adapter in module.scaling:
|
| 170 |
+
module.scaling[adapter] = scales[adapter]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def attn_forward(
|
| 174 |
+
attn: Attention,
|
| 175 |
+
hidden_states: List[torch.FloatTensor],
|
| 176 |
+
adapters: List[str],
|
| 177 |
+
hidden_states2: Optional[List[torch.FloatTensor]] = [],
|
| 178 |
+
position_embs: Optional[List[torch.Tensor]] = None,
|
| 179 |
+
group_mask: Optional[torch.Tensor] = None,
|
| 180 |
+
cache_mode: Optional[str] = None,
|
| 181 |
+
# to determine whether to cache the keys and values for this branch
|
| 182 |
+
to_cache: Optional[List[torch.Tensor]] = None,
|
| 183 |
+
cache_storage: Optional[List[torch.Tensor]] = None,
|
| 184 |
+
**kwargs: dict,
|
| 185 |
+
) -> torch.FloatTensor:
|
| 186 |
+
bs, _, _ = hidden_states[0].shape
|
| 187 |
+
h2_n = len(hidden_states2)
|
| 188 |
+
|
| 189 |
+
queries, keys, values = [], [], []
|
| 190 |
+
|
| 191 |
+
# Prepare query, key, value for each encoder hidden state (text branch)
|
| 192 |
+
for i, hidden_state in enumerate(hidden_states2):
|
| 193 |
+
query = attn.add_q_proj(hidden_state)
|
| 194 |
+
key = attn.add_k_proj(hidden_state)
|
| 195 |
+
value = attn.add_v_proj(hidden_state)
|
| 196 |
+
|
| 197 |
+
head_dim = key.shape[-1] // attn.heads
|
| 198 |
+
reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
| 199 |
+
|
| 200 |
+
query, key, value = map(reshape_fn, (query, key, value))
|
| 201 |
+
query, key = attn.norm_added_q(query), attn.norm_added_k(key)
|
| 202 |
+
|
| 203 |
+
queries.append(query)
|
| 204 |
+
keys.append(key)
|
| 205 |
+
values.append(value)
|
| 206 |
+
|
| 207 |
+
# Prepare query, key, value for each hidden state (image branch)
|
| 208 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 209 |
+
with specify_lora((attn.to_q, attn.to_k, attn.to_v), adapters[i + h2_n]):
|
| 210 |
+
query = attn.to_q(hidden_state)
|
| 211 |
+
key = attn.to_k(hidden_state)
|
| 212 |
+
value = attn.to_v(hidden_state)
|
| 213 |
+
|
| 214 |
+
head_dim = key.shape[-1] // attn.heads
|
| 215 |
+
reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
| 216 |
+
|
| 217 |
+
query, key, value = map(reshape_fn, (query, key, value))
|
| 218 |
+
query, key = attn.norm_q(query), attn.norm_k(key)
|
| 219 |
+
|
| 220 |
+
queries.append(query)
|
| 221 |
+
keys.append(key)
|
| 222 |
+
values.append(value)
|
| 223 |
+
|
| 224 |
+
# Apply rotary embedding
|
| 225 |
+
if position_embs is not None:
|
| 226 |
+
queries = [apply_rotary_emb(q, position_embs[i]) for i, q in enumerate(queries)]
|
| 227 |
+
keys = [apply_rotary_emb(k, position_embs[i]) for i, k in enumerate(keys)]
|
| 228 |
+
|
| 229 |
+
if cache_mode == "write":
|
| 230 |
+
for i, (k, v) in enumerate(zip(keys, values)):
|
| 231 |
+
if to_cache[i]:
|
| 232 |
+
cache_storage[attn.cache_idx][0].append(k)
|
| 233 |
+
cache_storage[attn.cache_idx][1].append(v)
|
| 234 |
+
|
| 235 |
+
attn_outputs = []
|
| 236 |
+
for i, query in enumerate(queries):
|
| 237 |
+
keys_, values_ = [], []
|
| 238 |
+
# Add keys and values from other branches
|
| 239 |
+
for j, (k, v) in enumerate(zip(keys, values)):
|
| 240 |
+
if (group_mask is not None) and not (group_mask[i][j].item()):
|
| 241 |
+
continue
|
| 242 |
+
keys_.append(k)
|
| 243 |
+
values_.append(v)
|
| 244 |
+
if cache_mode == "read":
|
| 245 |
+
keys_.extend(cache_storage[attn.cache_idx][0])
|
| 246 |
+
values_.extend(cache_storage[attn.cache_idx][1])
|
| 247 |
+
# Add keys and values from cache TODO
|
| 248 |
+
# Attention computation
|
| 249 |
+
attn_output = F.scaled_dot_product_attention(
|
| 250 |
+
query, torch.cat(keys_, dim=2), torch.cat(values_, dim=2)
|
| 251 |
+
).to(query.dtype)
|
| 252 |
+
attn_output = attn_output.transpose(1, 2).reshape(bs, -1, attn.heads * head_dim)
|
| 253 |
+
attn_outputs.append(attn_output)
|
| 254 |
+
|
| 255 |
+
# Reshape attention output to match the original hidden states
|
| 256 |
+
h_out, h2_out = [], []
|
| 257 |
+
|
| 258 |
+
for i, hidden_state in enumerate(hidden_states2):
|
| 259 |
+
h2_out.append(attn.to_add_out(attn_outputs[i]))
|
| 260 |
+
|
| 261 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 262 |
+
h = attn_outputs[i + h2_n]
|
| 263 |
+
if getattr(attn, "to_out", None) is not None:
|
| 264 |
+
with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
|
| 265 |
+
h = attn.to_out[0](h)
|
| 266 |
+
h_out.append(h)
|
| 267 |
+
|
| 268 |
+
return (h_out, h2_out) if h2_n else h_out
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def block_forward(
|
| 272 |
+
self,
|
| 273 |
+
image_hidden_states: List[torch.FloatTensor],
|
| 274 |
+
text_hidden_states: List[torch.FloatTensor],
|
| 275 |
+
tembs: List[torch.FloatTensor],
|
| 276 |
+
adapters: List[str],
|
| 277 |
+
position_embs=None,
|
| 278 |
+
attn_forward=attn_forward,
|
| 279 |
+
**kwargs: dict,
|
| 280 |
+
):
|
| 281 |
+
txt_n = len(text_hidden_states)
|
| 282 |
+
|
| 283 |
+
img_variables, txt_variables = [], []
|
| 284 |
+
|
| 285 |
+
for i, text_h in enumerate(text_hidden_states):
|
| 286 |
+
txt_variables.append(self.norm1_context(text_h, emb=tembs[i]))
|
| 287 |
+
|
| 288 |
+
for i, image_h in enumerate(image_hidden_states):
|
| 289 |
+
with specify_lora((self.norm1.linear,), adapters[i + txt_n]):
|
| 290 |
+
img_variables.append(self.norm1(image_h, emb=tembs[i + txt_n]))
|
| 291 |
+
|
| 292 |
+
# Attention.
|
| 293 |
+
img_attn_output, txt_attn_output = attn_forward(
|
| 294 |
+
self.attn,
|
| 295 |
+
hidden_states=[each[0] for each in img_variables],
|
| 296 |
+
hidden_states2=[each[0] for each in txt_variables],
|
| 297 |
+
position_embs=position_embs,
|
| 298 |
+
adapters=adapters,
|
| 299 |
+
**kwargs,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
text_out = []
|
| 303 |
+
for i in range(len(text_hidden_states)):
|
| 304 |
+
_, gate_msa, shift_mlp, scale_mlp, gate_mlp = txt_variables[i]
|
| 305 |
+
text_h = text_hidden_states[i] + txt_attn_output[i] * gate_msa.unsqueeze(1)
|
| 306 |
+
norm_h = (
|
| 307 |
+
self.norm2_context(text_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 308 |
+
)
|
| 309 |
+
text_h = self.ff_context(norm_h) * gate_mlp.unsqueeze(1) + text_h
|
| 310 |
+
text_out.append(clip_hidden_states(text_h))
|
| 311 |
+
|
| 312 |
+
image_out = []
|
| 313 |
+
for i in range(len(image_hidden_states)):
|
| 314 |
+
_, gate_msa, shift_mlp, scale_mlp, gate_mlp = img_variables[i]
|
| 315 |
+
image_h = (
|
| 316 |
+
image_hidden_states[i] + img_attn_output[i] * gate_msa.unsqueeze(1)
|
| 317 |
+
).to(image_hidden_states[i].dtype)
|
| 318 |
+
norm_h = self.norm2(image_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 319 |
+
with specify_lora((self.ff.net[2],), adapters[i + txt_n]):
|
| 320 |
+
image_h = image_h + self.ff(norm_h) * gate_mlp.unsqueeze(1)
|
| 321 |
+
image_out.append(clip_hidden_states(image_h))
|
| 322 |
+
return image_out, text_out
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def single_block_forward(
|
| 326 |
+
self,
|
| 327 |
+
hidden_states: List[torch.FloatTensor],
|
| 328 |
+
tembs: List[torch.FloatTensor],
|
| 329 |
+
adapters: List[str],
|
| 330 |
+
position_embs=None,
|
| 331 |
+
attn_forward=attn_forward,
|
| 332 |
+
**kwargs: dict,
|
| 333 |
+
):
|
| 334 |
+
mlp_hidden_states, gates = [[None for _ in hidden_states] for _ in range(2)]
|
| 335 |
+
|
| 336 |
+
hidden_state_norm = []
|
| 337 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 338 |
+
# [NOTE]!: This function's output is slightly DIFFERENT from the original
|
| 339 |
+
# FLUX version. In the original implementation, the gates were computed using
|
| 340 |
+
# the combined hidden states from both the image and text branches. Here, each
|
| 341 |
+
# branch computes its gate using only its own hidden state.
|
| 342 |
+
with specify_lora((self.norm.linear, self.proj_mlp), adapters[i]):
|
| 343 |
+
h_norm, gates[i] = self.norm(hidden_state, emb=tembs[i])
|
| 344 |
+
mlp_hidden_states[i] = self.act_mlp(self.proj_mlp(h_norm))
|
| 345 |
+
hidden_state_norm.append(h_norm)
|
| 346 |
+
|
| 347 |
+
attn_outputs = attn_forward(
|
| 348 |
+
self.attn, hidden_state_norm, adapters, position_embs=position_embs, **kwargs
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
h_out = []
|
| 352 |
+
for i in range(len(hidden_states)):
|
| 353 |
+
with specify_lora((self.proj_out,), adapters[i]):
|
| 354 |
+
h = torch.cat([attn_outputs[i], mlp_hidden_states[i]], dim=2)
|
| 355 |
+
h = gates[i].unsqueeze(1) * self.proj_out(h) + hidden_states[i]
|
| 356 |
+
h_out.append(clip_hidden_states(h))
|
| 357 |
+
|
| 358 |
+
return h_out
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def transformer_forward(
|
| 362 |
+
transformer: FluxTransformer2DModel,
|
| 363 |
+
image_features: List[torch.Tensor],
|
| 364 |
+
text_features: List[torch.Tensor] = None,
|
| 365 |
+
img_ids: List[torch.Tensor] = None,
|
| 366 |
+
txt_ids: List[torch.Tensor] = None,
|
| 367 |
+
pooled_projections: List[torch.Tensor] = None,
|
| 368 |
+
timesteps: List[torch.LongTensor] = None,
|
| 369 |
+
guidances: List[torch.Tensor] = None,
|
| 370 |
+
adapters: List[str] = None,
|
| 371 |
+
# Assign the function to be used for the forward pass
|
| 372 |
+
single_block_forward=single_block_forward,
|
| 373 |
+
block_forward=block_forward,
|
| 374 |
+
attn_forward=attn_forward,
|
| 375 |
+
**kwargs: dict,
|
| 376 |
+
):
|
| 377 |
+
self = transformer
|
| 378 |
+
txt_n = len(text_features) if text_features is not None else 0
|
| 379 |
+
|
| 380 |
+
adapters = adapters or [None] * (txt_n + len(image_features))
|
| 381 |
+
assert len(adapters) == len(timesteps)
|
| 382 |
+
|
| 383 |
+
# Preprocess the image_features
|
| 384 |
+
image_hidden_states = []
|
| 385 |
+
for i, image_feature in enumerate(image_features):
|
| 386 |
+
with specify_lora((self.x_embedder,), adapters[i + txt_n]):
|
| 387 |
+
image_hidden_states.append(self.x_embedder(image_feature))
|
| 388 |
+
|
| 389 |
+
# Preprocess the text_features
|
| 390 |
+
text_hidden_states = []
|
| 391 |
+
for text_feature in text_features:
|
| 392 |
+
text_hidden_states.append(self.context_embedder(text_feature))
|
| 393 |
+
|
| 394 |
+
# Prepare embeddings of (timestep, guidance, pooled_projections)
|
| 395 |
+
assert len(timesteps) == len(image_features) + len(text_features)
|
| 396 |
+
|
| 397 |
+
def get_temb(timestep, guidance, pooled_projection):
|
| 398 |
+
timestep = timestep.to(image_hidden_states[0].dtype) * 1000
|
| 399 |
+
if guidance is not None:
|
| 400 |
+
guidance = guidance.to(image_hidden_states[0].dtype) * 1000
|
| 401 |
+
return self.time_text_embed(timestep, guidance, pooled_projection)
|
| 402 |
+
else:
|
| 403 |
+
return self.time_text_embed(timestep, pooled_projection)
|
| 404 |
+
|
| 405 |
+
tembs = [get_temb(*each) for each in zip(timesteps, guidances, pooled_projections)]
|
| 406 |
+
|
| 407 |
+
# Prepare position embeddings for each token
|
| 408 |
+
position_embs = [self.pos_embed(each) for each in (*txt_ids, *img_ids)]
|
| 409 |
+
|
| 410 |
+
# Prepare the gradient checkpointing kwargs
|
| 411 |
+
gckpt_kwargs: Dict[str, Any] = (
|
| 412 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# dual branch blocks
|
| 416 |
+
for block in self.transformer_blocks:
|
| 417 |
+
block_kwargs = {
|
| 418 |
+
"self": block,
|
| 419 |
+
"image_hidden_states": image_hidden_states,
|
| 420 |
+
"text_hidden_states": text_hidden_states,
|
| 421 |
+
"tembs": tembs,
|
| 422 |
+
"position_embs": position_embs,
|
| 423 |
+
"adapters": adapters,
|
| 424 |
+
"attn_forward": attn_forward,
|
| 425 |
+
**kwargs,
|
| 426 |
+
}
|
| 427 |
+
if self.training and self.gradient_checkpointing:
|
| 428 |
+
image_hidden_states, text_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 429 |
+
block_forward, **block_kwargs, **gckpt_kwargs
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
image_hidden_states, text_hidden_states = block_forward(**block_kwargs)
|
| 433 |
+
|
| 434 |
+
# combine image and text hidden states then pass through the single transformer blocks
|
| 435 |
+
all_hidden_states = [*text_hidden_states, *image_hidden_states]
|
| 436 |
+
for block in self.single_transformer_blocks:
|
| 437 |
+
block_kwargs = {
|
| 438 |
+
"self": block,
|
| 439 |
+
"hidden_states": all_hidden_states,
|
| 440 |
+
"tembs": tembs,
|
| 441 |
+
"position_embs": position_embs,
|
| 442 |
+
"adapters": adapters,
|
| 443 |
+
"attn_forward": attn_forward,
|
| 444 |
+
**kwargs,
|
| 445 |
+
}
|
| 446 |
+
if self.training and self.gradient_checkpointing:
|
| 447 |
+
all_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 448 |
+
single_block_forward, **block_kwargs, **gckpt_kwargs
|
| 449 |
+
)
|
| 450 |
+
else:
|
| 451 |
+
all_hidden_states = single_block_forward(**block_kwargs)
|
| 452 |
+
|
| 453 |
+
image_hidden_states = self.norm_out(all_hidden_states[txt_n], tembs[txt_n])
|
| 454 |
+
output = self.proj_out(image_hidden_states)
|
| 455 |
+
|
| 456 |
+
return (output,)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
@torch.no_grad()
|
| 460 |
+
def generate(
|
| 461 |
+
pipeline: FluxPipeline,
|
| 462 |
+
prompt: Union[str, List[str]] = None,
|
| 463 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 464 |
+
height: Optional[int] = 512,
|
| 465 |
+
width: Optional[int] = 512,
|
| 466 |
+
num_inference_steps: int = 28,
|
| 467 |
+
timesteps: List[int] = None,
|
| 468 |
+
guidance_scale: float = 3.5,
|
| 469 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 470 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 471 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 472 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 473 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 474 |
+
output_type: Optional[str] = "pil",
|
| 475 |
+
return_dict: bool = True,
|
| 476 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 477 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 478 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 479 |
+
max_sequence_length: int = 512,
|
| 480 |
+
# Condition Parameters (Optional)
|
| 481 |
+
main_adapter: Optional[List[str]] = None,
|
| 482 |
+
conditions: List[Condition] = [],
|
| 483 |
+
image_guidance_scale: float = 1.0,
|
| 484 |
+
transformer_kwargs: Optional[Dict[str, Any]] = {},
|
| 485 |
+
kv_cache=False,
|
| 486 |
+
latent_mask=None,
|
| 487 |
+
**params: dict,
|
| 488 |
+
):
|
| 489 |
+
self = pipeline
|
| 490 |
+
|
| 491 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 492 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 493 |
+
|
| 494 |
+
# Check inputs. Raise error if not correct
|
| 495 |
+
self.check_inputs(
|
| 496 |
+
prompt,
|
| 497 |
+
prompt_2,
|
| 498 |
+
height,
|
| 499 |
+
width,
|
| 500 |
+
prompt_embeds=prompt_embeds,
|
| 501 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 502 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 503 |
+
max_sequence_length=max_sequence_length,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
self._guidance_scale = guidance_scale
|
| 507 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 508 |
+
|
| 509 |
+
# Define call parameters
|
| 510 |
+
if prompt is not None and isinstance(prompt, str):
|
| 511 |
+
batch_size = 1
|
| 512 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 513 |
+
batch_size = len(prompt)
|
| 514 |
+
else:
|
| 515 |
+
batch_size = prompt_embeds.shape[0]
|
| 516 |
+
|
| 517 |
+
device = self._execution_device
|
| 518 |
+
|
| 519 |
+
# Prepare prompt embeddings
|
| 520 |
+
(
|
| 521 |
+
prompt_embeds,
|
| 522 |
+
pooled_prompt_embeds,
|
| 523 |
+
text_ids,
|
| 524 |
+
) = self.encode_prompt(
|
| 525 |
+
prompt=prompt,
|
| 526 |
+
prompt_2=prompt_2,
|
| 527 |
+
prompt_embeds=prompt_embeds,
|
| 528 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 529 |
+
device=device,
|
| 530 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 531 |
+
max_sequence_length=max_sequence_length,
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Prepare latent variables
|
| 535 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 536 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 537 |
+
batch_size * num_images_per_prompt,
|
| 538 |
+
num_channels_latents,
|
| 539 |
+
height,
|
| 540 |
+
width,
|
| 541 |
+
prompt_embeds.dtype,
|
| 542 |
+
device,
|
| 543 |
+
generator,
|
| 544 |
+
latents,
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
if latent_mask is not None:
|
| 548 |
+
latent_mask = latent_mask.T.reshape(-1)
|
| 549 |
+
latents = latents[:, latent_mask]
|
| 550 |
+
latent_image_ids = latent_image_ids[latent_mask]
|
| 551 |
+
|
| 552 |
+
# Prepare conditions
|
| 553 |
+
c_latents, uc_latents, c_ids, c_timesteps = ([], [], [], [])
|
| 554 |
+
c_projections, c_guidances, c_adapters = ([], [], [])
|
| 555 |
+
complement_cond = None
|
| 556 |
+
for condition in conditions:
|
| 557 |
+
tokens, ids = condition.encode(self)
|
| 558 |
+
c_latents.append(tokens) # [batch_size, token_n, token_dim]
|
| 559 |
+
# Empty condition for unconditioned image
|
| 560 |
+
if image_guidance_scale != 1.0:
|
| 561 |
+
uc_latents.append(condition.encode(self, empty=True)[0])
|
| 562 |
+
c_ids.append(ids) # [token_n, id_dim(3)]
|
| 563 |
+
c_timesteps.append(torch.zeros([1], device=device))
|
| 564 |
+
c_projections.append(pooled_prompt_embeds)
|
| 565 |
+
c_guidances.append(torch.ones([1], device=device))
|
| 566 |
+
c_adapters.append(condition.adapter)
|
| 567 |
+
# This complement_condition will be combined with the original image.
|
| 568 |
+
# See the token integration of OminiControl2 [https://arxiv.org/abs/2503.08280]
|
| 569 |
+
if condition.is_complement:
|
| 570 |
+
complement_cond = (tokens, ids)
|
| 571 |
+
|
| 572 |
+
# Prepare timesteps
|
| 573 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 574 |
+
image_seq_len = latents.shape[1]
|
| 575 |
+
mu = calculate_shift(
|
| 576 |
+
image_seq_len,
|
| 577 |
+
self.scheduler.config.base_image_seq_len,
|
| 578 |
+
self.scheduler.config.max_image_seq_len,
|
| 579 |
+
self.scheduler.config.base_shift,
|
| 580 |
+
self.scheduler.config.max_shift,
|
| 581 |
+
)
|
| 582 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 583 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
|
| 584 |
+
)
|
| 585 |
+
num_warmup_steps = max(
|
| 586 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
| 587 |
+
)
|
| 588 |
+
self._num_timesteps = len(timesteps)
|
| 589 |
+
|
| 590 |
+
if kv_cache:
|
| 591 |
+
attn_counter = 0
|
| 592 |
+
for module in self.transformer.modules():
|
| 593 |
+
if isinstance(module, Attention):
|
| 594 |
+
setattr(module, "cache_idx", attn_counter)
|
| 595 |
+
attn_counter += 1
|
| 596 |
+
kv_cond = [[[], []] for _ in range(attn_counter)]
|
| 597 |
+
kv_uncond = [[[], []] for _ in range(attn_counter)]
|
| 598 |
+
|
| 599 |
+
def clear_cache():
|
| 600 |
+
for storage in [kv_cond, kv_uncond]:
|
| 601 |
+
for kesy, values in storage:
|
| 602 |
+
kesy.clear()
|
| 603 |
+
values.clear()
|
| 604 |
+
|
| 605 |
+
branch_n = len(conditions) + 2
|
| 606 |
+
group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool)
|
| 607 |
+
# Disable the attention cross different condition branches
|
| 608 |
+
group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions)))
|
| 609 |
+
# Disable the attention from condition branches to image branch and text branch
|
| 610 |
+
if kv_cache:
|
| 611 |
+
group_mask[2:, :2] = False
|
| 612 |
+
|
| 613 |
+
# Denoising loop
|
| 614 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 615 |
+
for i, t in enumerate(timesteps):
|
| 616 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 617 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000
|
| 618 |
+
|
| 619 |
+
# handle guidance
|
| 620 |
+
if self.transformer.config.guidance_embeds:
|
| 621 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
| 622 |
+
guidance = guidance.expand(latents.shape[0])
|
| 623 |
+
else:
|
| 624 |
+
guidance, c_guidances = None, [None for _ in c_guidances]
|
| 625 |
+
|
| 626 |
+
if kv_cache:
|
| 627 |
+
mode = "write" if i == 0 else "read"
|
| 628 |
+
if mode == "write":
|
| 629 |
+
clear_cache()
|
| 630 |
+
use_cond = not (kv_cache) or mode == "write"
|
| 631 |
+
|
| 632 |
+
noise_pred = transformer_forward(
|
| 633 |
+
self.transformer,
|
| 634 |
+
image_features=[latents] + (c_latents if use_cond else []),
|
| 635 |
+
text_features=[prompt_embeds],
|
| 636 |
+
img_ids=[latent_image_ids] + (c_ids if use_cond else []),
|
| 637 |
+
txt_ids=[text_ids],
|
| 638 |
+
timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
|
| 639 |
+
pooled_projections=[pooled_prompt_embeds] * 2
|
| 640 |
+
+ (c_projections if use_cond else []),
|
| 641 |
+
guidances=[guidance] * 2 + (c_guidances if use_cond else []),
|
| 642 |
+
return_dict=False,
|
| 643 |
+
adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
|
| 644 |
+
cache_mode=mode if kv_cache else None,
|
| 645 |
+
cache_storage=kv_cond if kv_cache else None,
|
| 646 |
+
to_cache=[False, False, *[True] * len(c_latents)],
|
| 647 |
+
group_mask=group_mask,
|
| 648 |
+
**transformer_kwargs,
|
| 649 |
+
)[0]
|
| 650 |
+
|
| 651 |
+
if image_guidance_scale != 1.0:
|
| 652 |
+
unc_pred = transformer_forward(
|
| 653 |
+
self.transformer,
|
| 654 |
+
image_features=[latents] + (uc_latents if use_cond else []),
|
| 655 |
+
text_features=[prompt_embeds],
|
| 656 |
+
img_ids=[latent_image_ids] + (c_ids if use_cond else []),
|
| 657 |
+
txt_ids=[text_ids],
|
| 658 |
+
timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
|
| 659 |
+
pooled_projections=[pooled_prompt_embeds] * 2
|
| 660 |
+
+ (c_projections if use_cond else []),
|
| 661 |
+
guidances=[guidance] * 2 + (c_guidances if use_cond else []),
|
| 662 |
+
return_dict=False,
|
| 663 |
+
adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
|
| 664 |
+
cache_mode=mode if kv_cache else None,
|
| 665 |
+
cache_storage=kv_uncond if kv_cache else None,
|
| 666 |
+
to_cache=[False, False, *[True] * len(c_latents)],
|
| 667 |
+
**transformer_kwargs,
|
| 668 |
+
)[0]
|
| 669 |
+
|
| 670 |
+
noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
|
| 671 |
+
|
| 672 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 673 |
+
latents_dtype = latents.dtype
|
| 674 |
+
latents = self.scheduler.step(noise_pred, t, latents)[0]
|
| 675 |
+
|
| 676 |
+
if latents.dtype != latents_dtype:
|
| 677 |
+
if torch.backends.mps.is_available():
|
| 678 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 679 |
+
latents = latents.to(latents_dtype)
|
| 680 |
+
|
| 681 |
+
if callback_on_step_end is not None:
|
| 682 |
+
callback_kwargs = {}
|
| 683 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 684 |
+
callback_kwargs[k] = locals()[k]
|
| 685 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 686 |
+
|
| 687 |
+
latents = callback_outputs.pop("latents", latents)
|
| 688 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 689 |
+
|
| 690 |
+
# call the callback, if provided
|
| 691 |
+
if i == len(timesteps) - 1 or (
|
| 692 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 693 |
+
):
|
| 694 |
+
progress_bar.update()
|
| 695 |
+
|
| 696 |
+
if latent_mask is not None:
|
| 697 |
+
# Combine the generated latents and the complement condition
|
| 698 |
+
assert complement_cond is not None
|
| 699 |
+
comp_latent, comp_ids = complement_cond
|
| 700 |
+
all_ids = torch.cat([latent_image_ids, comp_ids], dim=0) # (Ta+Tc,3)
|
| 701 |
+
shape = (all_ids.max(dim=0).values + 1).to(torch.long) # (3,)
|
| 702 |
+
H, W = shape[1].item(), shape[2].item()
|
| 703 |
+
B, _, C = latents.shape
|
| 704 |
+
# Create a empty canvas
|
| 705 |
+
canvas = latents.new_zeros(B, H * W, C) # (B,H*W,C)
|
| 706 |
+
|
| 707 |
+
# Stash the latents and the complement condition
|
| 708 |
+
def _stash(canvas, tokens, ids, H, W) -> None:
|
| 709 |
+
B, T, C = tokens.shape
|
| 710 |
+
ids = ids.to(torch.long)
|
| 711 |
+
flat_idx = (ids[:, 1] * W + ids[:, 2]).to(torch.long)
|
| 712 |
+
canvas.view(B, -1, C).index_copy_(1, flat_idx, tokens)
|
| 713 |
+
|
| 714 |
+
_stash(canvas, latents, latent_image_ids, H, W)
|
| 715 |
+
_stash(canvas, comp_latent, comp_ids, H, W)
|
| 716 |
+
latents = canvas.view(B, H * W, C)
|
| 717 |
+
|
| 718 |
+
if output_type == "latent":
|
| 719 |
+
image = latents
|
| 720 |
+
else:
|
| 721 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 722 |
+
latents = (
|
| 723 |
+
latents / self.vae.config.scaling_factor
|
| 724 |
+
) + self.vae.config.shift_factor
|
| 725 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 726 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 727 |
+
|
| 728 |
+
# Offload all models
|
| 729 |
+
self.maybe_free_model_hooks()
|
| 730 |
+
|
| 731 |
+
if not return_dict:
|
| 732 |
+
return (image,)
|
| 733 |
+
|
| 734 |
+
return FluxPipelineOutput(images=image)
|
omini/pipeline/flux_omini_ablate_qkv.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This version is for ablation study for the effect of scaling in LORA adapters.
|
| 3 |
+
|
| 4 |
+
The `generate` function is modified to include a `global_scale` parameter,
|
| 5 |
+
the `SCALE` variable is set globally at the start of generation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from typing import List, Union, Optional, Dict, Any, Callable, Type, Tuple
|
| 10 |
+
|
| 11 |
+
from diffusers.pipelines import FluxPipeline
|
| 12 |
+
from diffusers.pipelines.flux.pipeline_flux import (
|
| 13 |
+
FluxPipelineOutput,
|
| 14 |
+
FluxTransformer2DModel,
|
| 15 |
+
calculate_shift,
|
| 16 |
+
retrieve_timesteps,
|
| 17 |
+
np,
|
| 18 |
+
)
|
| 19 |
+
from diffusers.models.attention_processor import Attention, F
|
| 20 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 21 |
+
from transformers import pipeline
|
| 22 |
+
|
| 23 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
| 24 |
+
from accelerate.utils import is_torch_version
|
| 25 |
+
|
| 26 |
+
from contextlib import contextmanager
|
| 27 |
+
|
| 28 |
+
import cv2
|
| 29 |
+
|
| 30 |
+
from PIL import Image, ImageFilter
|
| 31 |
+
|
| 32 |
+
T_Q = None
|
| 33 |
+
T_K = None
|
| 34 |
+
T_V = None
|
| 35 |
+
|
| 36 |
+
def seed_everything(seed: int = 42):
|
| 37 |
+
torch.backends.cudnn.deterministic = True
|
| 38 |
+
torch.manual_seed(seed)
|
| 39 |
+
np.random.seed(seed)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def clip_hidden_states(hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 43 |
+
if hidden_states.dtype == torch.float16:
|
| 44 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 45 |
+
return hidden_states
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def encode_images(pipeline: FluxPipeline, images: torch.Tensor):
|
| 49 |
+
"""
|
| 50 |
+
Encodes the images into tokens and ids for FLUX pipeline.
|
| 51 |
+
"""
|
| 52 |
+
images = pipeline.image_processor.preprocess(images)
|
| 53 |
+
images = images.to(pipeline.device).to(pipeline.dtype)
|
| 54 |
+
images = pipeline.vae.encode(images).latent_dist.sample()
|
| 55 |
+
images = (
|
| 56 |
+
images - pipeline.vae.config.shift_factor
|
| 57 |
+
) * pipeline.vae.config.scaling_factor
|
| 58 |
+
images_tokens = pipeline._pack_latents(images, *images.shape)
|
| 59 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
| 60 |
+
images.shape[0],
|
| 61 |
+
images.shape[2],
|
| 62 |
+
images.shape[3],
|
| 63 |
+
pipeline.device,
|
| 64 |
+
pipeline.dtype,
|
| 65 |
+
)
|
| 66 |
+
if images_tokens.shape[1] != images_ids.shape[0]:
|
| 67 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
| 68 |
+
images.shape[0],
|
| 69 |
+
images.shape[2] // 2,
|
| 70 |
+
images.shape[3] // 2,
|
| 71 |
+
pipeline.device,
|
| 72 |
+
pipeline.dtype,
|
| 73 |
+
)
|
| 74 |
+
return images_tokens, images_ids
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
depth_pipe = None
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def convert_to_condition(
|
| 81 |
+
condition_type: str,
|
| 82 |
+
raw_img: Union[Image.Image, torch.Tensor],
|
| 83 |
+
blur_radius: Optional[int] = 5,
|
| 84 |
+
) -> Union[Image.Image, torch.Tensor]:
|
| 85 |
+
if condition_type == "depth":
|
| 86 |
+
global depth_pipe
|
| 87 |
+
depth_pipe = depth_pipe or pipeline(
|
| 88 |
+
task="depth-estimation",
|
| 89 |
+
model="LiheYoung/depth-anything-small-hf",
|
| 90 |
+
device="cpu", # Use "cpu" to enable parallel processing
|
| 91 |
+
)
|
| 92 |
+
source_image = raw_img.convert("RGB")
|
| 93 |
+
condition_img = depth_pipe(source_image)["depth"].convert("RGB")
|
| 94 |
+
return condition_img
|
| 95 |
+
elif condition_type == "canny":
|
| 96 |
+
img = np.array(raw_img)
|
| 97 |
+
edges = cv2.Canny(img, 100, 200)
|
| 98 |
+
edges = Image.fromarray(edges).convert("RGB")
|
| 99 |
+
return edges
|
| 100 |
+
elif condition_type == "coloring":
|
| 101 |
+
return raw_img.convert("L").convert("RGB")
|
| 102 |
+
elif condition_type == "deblurring":
|
| 103 |
+
condition_image = (
|
| 104 |
+
raw_img.convert("RGB")
|
| 105 |
+
.filter(ImageFilter.GaussianBlur(blur_radius))
|
| 106 |
+
.convert("RGB")
|
| 107 |
+
)
|
| 108 |
+
return condition_image
|
| 109 |
+
else:
|
| 110 |
+
print("Warning: Returning the raw image.")
|
| 111 |
+
return raw_img.convert("RGB")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class Condition(object):
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
condition: Union[Image.Image, torch.Tensor],
|
| 118 |
+
adapter_setting: Union[str, dict],
|
| 119 |
+
position_delta=None,
|
| 120 |
+
position_scale=1.0,
|
| 121 |
+
latent_mask=None,
|
| 122 |
+
is_complement=False,
|
| 123 |
+
) -> None:
|
| 124 |
+
self.condition = condition
|
| 125 |
+
self.adapter = adapter_setting
|
| 126 |
+
self.position_delta = position_delta
|
| 127 |
+
self.position_scale = position_scale
|
| 128 |
+
self.latent_mask = (
|
| 129 |
+
latent_mask.T.reshape(-1) if latent_mask is not None else None
|
| 130 |
+
)
|
| 131 |
+
self.is_complement = is_complement
|
| 132 |
+
|
| 133 |
+
def encode(
|
| 134 |
+
self, pipe: FluxPipeline, empty: bool = False
|
| 135 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 136 |
+
condition_empty = Image.new("RGB", self.condition.size, (0, 0, 0))
|
| 137 |
+
tokens, ids = encode_images(pipe, condition_empty if empty else self.condition)
|
| 138 |
+
|
| 139 |
+
if self.position_delta is not None:
|
| 140 |
+
ids[:, 1] += self.position_delta[0]
|
| 141 |
+
ids[:, 2] += self.position_delta[1]
|
| 142 |
+
|
| 143 |
+
if self.position_scale != 1.0:
|
| 144 |
+
scale_bias = (self.position_scale - 1.0) / 2
|
| 145 |
+
ids[:, 1:] *= self.position_scale
|
| 146 |
+
ids[:, 1:] += scale_bias
|
| 147 |
+
|
| 148 |
+
if self.latent_mask is not None:
|
| 149 |
+
tokens = tokens[:, self.latent_mask]
|
| 150 |
+
ids = ids[self.latent_mask]
|
| 151 |
+
|
| 152 |
+
return tokens, ids
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@contextmanager
|
| 156 |
+
def specify_lora(lora_modules: List[BaseTunerLayer], specified_lora, T=None):
|
| 157 |
+
# Filter valid lora modules
|
| 158 |
+
valid_lora_modules = [m for m in lora_modules if isinstance(m, BaseTunerLayer)]
|
| 159 |
+
# Save original scales
|
| 160 |
+
original_scales = [
|
| 161 |
+
{
|
| 162 |
+
adapter: module.scaling[adapter]
|
| 163 |
+
for adapter in module.active_adapters
|
| 164 |
+
if adapter in module.scaling
|
| 165 |
+
}
|
| 166 |
+
for module in valid_lora_modules
|
| 167 |
+
]
|
| 168 |
+
# Enter context: adjust scaling
|
| 169 |
+
for module in valid_lora_modules:
|
| 170 |
+
for adapter in module.active_adapters:
|
| 171 |
+
if adapter in module.scaling:
|
| 172 |
+
module.scaling[adapter] = 1. if adapter == specified_lora else 0
|
| 173 |
+
|
| 174 |
+
if hasattr(module, 'rotation') and T is not None:
|
| 175 |
+
# alter T if specified
|
| 176 |
+
if adapter in module.rotation:
|
| 177 |
+
# print("FOR DEBUG:entering specify_lora context: setting T")
|
| 178 |
+
module.rotation[adapter].T = T
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
yield
|
| 182 |
+
finally:
|
| 183 |
+
# Exit context: restore original scales
|
| 184 |
+
for module, scales in zip(valid_lora_modules, original_scales):
|
| 185 |
+
for adapter in module.active_adapters:
|
| 186 |
+
if adapter in module.scaling:
|
| 187 |
+
module.scaling[adapter] = scales[adapter]
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def attn_forward(
|
| 191 |
+
attn: Attention,
|
| 192 |
+
hidden_states: List[torch.FloatTensor],
|
| 193 |
+
adapters: List[str],
|
| 194 |
+
hidden_states2: Optional[List[torch.FloatTensor]] = [],
|
| 195 |
+
position_embs: Optional[List[torch.Tensor]] = None,
|
| 196 |
+
group_mask: Optional[torch.Tensor] = None,
|
| 197 |
+
cache_mode: Optional[str] = None,
|
| 198 |
+
# to determine whether to cache the keys and values for this branch
|
| 199 |
+
to_cache: Optional[List[torch.Tensor]] = None,
|
| 200 |
+
cache_storage: Optional[List[torch.Tensor]] = None,
|
| 201 |
+
**kwargs: dict,
|
| 202 |
+
) -> torch.FloatTensor:
|
| 203 |
+
bs, _, _ = hidden_states[0].shape
|
| 204 |
+
h2_n = len(hidden_states2)
|
| 205 |
+
|
| 206 |
+
queries, keys, values = [], [], []
|
| 207 |
+
|
| 208 |
+
# Prepare query, key, value for each encoder hidden state (text branch)
|
| 209 |
+
for i, hidden_state in enumerate(hidden_states2):
|
| 210 |
+
query = attn.add_q_proj(hidden_state)
|
| 211 |
+
key = attn.add_k_proj(hidden_state)
|
| 212 |
+
value = attn.add_v_proj(hidden_state)
|
| 213 |
+
|
| 214 |
+
head_dim = key.shape[-1] // attn.heads
|
| 215 |
+
reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
| 216 |
+
|
| 217 |
+
query, key, value = map(reshape_fn, (query, key, value))
|
| 218 |
+
query, key = attn.norm_added_q(query), attn.norm_added_k(key)
|
| 219 |
+
|
| 220 |
+
queries.append(query)
|
| 221 |
+
keys.append(key)
|
| 222 |
+
values.append(value)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
## THIS IS THE MODIFIED PART TO ABALTE QKV ROTATION T ##
|
| 226 |
+
# Prepare query, key, value for each hidden state (image branch)
|
| 227 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 228 |
+
with specify_lora((attn.to_q,), adapters[i + h2_n], T=T_Q):
|
| 229 |
+
query = attn.to_q(hidden_state)
|
| 230 |
+
|
| 231 |
+
with specify_lora((attn.to_k,), adapters[i + h2_n], T=T_K):
|
| 232 |
+
key = attn.to_k(hidden_state)
|
| 233 |
+
|
| 234 |
+
with specify_lora((attn.to_v,), adapters[i + h2_n], T=T_V):
|
| 235 |
+
value = attn.to_v(hidden_state)
|
| 236 |
+
|
| 237 |
+
head_dim = key.shape[-1] // attn.heads
|
| 238 |
+
reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
| 239 |
+
|
| 240 |
+
query, key, value = map(reshape_fn, (query, key, value))
|
| 241 |
+
query, key = attn.norm_q(query), attn.norm_k(key)
|
| 242 |
+
|
| 243 |
+
queries.append(query)
|
| 244 |
+
keys.append(key)
|
| 245 |
+
values.append(value)
|
| 246 |
+
|
| 247 |
+
# Apply rotary embedding
|
| 248 |
+
if position_embs is not None:
|
| 249 |
+
queries = [apply_rotary_emb(q, position_embs[i]) for i, q in enumerate(queries)]
|
| 250 |
+
keys = [apply_rotary_emb(k, position_embs[i]) for i, k in enumerate(keys)]
|
| 251 |
+
|
| 252 |
+
if cache_mode == "write":
|
| 253 |
+
for i, (k, v) in enumerate(zip(keys, values)):
|
| 254 |
+
if to_cache[i]:
|
| 255 |
+
cache_storage[attn.cache_idx][0].append(k)
|
| 256 |
+
cache_storage[attn.cache_idx][1].append(v)
|
| 257 |
+
|
| 258 |
+
attn_outputs = []
|
| 259 |
+
for i, query in enumerate(queries):
|
| 260 |
+
keys_, values_ = [], []
|
| 261 |
+
# Add keys and values from other branches
|
| 262 |
+
for j, (k, v) in enumerate(zip(keys, values)):
|
| 263 |
+
if (group_mask is not None) and not (group_mask[i][j].item()):
|
| 264 |
+
continue
|
| 265 |
+
keys_.append(k)
|
| 266 |
+
values_.append(v)
|
| 267 |
+
if cache_mode == "read":
|
| 268 |
+
keys_.extend(cache_storage[attn.cache_idx][0])
|
| 269 |
+
values_.extend(cache_storage[attn.cache_idx][1])
|
| 270 |
+
# Add keys and values from cache TODO
|
| 271 |
+
# Attention computation
|
| 272 |
+
attn_output = F.scaled_dot_product_attention(
|
| 273 |
+
query, torch.cat(keys_, dim=2), torch.cat(values_, dim=2)
|
| 274 |
+
).to(query.dtype)
|
| 275 |
+
attn_output = attn_output.transpose(1, 2).reshape(bs, -1, attn.heads * head_dim)
|
| 276 |
+
attn_outputs.append(attn_output)
|
| 277 |
+
|
| 278 |
+
# Reshape attention output to match the original hidden states
|
| 279 |
+
h_out, h2_out = [], []
|
| 280 |
+
|
| 281 |
+
for i, hidden_state in enumerate(hidden_states2):
|
| 282 |
+
h2_out.append(attn.to_add_out(attn_outputs[i]))
|
| 283 |
+
|
| 284 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 285 |
+
h = attn_outputs[i + h2_n]
|
| 286 |
+
if getattr(attn, "to_out", None) is not None:
|
| 287 |
+
with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
|
| 288 |
+
h = attn.to_out[0](h)
|
| 289 |
+
h_out.append(h)
|
| 290 |
+
|
| 291 |
+
return (h_out, h2_out) if h2_n else h_out
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def block_forward(
|
| 295 |
+
self,
|
| 296 |
+
image_hidden_states: List[torch.FloatTensor],
|
| 297 |
+
text_hidden_states: List[torch.FloatTensor],
|
| 298 |
+
tembs: List[torch.FloatTensor],
|
| 299 |
+
adapters: List[str],
|
| 300 |
+
position_embs=None,
|
| 301 |
+
attn_forward=attn_forward,
|
| 302 |
+
**kwargs: dict,
|
| 303 |
+
):
|
| 304 |
+
txt_n = len(text_hidden_states)
|
| 305 |
+
|
| 306 |
+
img_variables, txt_variables = [], []
|
| 307 |
+
|
| 308 |
+
for i, text_h in enumerate(text_hidden_states):
|
| 309 |
+
txt_variables.append(self.norm1_context(text_h, emb=tembs[i]))
|
| 310 |
+
|
| 311 |
+
for i, image_h in enumerate(image_hidden_states):
|
| 312 |
+
with specify_lora((self.norm1.linear,), adapters[i + txt_n]):
|
| 313 |
+
img_variables.append(self.norm1(image_h, emb=tembs[i + txt_n]))
|
| 314 |
+
|
| 315 |
+
# Attention.
|
| 316 |
+
img_attn_output, txt_attn_output = attn_forward(
|
| 317 |
+
self.attn,
|
| 318 |
+
hidden_states=[each[0] for each in img_variables],
|
| 319 |
+
hidden_states2=[each[0] for each in txt_variables],
|
| 320 |
+
position_embs=position_embs,
|
| 321 |
+
adapters=adapters,
|
| 322 |
+
**kwargs,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
text_out = []
|
| 326 |
+
for i in range(len(text_hidden_states)):
|
| 327 |
+
_, gate_msa, shift_mlp, scale_mlp, gate_mlp = txt_variables[i]
|
| 328 |
+
text_h = text_hidden_states[i] + txt_attn_output[i] * gate_msa.unsqueeze(1)
|
| 329 |
+
norm_h = (
|
| 330 |
+
self.norm2_context(text_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 331 |
+
)
|
| 332 |
+
text_h = self.ff_context(norm_h) * gate_mlp.unsqueeze(1) + text_h
|
| 333 |
+
text_out.append(clip_hidden_states(text_h))
|
| 334 |
+
|
| 335 |
+
image_out = []
|
| 336 |
+
for i in range(len(image_hidden_states)):
|
| 337 |
+
_, gate_msa, shift_mlp, scale_mlp, gate_mlp = img_variables[i]
|
| 338 |
+
image_h = (
|
| 339 |
+
image_hidden_states[i] + img_attn_output[i] * gate_msa.unsqueeze(1)
|
| 340 |
+
).to(image_hidden_states[i].dtype)
|
| 341 |
+
norm_h = self.norm2(image_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 342 |
+
with specify_lora((self.ff.net[2],), adapters[i + txt_n]):
|
| 343 |
+
image_h = image_h + self.ff(norm_h) * gate_mlp.unsqueeze(1)
|
| 344 |
+
image_out.append(clip_hidden_states(image_h))
|
| 345 |
+
return image_out, text_out
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def single_block_forward(
|
| 349 |
+
self,
|
| 350 |
+
hidden_states: List[torch.FloatTensor],
|
| 351 |
+
tembs: List[torch.FloatTensor],
|
| 352 |
+
adapters: List[str],
|
| 353 |
+
position_embs=None,
|
| 354 |
+
attn_forward=attn_forward,
|
| 355 |
+
**kwargs: dict,
|
| 356 |
+
):
|
| 357 |
+
mlp_hidden_states, gates = [[None for _ in hidden_states] for _ in range(2)]
|
| 358 |
+
|
| 359 |
+
hidden_state_norm = []
|
| 360 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 361 |
+
# [NOTE]!: This function's output is slightly DIFFERENT from the original
|
| 362 |
+
# FLUX version. In the original implementation, the gates were computed using
|
| 363 |
+
# the combined hidden states from both the image and text branches. Here, each
|
| 364 |
+
# branch computes its gate using only its own hidden state.
|
| 365 |
+
with specify_lora((self.norm.linear, self.proj_mlp), adapters[i]):
|
| 366 |
+
h_norm, gates[i] = self.norm(hidden_state, emb=tembs[i])
|
| 367 |
+
mlp_hidden_states[i] = self.act_mlp(self.proj_mlp(h_norm))
|
| 368 |
+
hidden_state_norm.append(h_norm)
|
| 369 |
+
|
| 370 |
+
attn_outputs = attn_forward(
|
| 371 |
+
self.attn, hidden_state_norm, adapters, position_embs=position_embs, **kwargs
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
h_out = []
|
| 375 |
+
for i in range(len(hidden_states)):
|
| 376 |
+
with specify_lora((self.proj_out,), adapters[i]):
|
| 377 |
+
h = torch.cat([attn_outputs[i], mlp_hidden_states[i]], dim=2)
|
| 378 |
+
h = gates[i].unsqueeze(1) * self.proj_out(h) + hidden_states[i]
|
| 379 |
+
h_out.append(clip_hidden_states(h))
|
| 380 |
+
|
| 381 |
+
return h_out
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def transformer_forward(
|
| 385 |
+
transformer: FluxTransformer2DModel,
|
| 386 |
+
image_features: List[torch.Tensor],
|
| 387 |
+
text_features: List[torch.Tensor] = None,
|
| 388 |
+
img_ids: List[torch.Tensor] = None,
|
| 389 |
+
txt_ids: List[torch.Tensor] = None,
|
| 390 |
+
pooled_projections: List[torch.Tensor] = None,
|
| 391 |
+
timesteps: List[torch.LongTensor] = None,
|
| 392 |
+
guidances: List[torch.Tensor] = None,
|
| 393 |
+
adapters: List[str] = None,
|
| 394 |
+
# Assign the function to be used for the forward pass
|
| 395 |
+
single_block_forward=single_block_forward,
|
| 396 |
+
block_forward=block_forward,
|
| 397 |
+
attn_forward=attn_forward,
|
| 398 |
+
**kwargs: dict,
|
| 399 |
+
):
|
| 400 |
+
self = transformer
|
| 401 |
+
txt_n = len(text_features) if text_features is not None else 0
|
| 402 |
+
|
| 403 |
+
adapters = adapters or [None] * (txt_n + len(image_features))
|
| 404 |
+
assert len(adapters) == len(timesteps)
|
| 405 |
+
|
| 406 |
+
# Preprocess the image_features
|
| 407 |
+
image_hidden_states = []
|
| 408 |
+
for i, image_feature in enumerate(image_features):
|
| 409 |
+
with specify_lora((self.x_embedder,), adapters[i + txt_n]):
|
| 410 |
+
image_hidden_states.append(self.x_embedder(image_feature))
|
| 411 |
+
|
| 412 |
+
# Preprocess the text_features
|
| 413 |
+
text_hidden_states = []
|
| 414 |
+
for text_feature in text_features:
|
| 415 |
+
text_hidden_states.append(self.context_embedder(text_feature))
|
| 416 |
+
|
| 417 |
+
# Prepare embeddings of (timestep, guidance, pooled_projections)
|
| 418 |
+
assert len(timesteps) == len(image_features) + len(text_features)
|
| 419 |
+
|
| 420 |
+
def get_temb(timestep, guidance, pooled_projection):
|
| 421 |
+
timestep = timestep.to(image_hidden_states[0].dtype) * 1000
|
| 422 |
+
if guidance is not None:
|
| 423 |
+
guidance = guidance.to(image_hidden_states[0].dtype) * 1000
|
| 424 |
+
return self.time_text_embed(timestep, guidance, pooled_projection)
|
| 425 |
+
else:
|
| 426 |
+
return self.time_text_embed(timestep, pooled_projection)
|
| 427 |
+
|
| 428 |
+
tembs = [get_temb(*each) for each in zip(timesteps, guidances, pooled_projections)]
|
| 429 |
+
|
| 430 |
+
# Prepare position embeddings for each token
|
| 431 |
+
position_embs = [self.pos_embed(each) for each in (*txt_ids, *img_ids)]
|
| 432 |
+
|
| 433 |
+
# Prepare the gradient checkpointing kwargs
|
| 434 |
+
gckpt_kwargs: Dict[str, Any] = (
|
| 435 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# dual branch blocks
|
| 439 |
+
for block in self.transformer_blocks:
|
| 440 |
+
block_kwargs = {
|
| 441 |
+
"self": block,
|
| 442 |
+
"image_hidden_states": image_hidden_states,
|
| 443 |
+
"text_hidden_states": text_hidden_states,
|
| 444 |
+
"tembs": tembs,
|
| 445 |
+
"position_embs": position_embs,
|
| 446 |
+
"adapters": adapters,
|
| 447 |
+
"attn_forward": attn_forward,
|
| 448 |
+
**kwargs,
|
| 449 |
+
}
|
| 450 |
+
if self.training and self.gradient_checkpointing:
|
| 451 |
+
image_hidden_states, text_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 452 |
+
block_forward, **block_kwargs, **gckpt_kwargs
|
| 453 |
+
)
|
| 454 |
+
else:
|
| 455 |
+
image_hidden_states, text_hidden_states = block_forward(**block_kwargs)
|
| 456 |
+
|
| 457 |
+
# combine image and text hidden states then pass through the single transformer blocks
|
| 458 |
+
all_hidden_states = [*text_hidden_states, *image_hidden_states]
|
| 459 |
+
for block in self.single_transformer_blocks:
|
| 460 |
+
block_kwargs = {
|
| 461 |
+
"self": block,
|
| 462 |
+
"hidden_states": all_hidden_states,
|
| 463 |
+
"tembs": tembs,
|
| 464 |
+
"position_embs": position_embs,
|
| 465 |
+
"adapters": adapters,
|
| 466 |
+
"attn_forward": attn_forward,
|
| 467 |
+
**kwargs,
|
| 468 |
+
}
|
| 469 |
+
if self.training and self.gradient_checkpointing:
|
| 470 |
+
all_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 471 |
+
single_block_forward, **block_kwargs, **gckpt_kwargs
|
| 472 |
+
)
|
| 473 |
+
else:
|
| 474 |
+
all_hidden_states = single_block_forward(**block_kwargs)
|
| 475 |
+
|
| 476 |
+
image_hidden_states = self.norm_out(all_hidden_states[txt_n], tembs[txt_n])
|
| 477 |
+
output = self.proj_out(image_hidden_states)
|
| 478 |
+
|
| 479 |
+
return (output,)
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
@torch.no_grad()
|
| 483 |
+
def generate(
|
| 484 |
+
pipeline: FluxPipeline,
|
| 485 |
+
prompt: Union[str, List[str]] = None,
|
| 486 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 487 |
+
height: Optional[int] = 512,
|
| 488 |
+
width: Optional[int] = 512,
|
| 489 |
+
num_inference_steps: int = 28,
|
| 490 |
+
timesteps: List[int] = None,
|
| 491 |
+
guidance_scale: float = 3.5,
|
| 492 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 493 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 494 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 495 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 496 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 497 |
+
output_type: Optional[str] = "pil",
|
| 498 |
+
return_dict: bool = True,
|
| 499 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 500 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 501 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 502 |
+
max_sequence_length: int = 512,
|
| 503 |
+
# Condition Parameters (Optional)
|
| 504 |
+
main_adapter: Optional[List[str]] = None,
|
| 505 |
+
conditions: List[Condition] = [],
|
| 506 |
+
image_guidance_scale: float = 1.0,
|
| 507 |
+
transformer_kwargs: Optional[Dict[str, Any]] = {},
|
| 508 |
+
kv_cache=False,
|
| 509 |
+
latent_mask=None,
|
| 510 |
+
global_T_Q=None,
|
| 511 |
+
global_T_K=None,
|
| 512 |
+
global_T_V=None,
|
| 513 |
+
**params: dict,
|
| 514 |
+
):
|
| 515 |
+
|
| 516 |
+
# Set global T_Q, T_K, T_V if provided
|
| 517 |
+
if global_T_Q is not None:
|
| 518 |
+
global T_Q
|
| 519 |
+
T_Q = global_T_Q
|
| 520 |
+
if global_T_K is not None:
|
| 521 |
+
global T_K
|
| 522 |
+
T_K = global_T_K
|
| 523 |
+
if global_T_V is not None:
|
| 524 |
+
global T_V
|
| 525 |
+
T_V = global_T_V
|
| 526 |
+
|
| 527 |
+
self = pipeline
|
| 528 |
+
|
| 529 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 530 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 531 |
+
|
| 532 |
+
# Check inputs. Raise error if not correct
|
| 533 |
+
self.check_inputs(
|
| 534 |
+
prompt,
|
| 535 |
+
prompt_2,
|
| 536 |
+
height,
|
| 537 |
+
width,
|
| 538 |
+
prompt_embeds=prompt_embeds,
|
| 539 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 540 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 541 |
+
max_sequence_length=max_sequence_length,
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
self._guidance_scale = guidance_scale
|
| 545 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 546 |
+
|
| 547 |
+
# Define call parameters
|
| 548 |
+
if prompt is not None and isinstance(prompt, str):
|
| 549 |
+
batch_size = 1
|
| 550 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 551 |
+
batch_size = len(prompt)
|
| 552 |
+
else:
|
| 553 |
+
batch_size = prompt_embeds.shape[0]
|
| 554 |
+
|
| 555 |
+
device = self._execution_device
|
| 556 |
+
|
| 557 |
+
# Prepare prompt embeddings
|
| 558 |
+
(
|
| 559 |
+
prompt_embeds,
|
| 560 |
+
pooled_prompt_embeds,
|
| 561 |
+
text_ids,
|
| 562 |
+
) = self.encode_prompt(
|
| 563 |
+
prompt=prompt,
|
| 564 |
+
prompt_2=prompt_2,
|
| 565 |
+
prompt_embeds=prompt_embeds,
|
| 566 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 567 |
+
device=device,
|
| 568 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 569 |
+
max_sequence_length=max_sequence_length,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
# Prepare latent variables
|
| 573 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 574 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 575 |
+
batch_size * num_images_per_prompt,
|
| 576 |
+
num_channels_latents,
|
| 577 |
+
height,
|
| 578 |
+
width,
|
| 579 |
+
prompt_embeds.dtype,
|
| 580 |
+
device,
|
| 581 |
+
generator,
|
| 582 |
+
latents,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
if latent_mask is not None:
|
| 586 |
+
latent_mask = latent_mask.T.reshape(-1)
|
| 587 |
+
latents = latents[:, latent_mask]
|
| 588 |
+
latent_image_ids = latent_image_ids[latent_mask]
|
| 589 |
+
|
| 590 |
+
# Prepare conditions
|
| 591 |
+
c_latents, uc_latents, c_ids, c_timesteps = ([], [], [], [])
|
| 592 |
+
c_projections, c_guidances, c_adapters = ([], [], [])
|
| 593 |
+
complement_cond = None
|
| 594 |
+
for condition in conditions:
|
| 595 |
+
tokens, ids = condition.encode(self)
|
| 596 |
+
c_latents.append(tokens) # [batch_size, token_n, token_dim]
|
| 597 |
+
# Empty condition for unconditioned image
|
| 598 |
+
if image_guidance_scale != 1.0:
|
| 599 |
+
uc_latents.append(condition.encode(self, empty=True)[0])
|
| 600 |
+
c_ids.append(ids) # [token_n, id_dim(3)]
|
| 601 |
+
c_timesteps.append(torch.zeros([1], device=device))
|
| 602 |
+
c_projections.append(pooled_prompt_embeds)
|
| 603 |
+
c_guidances.append(torch.ones([1], device=device))
|
| 604 |
+
c_adapters.append(condition.adapter)
|
| 605 |
+
# This complement_condition will be combined with the original image.
|
| 606 |
+
# See the token integration of OminiControl2 [https://arxiv.org/abs/2503.08280]
|
| 607 |
+
if condition.is_complement:
|
| 608 |
+
complement_cond = (tokens, ids)
|
| 609 |
+
|
| 610 |
+
# Prepare timesteps
|
| 611 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 612 |
+
image_seq_len = latents.shape[1]
|
| 613 |
+
mu = calculate_shift(
|
| 614 |
+
image_seq_len,
|
| 615 |
+
self.scheduler.config.base_image_seq_len,
|
| 616 |
+
self.scheduler.config.max_image_seq_len,
|
| 617 |
+
self.scheduler.config.base_shift,
|
| 618 |
+
self.scheduler.config.max_shift,
|
| 619 |
+
)
|
| 620 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 621 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
|
| 622 |
+
)
|
| 623 |
+
num_warmup_steps = max(
|
| 624 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
| 625 |
+
)
|
| 626 |
+
self._num_timesteps = len(timesteps)
|
| 627 |
+
|
| 628 |
+
if kv_cache:
|
| 629 |
+
attn_counter = 0
|
| 630 |
+
for module in self.transformer.modules():
|
| 631 |
+
if isinstance(module, Attention):
|
| 632 |
+
setattr(module, "cache_idx", attn_counter)
|
| 633 |
+
attn_counter += 1
|
| 634 |
+
kv_cond = [[[], []] for _ in range(attn_counter)]
|
| 635 |
+
kv_uncond = [[[], []] for _ in range(attn_counter)]
|
| 636 |
+
|
| 637 |
+
def clear_cache():
|
| 638 |
+
for storage in [kv_cond, kv_uncond]:
|
| 639 |
+
for kesy, values in storage:
|
| 640 |
+
kesy.clear()
|
| 641 |
+
values.clear()
|
| 642 |
+
|
| 643 |
+
branch_n = len(conditions) + 2
|
| 644 |
+
group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool)
|
| 645 |
+
# Disable the attention cross different condition branches
|
| 646 |
+
group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions)))
|
| 647 |
+
# Disable the attention from condition branches to image branch and text branch
|
| 648 |
+
if kv_cache:
|
| 649 |
+
group_mask[2:, :2] = False
|
| 650 |
+
|
| 651 |
+
# Denoising loop
|
| 652 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 653 |
+
for i, t in enumerate(timesteps):
|
| 654 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 655 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000
|
| 656 |
+
|
| 657 |
+
# handle guidance
|
| 658 |
+
if self.transformer.config.guidance_embeds:
|
| 659 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
| 660 |
+
guidance = guidance.expand(latents.shape[0])
|
| 661 |
+
else:
|
| 662 |
+
guidance, c_guidances = None, [None for _ in c_guidances]
|
| 663 |
+
|
| 664 |
+
if kv_cache:
|
| 665 |
+
mode = "write" if i == 0 else "read"
|
| 666 |
+
if mode == "write":
|
| 667 |
+
clear_cache()
|
| 668 |
+
use_cond = not (kv_cache) or mode == "write"
|
| 669 |
+
|
| 670 |
+
noise_pred = transformer_forward(
|
| 671 |
+
self.transformer,
|
| 672 |
+
image_features=[latents] + (c_latents if use_cond else []),
|
| 673 |
+
text_features=[prompt_embeds],
|
| 674 |
+
img_ids=[latent_image_ids] + (c_ids if use_cond else []),
|
| 675 |
+
txt_ids=[text_ids],
|
| 676 |
+
timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
|
| 677 |
+
pooled_projections=[pooled_prompt_embeds] * 2
|
| 678 |
+
+ (c_projections if use_cond else []),
|
| 679 |
+
guidances=[guidance] * 2 + (c_guidances if use_cond else []),
|
| 680 |
+
return_dict=False,
|
| 681 |
+
adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
|
| 682 |
+
cache_mode=mode if kv_cache else None,
|
| 683 |
+
cache_storage=kv_cond if kv_cache else None,
|
| 684 |
+
to_cache=[False, False, *[True] * len(c_latents)],
|
| 685 |
+
group_mask=group_mask,
|
| 686 |
+
**transformer_kwargs,
|
| 687 |
+
)[0]
|
| 688 |
+
|
| 689 |
+
if image_guidance_scale != 1.0:
|
| 690 |
+
unc_pred = transformer_forward(
|
| 691 |
+
self.transformer,
|
| 692 |
+
image_features=[latents] + (uc_latents if use_cond else []),
|
| 693 |
+
text_features=[prompt_embeds],
|
| 694 |
+
img_ids=[latent_image_ids] + (c_ids if use_cond else []),
|
| 695 |
+
txt_ids=[text_ids],
|
| 696 |
+
timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
|
| 697 |
+
pooled_projections=[pooled_prompt_embeds] * 2
|
| 698 |
+
+ (c_projections if use_cond else []),
|
| 699 |
+
guidances=[guidance] * 2 + (c_guidances if use_cond else []),
|
| 700 |
+
return_dict=False,
|
| 701 |
+
adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
|
| 702 |
+
cache_mode=mode if kv_cache else None,
|
| 703 |
+
cache_storage=kv_uncond if kv_cache else None,
|
| 704 |
+
to_cache=[False, False, *[True] * len(c_latents)],
|
| 705 |
+
**transformer_kwargs,
|
| 706 |
+
)[0]
|
| 707 |
+
|
| 708 |
+
noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
|
| 709 |
+
|
| 710 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 711 |
+
latents_dtype = latents.dtype
|
| 712 |
+
latents = self.scheduler.step(noise_pred, t, latents)[0]
|
| 713 |
+
|
| 714 |
+
if latents.dtype != latents_dtype:
|
| 715 |
+
if torch.backends.mps.is_available():
|
| 716 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 717 |
+
latents = latents.to(latents_dtype)
|
| 718 |
+
|
| 719 |
+
if callback_on_step_end is not None:
|
| 720 |
+
callback_kwargs = {}
|
| 721 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 722 |
+
callback_kwargs[k] = locals()[k]
|
| 723 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 724 |
+
|
| 725 |
+
latents = callback_outputs.pop("latents", latents)
|
| 726 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 727 |
+
|
| 728 |
+
# call the callback, if provided
|
| 729 |
+
if i == len(timesteps) - 1 or (
|
| 730 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 731 |
+
):
|
| 732 |
+
progress_bar.update()
|
| 733 |
+
|
| 734 |
+
if latent_mask is not None:
|
| 735 |
+
# Combine the generated latents and the complement condition
|
| 736 |
+
assert complement_cond is not None
|
| 737 |
+
comp_latent, comp_ids = complement_cond
|
| 738 |
+
all_ids = torch.cat([latent_image_ids, comp_ids], dim=0) # (Ta+Tc,3)
|
| 739 |
+
shape = (all_ids.max(dim=0).values + 1).to(torch.long) # (3,)
|
| 740 |
+
H, W = shape[1].item(), shape[2].item()
|
| 741 |
+
B, _, C = latents.shape
|
| 742 |
+
# Create a empty canvas
|
| 743 |
+
canvas = latents.new_zeros(B, H * W, C) # (B,H*W,C)
|
| 744 |
+
|
| 745 |
+
# Stash the latents and the complement condition
|
| 746 |
+
def _stash(canvas, tokens, ids, H, W) -> None:
|
| 747 |
+
B, T, C = tokens.shape
|
| 748 |
+
ids = ids.to(torch.long)
|
| 749 |
+
flat_idx = (ids[:, 1] * W + ids[:, 2]).to(torch.long)
|
| 750 |
+
canvas.view(B, -1, C).index_copy_(1, flat_idx, tokens)
|
| 751 |
+
|
| 752 |
+
_stash(canvas, latents, latent_image_ids, H, W)
|
| 753 |
+
_stash(canvas, comp_latent, comp_ids, H, W)
|
| 754 |
+
latents = canvas.view(B, H * W, C)
|
| 755 |
+
|
| 756 |
+
if output_type == "latent":
|
| 757 |
+
image = latents
|
| 758 |
+
else:
|
| 759 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 760 |
+
latents = (
|
| 761 |
+
latents / self.vae.config.scaling_factor
|
| 762 |
+
) + self.vae.config.shift_factor
|
| 763 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 764 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 765 |
+
|
| 766 |
+
# Offload all models
|
| 767 |
+
self.maybe_free_model_hooks()
|
| 768 |
+
|
| 769 |
+
if not return_dict:
|
| 770 |
+
return (image,)
|
| 771 |
+
|
| 772 |
+
return FluxPipelineOutput(images=image)
|
omini/pipeline/flux_omini_ablate_scale.py
ADDED
|
@@ -0,0 +1,748 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This version is for ablation study for the effect of scaling in LORA adapters.
|
| 3 |
+
|
| 4 |
+
The `generate` function is modified to include a `global_scale` parameter,
|
| 5 |
+
the `SCALE` variable is set globally at the start of generation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from typing import List, Union, Optional, Dict, Any, Callable, Type, Tuple
|
| 10 |
+
|
| 11 |
+
from diffusers.pipelines import FluxPipeline
|
| 12 |
+
from diffusers.pipelines.flux.pipeline_flux import (
|
| 13 |
+
FluxPipelineOutput,
|
| 14 |
+
FluxTransformer2DModel,
|
| 15 |
+
calculate_shift,
|
| 16 |
+
retrieve_timesteps,
|
| 17 |
+
np,
|
| 18 |
+
)
|
| 19 |
+
from diffusers.models.attention_processor import Attention, F
|
| 20 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 21 |
+
from transformers import pipeline
|
| 22 |
+
|
| 23 |
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
| 24 |
+
from accelerate.utils import is_torch_version
|
| 25 |
+
|
| 26 |
+
from contextlib import contextmanager
|
| 27 |
+
|
| 28 |
+
import cv2
|
| 29 |
+
|
| 30 |
+
from PIL import Image, ImageFilter
|
| 31 |
+
|
| 32 |
+
SCALE=1.
|
| 33 |
+
|
| 34 |
+
def seed_everything(seed: int = 42):
|
| 35 |
+
torch.backends.cudnn.deterministic = True
|
| 36 |
+
torch.manual_seed(seed)
|
| 37 |
+
np.random.seed(seed)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def clip_hidden_states(hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 41 |
+
if hidden_states.dtype == torch.float16:
|
| 42 |
+
hidden_states = hidden_states.clip(-65504, 65504)
|
| 43 |
+
return hidden_states
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def encode_images(pipeline: FluxPipeline, images: torch.Tensor):
|
| 47 |
+
"""
|
| 48 |
+
Encodes the images into tokens and ids for FLUX pipeline.
|
| 49 |
+
"""
|
| 50 |
+
images = pipeline.image_processor.preprocess(images)
|
| 51 |
+
images = images.to(pipeline.device).to(pipeline.dtype)
|
| 52 |
+
images = pipeline.vae.encode(images).latent_dist.sample()
|
| 53 |
+
images = (
|
| 54 |
+
images - pipeline.vae.config.shift_factor
|
| 55 |
+
) * pipeline.vae.config.scaling_factor
|
| 56 |
+
images_tokens = pipeline._pack_latents(images, *images.shape)
|
| 57 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
| 58 |
+
images.shape[0],
|
| 59 |
+
images.shape[2],
|
| 60 |
+
images.shape[3],
|
| 61 |
+
pipeline.device,
|
| 62 |
+
pipeline.dtype,
|
| 63 |
+
)
|
| 64 |
+
if images_tokens.shape[1] != images_ids.shape[0]:
|
| 65 |
+
images_ids = pipeline._prepare_latent_image_ids(
|
| 66 |
+
images.shape[0],
|
| 67 |
+
images.shape[2] // 2,
|
| 68 |
+
images.shape[3] // 2,
|
| 69 |
+
pipeline.device,
|
| 70 |
+
pipeline.dtype,
|
| 71 |
+
)
|
| 72 |
+
return images_tokens, images_ids
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
depth_pipe = None
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def convert_to_condition(
|
| 79 |
+
condition_type: str,
|
| 80 |
+
raw_img: Union[Image.Image, torch.Tensor],
|
| 81 |
+
blur_radius: Optional[int] = 5,
|
| 82 |
+
) -> Union[Image.Image, torch.Tensor]:
|
| 83 |
+
if condition_type == "depth":
|
| 84 |
+
global depth_pipe
|
| 85 |
+
depth_pipe = depth_pipe or pipeline(
|
| 86 |
+
task="depth-estimation",
|
| 87 |
+
model="LiheYoung/depth-anything-small-hf",
|
| 88 |
+
device="cpu", # Use "cpu" to enable parallel processing
|
| 89 |
+
)
|
| 90 |
+
source_image = raw_img.convert("RGB")
|
| 91 |
+
condition_img = depth_pipe(source_image)["depth"].convert("RGB")
|
| 92 |
+
return condition_img
|
| 93 |
+
elif condition_type == "canny":
|
| 94 |
+
img = np.array(raw_img)
|
| 95 |
+
edges = cv2.Canny(img, 100, 200)
|
| 96 |
+
edges = Image.fromarray(edges).convert("RGB")
|
| 97 |
+
return edges
|
| 98 |
+
elif condition_type == "coloring":
|
| 99 |
+
return raw_img.convert("L").convert("RGB")
|
| 100 |
+
elif condition_type == "deblurring":
|
| 101 |
+
condition_image = (
|
| 102 |
+
raw_img.convert("RGB")
|
| 103 |
+
.filter(ImageFilter.GaussianBlur(blur_radius))
|
| 104 |
+
.convert("RGB")
|
| 105 |
+
)
|
| 106 |
+
return condition_image
|
| 107 |
+
else:
|
| 108 |
+
print("Warning: Returning the raw image.")
|
| 109 |
+
return raw_img.convert("RGB")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Condition(object):
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
condition: Union[Image.Image, torch.Tensor],
|
| 116 |
+
adapter_setting: Union[str, dict],
|
| 117 |
+
position_delta=None,
|
| 118 |
+
position_scale=1.0,
|
| 119 |
+
latent_mask=None,
|
| 120 |
+
is_complement=False,
|
| 121 |
+
) -> None:
|
| 122 |
+
self.condition = condition
|
| 123 |
+
self.adapter = adapter_setting
|
| 124 |
+
self.position_delta = position_delta
|
| 125 |
+
self.position_scale = position_scale
|
| 126 |
+
self.latent_mask = (
|
| 127 |
+
latent_mask.T.reshape(-1) if latent_mask is not None else None
|
| 128 |
+
)
|
| 129 |
+
self.is_complement = is_complement
|
| 130 |
+
|
| 131 |
+
def encode(
|
| 132 |
+
self, pipe: FluxPipeline, empty: bool = False
|
| 133 |
+
) -> Tuple[torch.Tensor, torch.Tensor, int]:
|
| 134 |
+
condition_empty = Image.new("RGB", self.condition.size, (0, 0, 0))
|
| 135 |
+
tokens, ids = encode_images(pipe, condition_empty if empty else self.condition)
|
| 136 |
+
|
| 137 |
+
if self.position_delta is not None:
|
| 138 |
+
ids[:, 1] += self.position_delta[0]
|
| 139 |
+
ids[:, 2] += self.position_delta[1]
|
| 140 |
+
|
| 141 |
+
if self.position_scale != 1.0:
|
| 142 |
+
scale_bias = (self.position_scale - 1.0) / 2
|
| 143 |
+
ids[:, 1:] *= self.position_scale
|
| 144 |
+
ids[:, 1:] += scale_bias
|
| 145 |
+
|
| 146 |
+
if self.latent_mask is not None:
|
| 147 |
+
tokens = tokens[:, self.latent_mask]
|
| 148 |
+
ids = ids[self.latent_mask]
|
| 149 |
+
|
| 150 |
+
return tokens, ids
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@contextmanager
|
| 154 |
+
def specify_lora(lora_modules: List[BaseTunerLayer], specified_lora):
|
| 155 |
+
# Filter valid lora modules
|
| 156 |
+
valid_lora_modules = [m for m in lora_modules if isinstance(m, BaseTunerLayer)]
|
| 157 |
+
# Save original scales
|
| 158 |
+
original_scales = [
|
| 159 |
+
{
|
| 160 |
+
adapter: module.scaling[adapter]
|
| 161 |
+
for adapter in module.active_adapters
|
| 162 |
+
if adapter in module.scaling
|
| 163 |
+
}
|
| 164 |
+
for module in valid_lora_modules
|
| 165 |
+
]
|
| 166 |
+
# Enter context: adjust scaling
|
| 167 |
+
for module in valid_lora_modules:
|
| 168 |
+
for adapter in module.active_adapters:
|
| 169 |
+
if adapter in module.scaling:
|
| 170 |
+
module.scaling[adapter] = SCALE if adapter == specified_lora else 0
|
| 171 |
+
try:
|
| 172 |
+
yield
|
| 173 |
+
finally:
|
| 174 |
+
# Exit context: restore original scales
|
| 175 |
+
for module, scales in zip(valid_lora_modules, original_scales):
|
| 176 |
+
for adapter in module.active_adapters:
|
| 177 |
+
if adapter in module.scaling:
|
| 178 |
+
module.scaling[adapter] = scales[adapter]
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def attn_forward(
|
| 182 |
+
attn: Attention,
|
| 183 |
+
hidden_states: List[torch.FloatTensor],
|
| 184 |
+
adapters: List[str],
|
| 185 |
+
hidden_states2: Optional[List[torch.FloatTensor]] = [],
|
| 186 |
+
position_embs: Optional[List[torch.Tensor]] = None,
|
| 187 |
+
group_mask: Optional[torch.Tensor] = None,
|
| 188 |
+
cache_mode: Optional[str] = None,
|
| 189 |
+
# to determine whether to cache the keys and values for this branch
|
| 190 |
+
to_cache: Optional[List[torch.Tensor]] = None,
|
| 191 |
+
cache_storage: Optional[List[torch.Tensor]] = None,
|
| 192 |
+
**kwargs: dict,
|
| 193 |
+
) -> torch.FloatTensor:
|
| 194 |
+
bs, _, _ = hidden_states[0].shape
|
| 195 |
+
h2_n = len(hidden_states2)
|
| 196 |
+
|
| 197 |
+
queries, keys, values = [], [], []
|
| 198 |
+
|
| 199 |
+
# Prepare query, key, value for each encoder hidden state (text branch)
|
| 200 |
+
for i, hidden_state in enumerate(hidden_states2):
|
| 201 |
+
query = attn.add_q_proj(hidden_state)
|
| 202 |
+
key = attn.add_k_proj(hidden_state)
|
| 203 |
+
value = attn.add_v_proj(hidden_state)
|
| 204 |
+
|
| 205 |
+
head_dim = key.shape[-1] // attn.heads
|
| 206 |
+
reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
| 207 |
+
|
| 208 |
+
query, key, value = map(reshape_fn, (query, key, value))
|
| 209 |
+
query, key = attn.norm_added_q(query), attn.norm_added_k(key)
|
| 210 |
+
|
| 211 |
+
queries.append(query)
|
| 212 |
+
keys.append(key)
|
| 213 |
+
values.append(value)
|
| 214 |
+
|
| 215 |
+
# Prepare query, key, value for each hidden state (image branch)
|
| 216 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 217 |
+
with specify_lora((attn.to_q, attn.to_k, attn.to_v), adapters[i + h2_n]):
|
| 218 |
+
query = attn.to_q(hidden_state)
|
| 219 |
+
key = attn.to_k(hidden_state)
|
| 220 |
+
value = attn.to_v(hidden_state)
|
| 221 |
+
|
| 222 |
+
head_dim = key.shape[-1] // attn.heads
|
| 223 |
+
reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
|
| 224 |
+
|
| 225 |
+
query, key, value = map(reshape_fn, (query, key, value))
|
| 226 |
+
query, key = attn.norm_q(query), attn.norm_k(key)
|
| 227 |
+
|
| 228 |
+
queries.append(query)
|
| 229 |
+
keys.append(key)
|
| 230 |
+
values.append(value)
|
| 231 |
+
|
| 232 |
+
# Apply rotary embedding
|
| 233 |
+
if position_embs is not None:
|
| 234 |
+
queries = [apply_rotary_emb(q, position_embs[i]) for i, q in enumerate(queries)]
|
| 235 |
+
keys = [apply_rotary_emb(k, position_embs[i]) for i, k in enumerate(keys)]
|
| 236 |
+
|
| 237 |
+
if cache_mode == "write":
|
| 238 |
+
for i, (k, v) in enumerate(zip(keys, values)):
|
| 239 |
+
if to_cache[i]:
|
| 240 |
+
cache_storage[attn.cache_idx][0].append(k)
|
| 241 |
+
cache_storage[attn.cache_idx][1].append(v)
|
| 242 |
+
|
| 243 |
+
attn_outputs = []
|
| 244 |
+
for i, query in enumerate(queries):
|
| 245 |
+
keys_, values_ = [], []
|
| 246 |
+
# Add keys and values from other branches
|
| 247 |
+
for j, (k, v) in enumerate(zip(keys, values)):
|
| 248 |
+
if (group_mask is not None) and not (group_mask[i][j].item()):
|
| 249 |
+
continue
|
| 250 |
+
keys_.append(k)
|
| 251 |
+
values_.append(v)
|
| 252 |
+
if cache_mode == "read":
|
| 253 |
+
keys_.extend(cache_storage[attn.cache_idx][0])
|
| 254 |
+
values_.extend(cache_storage[attn.cache_idx][1])
|
| 255 |
+
# Add keys and values from cache TODO
|
| 256 |
+
# Attention computation
|
| 257 |
+
attn_output = F.scaled_dot_product_attention(
|
| 258 |
+
query, torch.cat(keys_, dim=2), torch.cat(values_, dim=2)
|
| 259 |
+
).to(query.dtype)
|
| 260 |
+
attn_output = attn_output.transpose(1, 2).reshape(bs, -1, attn.heads * head_dim)
|
| 261 |
+
attn_outputs.append(attn_output)
|
| 262 |
+
|
| 263 |
+
# Reshape attention output to match the original hidden states
|
| 264 |
+
h_out, h2_out = [], []
|
| 265 |
+
|
| 266 |
+
for i, hidden_state in enumerate(hidden_states2):
|
| 267 |
+
h2_out.append(attn.to_add_out(attn_outputs[i]))
|
| 268 |
+
|
| 269 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 270 |
+
h = attn_outputs[i + h2_n]
|
| 271 |
+
if getattr(attn, "to_out", None) is not None:
|
| 272 |
+
with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
|
| 273 |
+
h = attn.to_out[0](h)
|
| 274 |
+
h_out.append(h)
|
| 275 |
+
|
| 276 |
+
return (h_out, h2_out) if h2_n else h_out
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def block_forward(
|
| 280 |
+
self,
|
| 281 |
+
image_hidden_states: List[torch.FloatTensor],
|
| 282 |
+
text_hidden_states: List[torch.FloatTensor],
|
| 283 |
+
tembs: List[torch.FloatTensor],
|
| 284 |
+
adapters: List[str],
|
| 285 |
+
position_embs=None,
|
| 286 |
+
attn_forward=attn_forward,
|
| 287 |
+
**kwargs: dict,
|
| 288 |
+
):
|
| 289 |
+
txt_n = len(text_hidden_states)
|
| 290 |
+
|
| 291 |
+
img_variables, txt_variables = [], []
|
| 292 |
+
|
| 293 |
+
for i, text_h in enumerate(text_hidden_states):
|
| 294 |
+
txt_variables.append(self.norm1_context(text_h, emb=tembs[i]))
|
| 295 |
+
|
| 296 |
+
for i, image_h in enumerate(image_hidden_states):
|
| 297 |
+
with specify_lora((self.norm1.linear,), adapters[i + txt_n]):
|
| 298 |
+
img_variables.append(self.norm1(image_h, emb=tembs[i + txt_n]))
|
| 299 |
+
|
| 300 |
+
# Attention.
|
| 301 |
+
img_attn_output, txt_attn_output = attn_forward(
|
| 302 |
+
self.attn,
|
| 303 |
+
hidden_states=[each[0] for each in img_variables],
|
| 304 |
+
hidden_states2=[each[0] for each in txt_variables],
|
| 305 |
+
position_embs=position_embs,
|
| 306 |
+
adapters=adapters,
|
| 307 |
+
**kwargs,
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
text_out = []
|
| 311 |
+
for i in range(len(text_hidden_states)):
|
| 312 |
+
_, gate_msa, shift_mlp, scale_mlp, gate_mlp = txt_variables[i]
|
| 313 |
+
text_h = text_hidden_states[i] + txt_attn_output[i] * gate_msa.unsqueeze(1)
|
| 314 |
+
norm_h = (
|
| 315 |
+
self.norm2_context(text_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 316 |
+
)
|
| 317 |
+
text_h = self.ff_context(norm_h) * gate_mlp.unsqueeze(1) + text_h
|
| 318 |
+
text_out.append(clip_hidden_states(text_h))
|
| 319 |
+
|
| 320 |
+
image_out = []
|
| 321 |
+
for i in range(len(image_hidden_states)):
|
| 322 |
+
_, gate_msa, shift_mlp, scale_mlp, gate_mlp = img_variables[i]
|
| 323 |
+
image_h = (
|
| 324 |
+
image_hidden_states[i] + img_attn_output[i] * gate_msa.unsqueeze(1)
|
| 325 |
+
).to(image_hidden_states[i].dtype)
|
| 326 |
+
norm_h = self.norm2(image_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| 327 |
+
with specify_lora((self.ff.net[2],), adapters[i + txt_n]):
|
| 328 |
+
image_h = image_h + self.ff(norm_h) * gate_mlp.unsqueeze(1)
|
| 329 |
+
image_out.append(clip_hidden_states(image_h))
|
| 330 |
+
return image_out, text_out
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def single_block_forward(
|
| 334 |
+
self,
|
| 335 |
+
hidden_states: List[torch.FloatTensor],
|
| 336 |
+
tembs: List[torch.FloatTensor],
|
| 337 |
+
adapters: List[str],
|
| 338 |
+
position_embs=None,
|
| 339 |
+
attn_forward=attn_forward,
|
| 340 |
+
**kwargs: dict,
|
| 341 |
+
):
|
| 342 |
+
mlp_hidden_states, gates = [[None for _ in hidden_states] for _ in range(2)]
|
| 343 |
+
|
| 344 |
+
hidden_state_norm = []
|
| 345 |
+
for i, hidden_state in enumerate(hidden_states):
|
| 346 |
+
# [NOTE]!: This function's output is slightly DIFFERENT from the original
|
| 347 |
+
# FLUX version. In the original implementation, the gates were computed using
|
| 348 |
+
# the combined hidden states from both the image and text branches. Here, each
|
| 349 |
+
# branch computes its gate using only its own hidden state.
|
| 350 |
+
with specify_lora((self.norm.linear, self.proj_mlp), adapters[i]):
|
| 351 |
+
h_norm, gates[i] = self.norm(hidden_state, emb=tembs[i])
|
| 352 |
+
mlp_hidden_states[i] = self.act_mlp(self.proj_mlp(h_norm))
|
| 353 |
+
hidden_state_norm.append(h_norm)
|
| 354 |
+
|
| 355 |
+
attn_outputs = attn_forward(
|
| 356 |
+
self.attn, hidden_state_norm, adapters, position_embs=position_embs, **kwargs
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
h_out = []
|
| 360 |
+
for i in range(len(hidden_states)):
|
| 361 |
+
with specify_lora((self.proj_out,), adapters[i]):
|
| 362 |
+
h = torch.cat([attn_outputs[i], mlp_hidden_states[i]], dim=2)
|
| 363 |
+
h = gates[i].unsqueeze(1) * self.proj_out(h) + hidden_states[i]
|
| 364 |
+
h_out.append(clip_hidden_states(h))
|
| 365 |
+
|
| 366 |
+
return h_out
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def transformer_forward(
|
| 370 |
+
transformer: FluxTransformer2DModel,
|
| 371 |
+
image_features: List[torch.Tensor],
|
| 372 |
+
text_features: List[torch.Tensor] = None,
|
| 373 |
+
img_ids: List[torch.Tensor] = None,
|
| 374 |
+
txt_ids: List[torch.Tensor] = None,
|
| 375 |
+
pooled_projections: List[torch.Tensor] = None,
|
| 376 |
+
timesteps: List[torch.LongTensor] = None,
|
| 377 |
+
guidances: List[torch.Tensor] = None,
|
| 378 |
+
adapters: List[str] = None,
|
| 379 |
+
# Assign the function to be used for the forward pass
|
| 380 |
+
single_block_forward=single_block_forward,
|
| 381 |
+
block_forward=block_forward,
|
| 382 |
+
attn_forward=attn_forward,
|
| 383 |
+
**kwargs: dict,
|
| 384 |
+
):
|
| 385 |
+
self = transformer
|
| 386 |
+
txt_n = len(text_features) if text_features is not None else 0
|
| 387 |
+
|
| 388 |
+
adapters = adapters or [None] * (txt_n + len(image_features))
|
| 389 |
+
assert len(adapters) == len(timesteps)
|
| 390 |
+
|
| 391 |
+
# Preprocess the image_features
|
| 392 |
+
image_hidden_states = []
|
| 393 |
+
for i, image_feature in enumerate(image_features):
|
| 394 |
+
with specify_lora((self.x_embedder,), adapters[i + txt_n]):
|
| 395 |
+
image_hidden_states.append(self.x_embedder(image_feature))
|
| 396 |
+
|
| 397 |
+
# Preprocess the text_features
|
| 398 |
+
text_hidden_states = []
|
| 399 |
+
for text_feature in text_features:
|
| 400 |
+
text_hidden_states.append(self.context_embedder(text_feature))
|
| 401 |
+
|
| 402 |
+
# Prepare embeddings of (timestep, guidance, pooled_projections)
|
| 403 |
+
assert len(timesteps) == len(image_features) + len(text_features)
|
| 404 |
+
|
| 405 |
+
def get_temb(timestep, guidance, pooled_projection):
|
| 406 |
+
timestep = timestep.to(image_hidden_states[0].dtype) * 1000
|
| 407 |
+
if guidance is not None:
|
| 408 |
+
guidance = guidance.to(image_hidden_states[0].dtype) * 1000
|
| 409 |
+
return self.time_text_embed(timestep, guidance, pooled_projection)
|
| 410 |
+
else:
|
| 411 |
+
return self.time_text_embed(timestep, pooled_projection)
|
| 412 |
+
|
| 413 |
+
tembs = [get_temb(*each) for each in zip(timesteps, guidances, pooled_projections)]
|
| 414 |
+
|
| 415 |
+
# Prepare position embeddings for each token
|
| 416 |
+
position_embs = [self.pos_embed(each) for each in (*txt_ids, *img_ids)]
|
| 417 |
+
|
| 418 |
+
# Prepare the gradient checkpointing kwargs
|
| 419 |
+
gckpt_kwargs: Dict[str, Any] = (
|
| 420 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# dual branch blocks
|
| 424 |
+
for block in self.transformer_blocks:
|
| 425 |
+
block_kwargs = {
|
| 426 |
+
"self": block,
|
| 427 |
+
"image_hidden_states": image_hidden_states,
|
| 428 |
+
"text_hidden_states": text_hidden_states,
|
| 429 |
+
"tembs": tembs,
|
| 430 |
+
"position_embs": position_embs,
|
| 431 |
+
"adapters": adapters,
|
| 432 |
+
"attn_forward": attn_forward,
|
| 433 |
+
**kwargs,
|
| 434 |
+
}
|
| 435 |
+
if self.training and self.gradient_checkpointing:
|
| 436 |
+
image_hidden_states, text_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 437 |
+
block_forward, **block_kwargs, **gckpt_kwargs
|
| 438 |
+
)
|
| 439 |
+
else:
|
| 440 |
+
image_hidden_states, text_hidden_states = block_forward(**block_kwargs)
|
| 441 |
+
|
| 442 |
+
# combine image and text hidden states then pass through the single transformer blocks
|
| 443 |
+
all_hidden_states = [*text_hidden_states, *image_hidden_states]
|
| 444 |
+
for block in self.single_transformer_blocks:
|
| 445 |
+
block_kwargs = {
|
| 446 |
+
"self": block,
|
| 447 |
+
"hidden_states": all_hidden_states,
|
| 448 |
+
"tembs": tembs,
|
| 449 |
+
"position_embs": position_embs,
|
| 450 |
+
"adapters": adapters,
|
| 451 |
+
"attn_forward": attn_forward,
|
| 452 |
+
**kwargs,
|
| 453 |
+
}
|
| 454 |
+
if self.training and self.gradient_checkpointing:
|
| 455 |
+
all_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 456 |
+
single_block_forward, **block_kwargs, **gckpt_kwargs
|
| 457 |
+
)
|
| 458 |
+
else:
|
| 459 |
+
all_hidden_states = single_block_forward(**block_kwargs)
|
| 460 |
+
|
| 461 |
+
image_hidden_states = self.norm_out(all_hidden_states[txt_n], tembs[txt_n])
|
| 462 |
+
output = self.proj_out(image_hidden_states)
|
| 463 |
+
|
| 464 |
+
return (output,)
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
@torch.no_grad()
|
| 468 |
+
def generate(
|
| 469 |
+
pipeline: FluxPipeline,
|
| 470 |
+
prompt: Union[str, List[str]] = None,
|
| 471 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 472 |
+
height: Optional[int] = 512,
|
| 473 |
+
width: Optional[int] = 512,
|
| 474 |
+
num_inference_steps: int = 28,
|
| 475 |
+
timesteps: List[int] = None,
|
| 476 |
+
guidance_scale: float = 3.5,
|
| 477 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 478 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 479 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 480 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 481 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 482 |
+
output_type: Optional[str] = "pil",
|
| 483 |
+
return_dict: bool = True,
|
| 484 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 485 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 486 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 487 |
+
max_sequence_length: int = 512,
|
| 488 |
+
# Condition Parameters (Optional)
|
| 489 |
+
main_adapter: Optional[List[str]] = None,
|
| 490 |
+
conditions: List[Condition] = [],
|
| 491 |
+
image_guidance_scale: float = 1.0,
|
| 492 |
+
transformer_kwargs: Optional[Dict[str, Any]] = {},
|
| 493 |
+
kv_cache=False,
|
| 494 |
+
latent_mask=None,
|
| 495 |
+
global_scale=None,
|
| 496 |
+
**params: dict,
|
| 497 |
+
):
|
| 498 |
+
|
| 499 |
+
if global_scale is not None:
|
| 500 |
+
global SCALE
|
| 501 |
+
SCALE = global_scale
|
| 502 |
+
|
| 503 |
+
self = pipeline
|
| 504 |
+
|
| 505 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
| 506 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
| 507 |
+
|
| 508 |
+
# Check inputs. Raise error if not correct
|
| 509 |
+
self.check_inputs(
|
| 510 |
+
prompt,
|
| 511 |
+
prompt_2,
|
| 512 |
+
height,
|
| 513 |
+
width,
|
| 514 |
+
prompt_embeds=prompt_embeds,
|
| 515 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 516 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 517 |
+
max_sequence_length=max_sequence_length,
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
self._guidance_scale = guidance_scale
|
| 521 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 522 |
+
|
| 523 |
+
# Define call parameters
|
| 524 |
+
if prompt is not None and isinstance(prompt, str):
|
| 525 |
+
batch_size = 1
|
| 526 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 527 |
+
batch_size = len(prompt)
|
| 528 |
+
else:
|
| 529 |
+
batch_size = prompt_embeds.shape[0]
|
| 530 |
+
|
| 531 |
+
device = self._execution_device
|
| 532 |
+
|
| 533 |
+
# Prepare prompt embeddings
|
| 534 |
+
(
|
| 535 |
+
prompt_embeds,
|
| 536 |
+
pooled_prompt_embeds,
|
| 537 |
+
text_ids,
|
| 538 |
+
) = self.encode_prompt(
|
| 539 |
+
prompt=prompt,
|
| 540 |
+
prompt_2=prompt_2,
|
| 541 |
+
prompt_embeds=prompt_embeds,
|
| 542 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 543 |
+
device=device,
|
| 544 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 545 |
+
max_sequence_length=max_sequence_length,
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
# Prepare latent variables
|
| 549 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
| 550 |
+
latents, latent_image_ids = self.prepare_latents(
|
| 551 |
+
batch_size * num_images_per_prompt,
|
| 552 |
+
num_channels_latents,
|
| 553 |
+
height,
|
| 554 |
+
width,
|
| 555 |
+
prompt_embeds.dtype,
|
| 556 |
+
device,
|
| 557 |
+
generator,
|
| 558 |
+
latents,
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
if latent_mask is not None:
|
| 562 |
+
latent_mask = latent_mask.T.reshape(-1)
|
| 563 |
+
latents = latents[:, latent_mask]
|
| 564 |
+
latent_image_ids = latent_image_ids[latent_mask]
|
| 565 |
+
|
| 566 |
+
# Prepare conditions
|
| 567 |
+
c_latents, uc_latents, c_ids, c_timesteps = ([], [], [], [])
|
| 568 |
+
c_projections, c_guidances, c_adapters = ([], [], [])
|
| 569 |
+
complement_cond = None
|
| 570 |
+
for condition in conditions:
|
| 571 |
+
tokens, ids = condition.encode(self)
|
| 572 |
+
c_latents.append(tokens) # [batch_size, token_n, token_dim]
|
| 573 |
+
# Empty condition for unconditioned image
|
| 574 |
+
if image_guidance_scale != 1.0:
|
| 575 |
+
uc_latents.append(condition.encode(self, empty=True)[0])
|
| 576 |
+
c_ids.append(ids) # [token_n, id_dim(3)]
|
| 577 |
+
c_timesteps.append(torch.zeros([1], device=device))
|
| 578 |
+
c_projections.append(pooled_prompt_embeds)
|
| 579 |
+
c_guidances.append(torch.ones([1], device=device))
|
| 580 |
+
c_adapters.append(condition.adapter)
|
| 581 |
+
# This complement_condition will be combined with the original image.
|
| 582 |
+
# See the token integration of OminiControl2 [https://arxiv.org/abs/2503.08280]
|
| 583 |
+
if condition.is_complement:
|
| 584 |
+
complement_cond = (tokens, ids)
|
| 585 |
+
|
| 586 |
+
# Prepare timesteps
|
| 587 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 588 |
+
image_seq_len = latents.shape[1]
|
| 589 |
+
mu = calculate_shift(
|
| 590 |
+
image_seq_len,
|
| 591 |
+
self.scheduler.config.base_image_seq_len,
|
| 592 |
+
self.scheduler.config.max_image_seq_len,
|
| 593 |
+
self.scheduler.config.base_shift,
|
| 594 |
+
self.scheduler.config.max_shift,
|
| 595 |
+
)
|
| 596 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 597 |
+
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
|
| 598 |
+
)
|
| 599 |
+
num_warmup_steps = max(
|
| 600 |
+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
|
| 601 |
+
)
|
| 602 |
+
self._num_timesteps = len(timesteps)
|
| 603 |
+
|
| 604 |
+
if kv_cache:
|
| 605 |
+
attn_counter = 0
|
| 606 |
+
for module in self.transformer.modules():
|
| 607 |
+
if isinstance(module, Attention):
|
| 608 |
+
setattr(module, "cache_idx", attn_counter)
|
| 609 |
+
attn_counter += 1
|
| 610 |
+
kv_cond = [[[], []] for _ in range(attn_counter)]
|
| 611 |
+
kv_uncond = [[[], []] for _ in range(attn_counter)]
|
| 612 |
+
|
| 613 |
+
def clear_cache():
|
| 614 |
+
for storage in [kv_cond, kv_uncond]:
|
| 615 |
+
for kesy, values in storage:
|
| 616 |
+
kesy.clear()
|
| 617 |
+
values.clear()
|
| 618 |
+
|
| 619 |
+
branch_n = len(conditions) + 2
|
| 620 |
+
group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool)
|
| 621 |
+
# Disable the attention cross different condition branches
|
| 622 |
+
group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions)))
|
| 623 |
+
# Disable the attention from condition branches to image branch and text branch
|
| 624 |
+
if kv_cache:
|
| 625 |
+
group_mask[2:, :2] = False
|
| 626 |
+
|
| 627 |
+
# Denoising loop
|
| 628 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 629 |
+
for i, t in enumerate(timesteps):
|
| 630 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 631 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000
|
| 632 |
+
|
| 633 |
+
# handle guidance
|
| 634 |
+
if self.transformer.config.guidance_embeds:
|
| 635 |
+
guidance = torch.tensor([guidance_scale], device=device)
|
| 636 |
+
guidance = guidance.expand(latents.shape[0])
|
| 637 |
+
else:
|
| 638 |
+
guidance, c_guidances = None, [None for _ in c_guidances]
|
| 639 |
+
|
| 640 |
+
if kv_cache:
|
| 641 |
+
mode = "write" if i == 0 else "read"
|
| 642 |
+
if mode == "write":
|
| 643 |
+
clear_cache()
|
| 644 |
+
use_cond = not (kv_cache) or mode == "write"
|
| 645 |
+
|
| 646 |
+
noise_pred = transformer_forward(
|
| 647 |
+
self.transformer,
|
| 648 |
+
image_features=[latents] + (c_latents if use_cond else []),
|
| 649 |
+
text_features=[prompt_embeds],
|
| 650 |
+
img_ids=[latent_image_ids] + (c_ids if use_cond else []),
|
| 651 |
+
txt_ids=[text_ids],
|
| 652 |
+
timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
|
| 653 |
+
pooled_projections=[pooled_prompt_embeds] * 2
|
| 654 |
+
+ (c_projections if use_cond else []),
|
| 655 |
+
guidances=[guidance] * 2 + (c_guidances if use_cond else []),
|
| 656 |
+
return_dict=False,
|
| 657 |
+
adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
|
| 658 |
+
cache_mode=mode if kv_cache else None,
|
| 659 |
+
cache_storage=kv_cond if kv_cache else None,
|
| 660 |
+
to_cache=[False, False, *[True] * len(c_latents)],
|
| 661 |
+
group_mask=group_mask,
|
| 662 |
+
**transformer_kwargs,
|
| 663 |
+
)[0]
|
| 664 |
+
|
| 665 |
+
if image_guidance_scale != 1.0:
|
| 666 |
+
unc_pred = transformer_forward(
|
| 667 |
+
self.transformer,
|
| 668 |
+
image_features=[latents] + (uc_latents if use_cond else []),
|
| 669 |
+
text_features=[prompt_embeds],
|
| 670 |
+
img_ids=[latent_image_ids] + (c_ids if use_cond else []),
|
| 671 |
+
txt_ids=[text_ids],
|
| 672 |
+
timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
|
| 673 |
+
pooled_projections=[pooled_prompt_embeds] * 2
|
| 674 |
+
+ (c_projections if use_cond else []),
|
| 675 |
+
guidances=[guidance] * 2 + (c_guidances if use_cond else []),
|
| 676 |
+
return_dict=False,
|
| 677 |
+
adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
|
| 678 |
+
cache_mode=mode if kv_cache else None,
|
| 679 |
+
cache_storage=kv_uncond if kv_cache else None,
|
| 680 |
+
to_cache=[False, False, *[True] * len(c_latents)],
|
| 681 |
+
**transformer_kwargs,
|
| 682 |
+
)[0]
|
| 683 |
+
|
| 684 |
+
noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
|
| 685 |
+
|
| 686 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 687 |
+
latents_dtype = latents.dtype
|
| 688 |
+
latents = self.scheduler.step(noise_pred, t, latents)[0]
|
| 689 |
+
|
| 690 |
+
if latents.dtype != latents_dtype:
|
| 691 |
+
if torch.backends.mps.is_available():
|
| 692 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
| 693 |
+
latents = latents.to(latents_dtype)
|
| 694 |
+
|
| 695 |
+
if callback_on_step_end is not None:
|
| 696 |
+
callback_kwargs = {}
|
| 697 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 698 |
+
callback_kwargs[k] = locals()[k]
|
| 699 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 700 |
+
|
| 701 |
+
latents = callback_outputs.pop("latents", latents)
|
| 702 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 703 |
+
|
| 704 |
+
# call the callback, if provided
|
| 705 |
+
if i == len(timesteps) - 1 or (
|
| 706 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 707 |
+
):
|
| 708 |
+
progress_bar.update()
|
| 709 |
+
|
| 710 |
+
if latent_mask is not None:
|
| 711 |
+
# Combine the generated latents and the complement condition
|
| 712 |
+
assert complement_cond is not None
|
| 713 |
+
comp_latent, comp_ids = complement_cond
|
| 714 |
+
all_ids = torch.cat([latent_image_ids, comp_ids], dim=0) # (Ta+Tc,3)
|
| 715 |
+
shape = (all_ids.max(dim=0).values + 1).to(torch.long) # (3,)
|
| 716 |
+
H, W = shape[1].item(), shape[2].item()
|
| 717 |
+
B, _, C = latents.shape
|
| 718 |
+
# Create a empty canvas
|
| 719 |
+
canvas = latents.new_zeros(B, H * W, C) # (B,H*W,C)
|
| 720 |
+
|
| 721 |
+
# Stash the latents and the complement condition
|
| 722 |
+
def _stash(canvas, tokens, ids, H, W) -> None:
|
| 723 |
+
B, T, C = tokens.shape
|
| 724 |
+
ids = ids.to(torch.long)
|
| 725 |
+
flat_idx = (ids[:, 1] * W + ids[:, 2]).to(torch.long)
|
| 726 |
+
canvas.view(B, -1, C).index_copy_(1, flat_idx, tokens)
|
| 727 |
+
|
| 728 |
+
_stash(canvas, latents, latent_image_ids, H, W)
|
| 729 |
+
_stash(canvas, comp_latent, comp_ids, H, W)
|
| 730 |
+
latents = canvas.view(B, H * W, C)
|
| 731 |
+
|
| 732 |
+
if output_type == "latent":
|
| 733 |
+
image = latents
|
| 734 |
+
else:
|
| 735 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
| 736 |
+
latents = (
|
| 737 |
+
latents / self.vae.config.scaling_factor
|
| 738 |
+
) + self.vae.config.shift_factor
|
| 739 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 740 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 741 |
+
|
| 742 |
+
# Offload all models
|
| 743 |
+
self.maybe_free_model_hooks()
|
| 744 |
+
|
| 745 |
+
if not return_dict:
|
| 746 |
+
return (image,)
|
| 747 |
+
|
| 748 |
+
return FluxPipelineOutput(images=image)
|
omini/rotation/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .rotation_config import RotationConfig
|
| 2 |
+
from .layer import RotationLayer
|
| 3 |
+
from .model import RotationTuner
|
omini/rotation/layer.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from typing import Optional, Set
|
| 4 |
+
|
| 5 |
+
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
|
| 6 |
+
|
| 7 |
+
def inverse_2x2(matrices):
|
| 8 |
+
|
| 9 |
+
# Extract matrix elements
|
| 10 |
+
# matrices[..., 0, 0] corresponds to 'a' in [[a, b], [c, d]]
|
| 11 |
+
a = matrices[..., 0, 0]
|
| 12 |
+
b = matrices[..., 0, 1]
|
| 13 |
+
c = matrices[..., 1, 0]
|
| 14 |
+
d = matrices[..., 1, 1]
|
| 15 |
+
|
| 16 |
+
# Compute determinant
|
| 17 |
+
det = a * d - b * c
|
| 18 |
+
|
| 19 |
+
# Compute inverse using the formula:
|
| 20 |
+
# inv = (1/det) * [[d, -b], [-c, a]]
|
| 21 |
+
inv_det = 1.0 / det
|
| 22 |
+
|
| 23 |
+
# Create output tensor
|
| 24 |
+
inv_matrices = torch.empty_like(matrices)
|
| 25 |
+
inv_matrices[..., 0, 0] = d * inv_det
|
| 26 |
+
inv_matrices[..., 0, 1] = -b * inv_det
|
| 27 |
+
inv_matrices[..., 1, 0] = -c * inv_det
|
| 28 |
+
inv_matrices[..., 1, 1] = a * inv_det
|
| 29 |
+
|
| 30 |
+
return inv_matrices
|
| 31 |
+
|
| 32 |
+
class Rotation(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Rotation layer based on Cayley transformation for parameter-efficient fine-tuning.
|
| 35 |
+
|
| 36 |
+
This layer implements orthogonal fine-tuning through Cayley transformation:
|
| 37 |
+
h(x) = (I - A)^{-1} (I + A) x
|
| 38 |
+
|
| 39 |
+
where A = XY^T with X = [U; -V] and Y = [V; U]
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, r, dim, T=1.0, num_rotations=4):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.r = r
|
| 45 |
+
self.T = T
|
| 46 |
+
self.U = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.002, requires_grad=True)
|
| 47 |
+
self.V = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.0, requires_grad=True)
|
| 48 |
+
self.num_rotations = num_rotations
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
"""
|
| 53 |
+
Apply Cayley transformation to input x.
|
| 54 |
+
|
| 55 |
+
A = XY^T where X = [U; -V], Y = [V; U]
|
| 56 |
+
Cayley transformation: h(x) = (I - A)^{-1} (I + A) x
|
| 57 |
+
|
| 58 |
+
Uses Woodbury identity for efficient computation:
|
| 59 |
+
(I - XY^T)^{-1} = I + X (I - Y^T X)^{-1} Y^T
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
x: Input tensor of shape (..., dim)
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Transformed tensor of shape (..., dim)
|
| 66 |
+
"""
|
| 67 |
+
x_dtype = x.dtype
|
| 68 |
+
X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
|
| 69 |
+
Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
|
| 70 |
+
|
| 71 |
+
Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
|
| 72 |
+
I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).repeat(self.num_rotations, 1, 1)
|
| 73 |
+
I_minus_YX = I_2r - Y_T_X
|
| 74 |
+
|
| 75 |
+
if self.r == 1:
|
| 76 |
+
I_minus_YX_inv = inverse_2x2(I_minus_YX)
|
| 77 |
+
else:
|
| 78 |
+
# make it float32
|
| 79 |
+
I_minus_YX = I_minus_YX.to(torch.float32)
|
| 80 |
+
I_minus_YX_inv = torch.linalg.inv(I_minus_YX) # Shape: (num_rotations, 2r, 2r)
|
| 81 |
+
I_minus_YX_inv = I_minus_YX_inv.to(x_dtype)
|
| 82 |
+
|
| 83 |
+
Yx = torch.einsum("...d,nrd->...nr", x, Y) # Shape: (batch*seq_len, num_rotations, 2r)
|
| 84 |
+
I_minus_YX_inv_Yx = torch.einsum("nrr,...nr->...nr", I_minus_YX_inv, Yx)
|
| 85 |
+
|
| 86 |
+
second_term = torch.einsum("...nr,nrd->...nd", I_minus_YX_inv_Yx, X) # Shape: (batch*seq_len, num_rotations, dim)
|
| 87 |
+
second_term = second_term.sum(dim=-2) # Sum over rotations
|
| 88 |
+
|
| 89 |
+
output = x + 2 * second_term # Shape: (batch*seq_len, dim)
|
| 90 |
+
|
| 91 |
+
return output
|
| 92 |
+
|
| 93 |
+
def get_delta_weight(self):
|
| 94 |
+
"""
|
| 95 |
+
Compute the delta weight matrix induced by the rotation layer.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Delta weight matrix of shape (dim, dim)
|
| 99 |
+
"""
|
| 100 |
+
X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
|
| 101 |
+
Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
|
| 102 |
+
|
| 103 |
+
Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
|
| 104 |
+
I_2r = torch.eye(2 * self.r, device=X.device, dtype=X.dtype).repeat(self.num_rotations, 1, 1)
|
| 105 |
+
I_minus_YX = I_2r - Y_T_X
|
| 106 |
+
|
| 107 |
+
if self.r == 1:
|
| 108 |
+
I_minus_YX_inv = inverse_2x2(I_minus_YX)
|
| 109 |
+
I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
|
| 110 |
+
else:
|
| 111 |
+
I_minus_YX_inv_Y = torch.linalg.solve(I_minus_YX.to(torch.float32), Y.to(torch.float32)) # Shape: (num_rotations, 2r, dim)
|
| 112 |
+
I_minus_YX_inv_Y = I_minus_YX_inv_Y.to(X.dtype)
|
| 113 |
+
|
| 114 |
+
# I_minus_YX_float = I_minus_YX.float()
|
| 115 |
+
# I_minus_YX_inv = torch.linalg.inv(I_minus_YX_float) # Shape: (num_rotations, 2r, 2r)
|
| 116 |
+
# I_minus_YX_inv = I_minus_YX_inv.to(X.dtype)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
|
| 120 |
+
second_term = torch.einsum("nrd,nrD->ndD", X, I_minus_YX_inv_Y) # Shape: (num_rotations, dim, dim)
|
| 121 |
+
second_term = second_term.sum(dim=0)
|
| 122 |
+
total_delta_weight = 2 * second_term
|
| 123 |
+
return total_delta_weight
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class RotationLayer(BaseTunerLayer):
|
| 127 |
+
"""
|
| 128 |
+
Adapter-like wrapper that attaches Rotation modules to a base linear layer.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
adapter_layer_names: tuple[str, ...] = ("rotation",)
|
| 132 |
+
other_param_names: tuple[str, ...] = ("r", "T", "num_rotations", "scaling")
|
| 133 |
+
|
| 134 |
+
def __init__(self, base_layer: nn.Module, **kwargs):
|
| 135 |
+
# Let BaseTunerLayer do its init (it usually subclasses nn.Module)
|
| 136 |
+
super().__init__()
|
| 137 |
+
# store base layer and adapter containers
|
| 138 |
+
self.base_layer = base_layer
|
| 139 |
+
self.rotation = nn.ModuleDict() # mapping adapter_name -> Rotation module
|
| 140 |
+
self.scaling={} # default scaling per adapter
|
| 141 |
+
self._adapter_config = {} # store r, T, num_rotations per adapter
|
| 142 |
+
|
| 143 |
+
# flags (exposed in a simple way)
|
| 144 |
+
self._disable_adapters = False
|
| 145 |
+
self.merged_adapters: list[str] = []
|
| 146 |
+
self._cast_input_dtype_enabled = True
|
| 147 |
+
self.kwargs = kwargs
|
| 148 |
+
|
| 149 |
+
if isinstance(base_layer, nn.Linear):
|
| 150 |
+
self.in_features = base_layer.in_features
|
| 151 |
+
self.out_features = base_layer.out_features
|
| 152 |
+
else:
|
| 153 |
+
raise NotImplementedError("RotationLayer only supports nn.Linear base layers for now.")
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def _available_adapters(self) -> set[str]:
|
| 157 |
+
return set(self.rotation.keys())
|
| 158 |
+
|
| 159 |
+
@property
|
| 160 |
+
def disable_adapters(self) -> bool:
|
| 161 |
+
return self._disable_adapters
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def merged(self) -> bool:
|
| 165 |
+
return bool(self.merged_adapters)
|
| 166 |
+
|
| 167 |
+
@property
|
| 168 |
+
def active_adapters(self) -> list[str]:
|
| 169 |
+
# If some external mechanism sets active adapters, prefer it; else use all added adapters.
|
| 170 |
+
return getattr(self, "_active_adapters", list(self.rotation.keys()))
|
| 171 |
+
|
| 172 |
+
def get_base_layer(self) -> nn.Module:
|
| 173 |
+
return self.base_layer
|
| 174 |
+
|
| 175 |
+
def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 176 |
+
if not self._cast_input_dtype_enabled:
|
| 177 |
+
return x
|
| 178 |
+
return x.to(dtype)
|
| 179 |
+
|
| 180 |
+
def update_layer(
|
| 181 |
+
self,
|
| 182 |
+
adapter_name: str,
|
| 183 |
+
r: int,
|
| 184 |
+
T: float,
|
| 185 |
+
num_rotations: int,
|
| 186 |
+
**kwargs,
|
| 187 |
+
):
|
| 188 |
+
"""
|
| 189 |
+
Add / update a rotation adapter for this layer.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
if r <= 0:
|
| 193 |
+
raise ValueError(f"r must be positive, got {r}")
|
| 194 |
+
if num_rotations <= 0:
|
| 195 |
+
raise ValueError(f"num_rotations must be positive, got {num_rotations}")
|
| 196 |
+
|
| 197 |
+
rot = Rotation(r=r, dim=self.in_features, T=T, num_rotations=num_rotations)
|
| 198 |
+
self.rotation[adapter_name] = rot
|
| 199 |
+
self.scaling[adapter_name] = 1.0
|
| 200 |
+
self._adapter_config[adapter_name] = {"r": r, "T": T, "num_rotations": num_rotations}
|
| 201 |
+
|
| 202 |
+
# (optional) helper to set currently active adapters externally
|
| 203 |
+
def set_active_adapters(self, adapters: Optional[list[str]]):
|
| 204 |
+
if adapters is None:
|
| 205 |
+
if hasattr(self, "_active_adapters"):
|
| 206 |
+
delattr(self, "_active_adapters")
|
| 207 |
+
else:
|
| 208 |
+
self._active_adapters = adapters
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class Linear(nn.Module, RotationLayer):
|
| 212 |
+
"""
|
| 213 |
+
A linear layer with an integrated rotation layer for parameter-efficient fine-tuning.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def __init__(self,
|
| 217 |
+
base_layer: nn.Linear,
|
| 218 |
+
adapter_name: str,
|
| 219 |
+
r: int,
|
| 220 |
+
T: float,
|
| 221 |
+
num_rotations: int,
|
| 222 |
+
**kwargs):
|
| 223 |
+
|
| 224 |
+
super().__init__()
|
| 225 |
+
RotationLayer.__init__(self, base_layer=base_layer, **kwargs)
|
| 226 |
+
|
| 227 |
+
self._active_adapter = adapter_name
|
| 228 |
+
|
| 229 |
+
self.update_layer(
|
| 230 |
+
adapter_name=adapter_name,
|
| 231 |
+
r=r,
|
| 232 |
+
T=T,
|
| 233 |
+
num_rotations=num_rotations,
|
| 234 |
+
**kwargs,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None):
|
| 238 |
+
"""
|
| 239 |
+
Merge the adapter effect into the base layer weights:
|
| 240 |
+
W_merged = W @ R
|
| 241 |
+
where R = I + delta (delta returned by get_delta_weight()).
|
| 242 |
+
"""
|
| 243 |
+
adapter_names = check_adapters_to_merge(self, adapter_names)
|
| 244 |
+
|
| 245 |
+
if not adapter_names:
|
| 246 |
+
return
|
| 247 |
+
|
| 248 |
+
base_layer = self.get_base_layer()
|
| 249 |
+
orig_dtype = base_layer.weight.dtype
|
| 250 |
+
# base_layer.weight shape: (out_features, in_features)
|
| 251 |
+
W = base_layer.weight.data # (out, in)
|
| 252 |
+
|
| 253 |
+
for active_adapter in adapter_names:
|
| 254 |
+
if active_adapter not in self._available_adapters:
|
| 255 |
+
continue
|
| 256 |
+
delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
|
| 257 |
+
R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R # (in, in)
|
| 258 |
+
# merged W = W @ R
|
| 259 |
+
merged_W = W.to(R.dtype) @ R
|
| 260 |
+
if safe_merge and not torch.isfinite(merged_W).all():
|
| 261 |
+
raise ValueError("Merging resulted in non-finite weights. Aborting merge.")
|
| 262 |
+
|
| 263 |
+
base_layer.weight.data = merged_W.contiguous().to(orig_dtype)
|
| 264 |
+
# mark merged (so unmerge can restore by inverse)
|
| 265 |
+
self.merged_adapters.append(active_adapter)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def unmerge(self):
|
| 269 |
+
"""
|
| 270 |
+
Reverse merges in LIFO order (pop merged adapters and invert R).
|
| 271 |
+
"""
|
| 272 |
+
base_layer = self.get_base_layer()
|
| 273 |
+
orig_dtype = base_layer.weight.dtype
|
| 274 |
+
|
| 275 |
+
while self.merged_adapters:
|
| 276 |
+
active_adapter = self.merged_adapters.pop()
|
| 277 |
+
if active_adapter not in self._available_adapters:
|
| 278 |
+
continue
|
| 279 |
+
delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
|
| 280 |
+
R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R
|
| 281 |
+
R_inv = torch.linalg.inv(R)
|
| 282 |
+
merged_W = base_layer.weight.data.to(R.dtype)
|
| 283 |
+
unmerged_W = merged_W @ R_inv
|
| 284 |
+
base_layer.weight.data = unmerged_W.contiguous().to(orig_dtype)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 288 |
+
x_dtype = x.dtype
|
| 289 |
+
base_layer = self.get_base_layer()
|
| 290 |
+
|
| 291 |
+
if self.disable_adapters:
|
| 292 |
+
# if merged, unmerge to ensure base_layer produces original behavior
|
| 293 |
+
if self.merged:
|
| 294 |
+
self.unmerge()
|
| 295 |
+
return base_layer(x, *args, **kwargs).to(x_dtype)
|
| 296 |
+
|
| 297 |
+
if self.merged:
|
| 298 |
+
# if merged into base layer, just forward
|
| 299 |
+
return base_layer(x, *args, **kwargs).to(x_dtype)
|
| 300 |
+
|
| 301 |
+
# otherwise apply active adapters (transform inputs) then call base layer
|
| 302 |
+
for active_adapter in self.active_adapters:
|
| 303 |
+
if active_adapter not in self.rotation:
|
| 304 |
+
continue
|
| 305 |
+
rotation = self.rotation[active_adapter]
|
| 306 |
+
x = self._cast_input_dtype(x, rotation.U.dtype)
|
| 307 |
+
x = rotation(x)
|
| 308 |
+
|
| 309 |
+
return base_layer(x, *args, **kwargs).to(x_dtype)
|
| 310 |
+
|
| 311 |
+
def __repr__(self):
|
| 312 |
+
return f"rotation.{super().__repr__()}"
|
| 313 |
+
|
omini/rotation/layer_test.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from omini.rotation.layer import Linear, Rotation
|
| 4 |
+
|
| 5 |
+
def test_rotation_merge():
|
| 6 |
+
"""
|
| 7 |
+
Test that merging rotation adapter produces the same output as the unmerged version.
|
| 8 |
+
"""
|
| 9 |
+
print("="*60)
|
| 10 |
+
print("Testing Rotation Layer Merge")
|
| 11 |
+
print("="*60)
|
| 12 |
+
|
| 13 |
+
# Set random seed for reproducibility
|
| 14 |
+
torch.manual_seed(42)
|
| 15 |
+
|
| 16 |
+
# Configuration
|
| 17 |
+
in_features = 512
|
| 18 |
+
out_features = 1024
|
| 19 |
+
r = 4
|
| 20 |
+
num_rotations = 4
|
| 21 |
+
T = 1.0
|
| 22 |
+
batch_size = 8
|
| 23 |
+
seq_len = 16
|
| 24 |
+
|
| 25 |
+
# Create base linear layer
|
| 26 |
+
base_layer = nn.Linear(in_features, out_features, bias=True)
|
| 27 |
+
|
| 28 |
+
# Create rotation layer
|
| 29 |
+
rotation_layer = Linear(
|
| 30 |
+
base_layer=base_layer,
|
| 31 |
+
adapter_name="default",
|
| 32 |
+
r=r,
|
| 33 |
+
T=T,
|
| 34 |
+
num_rotations=num_rotations
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Create random input
|
| 38 |
+
x = torch.randn(batch_size, seq_len, in_features)
|
| 39 |
+
|
| 40 |
+
# Test 1: Forward pass before merge
|
| 41 |
+
print("\n" + "-"*60)
|
| 42 |
+
print("Test 1: Computing output BEFORE merge")
|
| 43 |
+
print("-"*60)
|
| 44 |
+
rotation_layer.eval()
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
output_before = rotation_layer(x)
|
| 47 |
+
|
| 48 |
+
print(f"Output shape: {output_before.shape}")
|
| 49 |
+
print(f"Output mean: {output_before.mean().item():.6f}")
|
| 50 |
+
print(f"Output std: {output_before.std().item():.6f}")
|
| 51 |
+
print(f"Output min: {output_before.min().item():.6f}")
|
| 52 |
+
print(f"Output max: {output_before.max().item():.6f}")
|
| 53 |
+
|
| 54 |
+
# Save original weight for verification
|
| 55 |
+
original_weight = base_layer.weight.data.clone()
|
| 56 |
+
|
| 57 |
+
# Test 2: Merge adapter
|
| 58 |
+
print("\n" + "-"*60)
|
| 59 |
+
print("Test 2: Merging adapter")
|
| 60 |
+
print("-"*60)
|
| 61 |
+
rotation_layer.merge(safe_merge=True, adapter_names=["default"])
|
| 62 |
+
print(f"✓ Adapter merged successfully")
|
| 63 |
+
print(f"✓ Merged adapters: {rotation_layer.merged_adapters}")
|
| 64 |
+
|
| 65 |
+
# Check that weights have changed
|
| 66 |
+
weight_diff = (base_layer.weight.data - original_weight).abs().max().item()
|
| 67 |
+
print(f"Max weight change: {weight_diff:.6e}")
|
| 68 |
+
|
| 69 |
+
# Test 3: Forward pass after merge
|
| 70 |
+
print("\n" + "-"*60)
|
| 71 |
+
print("Test 3: Computing output AFTER merge")
|
| 72 |
+
print("-"*60)
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
output_after = rotation_layer(x)
|
| 75 |
+
|
| 76 |
+
print(f"Output shape: {output_after.shape}")
|
| 77 |
+
print(f"Output mean: {output_after.mean().item():.6f}")
|
| 78 |
+
print(f"Output std: {output_after.std().item():.6f}")
|
| 79 |
+
print(f"Output min: {output_after.min().item():.6f}")
|
| 80 |
+
print(f"Output max: {output_after.max().item():.6f}")
|
| 81 |
+
|
| 82 |
+
# Test 4: Compare outputs
|
| 83 |
+
print("\n" + "-"*60)
|
| 84 |
+
print("Test 4: Comparing outputs")
|
| 85 |
+
print("-"*60)
|
| 86 |
+
|
| 87 |
+
# Compute differences
|
| 88 |
+
abs_diff = (output_after - output_before).abs()
|
| 89 |
+
rel_diff = abs_diff / (output_before.abs() + 1e-8)
|
| 90 |
+
|
| 91 |
+
max_abs_diff = abs_diff.max().item()
|
| 92 |
+
mean_abs_diff = abs_diff.mean().item()
|
| 93 |
+
max_rel_diff = rel_diff.max().item()
|
| 94 |
+
mean_rel_diff = rel_diff.mean().item()
|
| 95 |
+
|
| 96 |
+
print(f"Max absolute difference: {max_abs_diff:.6e}")
|
| 97 |
+
print(f"Mean absolute difference: {mean_abs_diff:.6e}")
|
| 98 |
+
print(f"Max relative difference: {max_rel_diff:.6e}")
|
| 99 |
+
print(f"Mean relative difference: {mean_rel_diff:.6e}")
|
| 100 |
+
|
| 101 |
+
# Check if outputs are close
|
| 102 |
+
atol = 1e-4 # Absolute tolerance
|
| 103 |
+
rtol = 1e-3 # Relative tolerance
|
| 104 |
+
|
| 105 |
+
are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
|
| 106 |
+
|
| 107 |
+
if are_close:
|
| 108 |
+
print(f"\n✅ PASS: Outputs are identical (within atol={atol}, rtol={rtol})")
|
| 109 |
+
else:
|
| 110 |
+
print(f"\n❌ FAIL: Outputs differ significantly")
|
| 111 |
+
print(f" Expected: atol < {atol}, rtol < {rtol}")
|
| 112 |
+
print(f" Got: max_abs_diff = {max_abs_diff:.6e}, max_rel_diff = {max_rel_diff:.6e}")
|
| 113 |
+
|
| 114 |
+
# Test 5: Unmerge and verify
|
| 115 |
+
print("\n" + "-"*60)
|
| 116 |
+
print("Test 5: Testing unmerge")
|
| 117 |
+
print("-"*60)
|
| 118 |
+
rotation_layer.unmerge()
|
| 119 |
+
print(f"✓ Adapter unmerged")
|
| 120 |
+
print(f"✓ Merged adapters: {rotation_layer.merged_adapters}")
|
| 121 |
+
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
output_unmerged = rotation_layer(x)
|
| 124 |
+
|
| 125 |
+
unmerge_diff = (output_unmerged - output_before).abs().max().item()
|
| 126 |
+
print(f"Max difference after unmerge: {unmerge_diff:.6e}")
|
| 127 |
+
|
| 128 |
+
unmerge_close = torch.allclose(output_before, output_unmerged, atol=atol, rtol=rtol)
|
| 129 |
+
if unmerge_close:
|
| 130 |
+
print(f"✅ PASS: Unmerge restored original behavior")
|
| 131 |
+
else:
|
| 132 |
+
print(f"❌ FAIL: Unmerge did not restore original behavior")
|
| 133 |
+
|
| 134 |
+
# Test 6: Verify weight restoration
|
| 135 |
+
weight_restored_diff = (base_layer.weight.data - original_weight).abs().max().item()
|
| 136 |
+
print(f"Max weight difference after unmerge: {weight_restored_diff:.6e}")
|
| 137 |
+
|
| 138 |
+
weight_restored = torch.allclose(base_layer.weight.data, original_weight, atol=1e-5)
|
| 139 |
+
if weight_restored:
|
| 140 |
+
print(f"✅ PASS: Original weights restored")
|
| 141 |
+
else:
|
| 142 |
+
print(f"❌ FAIL: Original weights not fully restored")
|
| 143 |
+
|
| 144 |
+
print("\n" + "="*60)
|
| 145 |
+
print("Test Summary")
|
| 146 |
+
print("="*60)
|
| 147 |
+
return are_close and unmerge_close and weight_restored
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def test_multiple_merges():
|
| 151 |
+
"""
|
| 152 |
+
Test merging and unmerging multiple times.
|
| 153 |
+
"""
|
| 154 |
+
print("\n" + "="*60)
|
| 155 |
+
print("Testing Multiple Merge/Unmerge Cycles")
|
| 156 |
+
print("="*60)
|
| 157 |
+
|
| 158 |
+
torch.manual_seed(42)
|
| 159 |
+
|
| 160 |
+
in_features = 256
|
| 161 |
+
out_features = 512
|
| 162 |
+
r = 4
|
| 163 |
+
num_rotations = 4
|
| 164 |
+
|
| 165 |
+
base_layer = nn.Linear(in_features, out_features, bias=True)
|
| 166 |
+
rotation_layer = Linear(
|
| 167 |
+
base_layer=base_layer,
|
| 168 |
+
adapter_name="default",
|
| 169 |
+
r=r,
|
| 170 |
+
T=1.0,
|
| 171 |
+
num_rotations=num_rotations
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
x = torch.randn(4, 8, in_features)
|
| 175 |
+
rotation_layer.eval()
|
| 176 |
+
|
| 177 |
+
# Get original output
|
| 178 |
+
with torch.no_grad():
|
| 179 |
+
original_output = rotation_layer(x)
|
| 180 |
+
|
| 181 |
+
# Test multiple cycles
|
| 182 |
+
all_passed = True
|
| 183 |
+
for cycle in range(3):
|
| 184 |
+
print(f"\nCycle {cycle + 1}:")
|
| 185 |
+
|
| 186 |
+
# Merge
|
| 187 |
+
rotation_layer.merge(safe_merge=True)
|
| 188 |
+
with torch.no_grad():
|
| 189 |
+
merged_output = rotation_layer(x)
|
| 190 |
+
|
| 191 |
+
merge_close = torch.allclose(original_output, merged_output, atol=1e-4, rtol=1e-3)
|
| 192 |
+
print(f" Merge: {'✅ PASS' if merge_close else '❌ FAIL'}")
|
| 193 |
+
|
| 194 |
+
# Unmerge
|
| 195 |
+
rotation_layer.unmerge()
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
unmerged_output = rotation_layer(x)
|
| 198 |
+
|
| 199 |
+
unmerge_close = torch.allclose(original_output, unmerged_output, atol=1e-4, rtol=1e-3)
|
| 200 |
+
print(f" Unmerge: {'✅ PASS' if unmerge_close else '❌ FAIL'}")
|
| 201 |
+
|
| 202 |
+
all_passed = all_passed and merge_close and unmerge_close
|
| 203 |
+
|
| 204 |
+
return all_passed
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def test_with_different_dtypes():
|
| 208 |
+
"""
|
| 209 |
+
Test merging with different data types.
|
| 210 |
+
"""
|
| 211 |
+
print("\n" + "="*60)
|
| 212 |
+
print("Testing Different Data Types")
|
| 213 |
+
print("="*60)
|
| 214 |
+
|
| 215 |
+
torch.manual_seed(42)
|
| 216 |
+
|
| 217 |
+
dtypes = [torch.float32, torch.float16, torch.bfloat16]
|
| 218 |
+
all_passed = True
|
| 219 |
+
|
| 220 |
+
for dtype in dtypes:
|
| 221 |
+
print(f"\nTesting with dtype: {dtype}")
|
| 222 |
+
|
| 223 |
+
in_features = 256
|
| 224 |
+
out_features = 512
|
| 225 |
+
r = 4
|
| 226 |
+
num_rotations = 4
|
| 227 |
+
|
| 228 |
+
base_layer = nn.Linear(in_features, out_features, bias=True)
|
| 229 |
+
base_layer = base_layer.to(dtype)
|
| 230 |
+
|
| 231 |
+
rotation_layer = Linear(
|
| 232 |
+
base_layer=base_layer,
|
| 233 |
+
adapter_name="default",
|
| 234 |
+
r=r,
|
| 235 |
+
T=1.0,
|
| 236 |
+
num_rotations=num_rotations
|
| 237 |
+
)
|
| 238 |
+
rotation_layer = rotation_layer.to(dtype)
|
| 239 |
+
|
| 240 |
+
x = torch.randn(4, 8, in_features, dtype=dtype)
|
| 241 |
+
rotation_layer.eval()
|
| 242 |
+
|
| 243 |
+
with torch.no_grad():
|
| 244 |
+
output_before = rotation_layer(x)
|
| 245 |
+
rotation_layer.merge(safe_merge=True)
|
| 246 |
+
output_after = rotation_layer(x)
|
| 247 |
+
|
| 248 |
+
# Adjust tolerances based on dtype
|
| 249 |
+
if dtype == torch.float32:
|
| 250 |
+
atol, rtol = 1e-5, 1e-4
|
| 251 |
+
elif dtype == torch.float16:
|
| 252 |
+
atol, rtol = 1e-2, 1e-2
|
| 253 |
+
else: # bfloat16
|
| 254 |
+
atol, rtol = 1e-2, 1e-2
|
| 255 |
+
|
| 256 |
+
are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
|
| 257 |
+
|
| 258 |
+
if are_close:
|
| 259 |
+
print(f" ✅ PASS")
|
| 260 |
+
else:
|
| 261 |
+
max_diff = (output_after - output_before).abs().max().item()
|
| 262 |
+
print(f" ❌ FAIL (max diff: {max_diff:.6e})")
|
| 263 |
+
|
| 264 |
+
all_passed = all_passed and are_close
|
| 265 |
+
|
| 266 |
+
return all_passed
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
print("\n" + "="*60)
|
| 271 |
+
print("ROTATION LAYER MERGE TEST SUITE")
|
| 272 |
+
print("="*60)
|
| 273 |
+
|
| 274 |
+
results = {}
|
| 275 |
+
|
| 276 |
+
# Run all tests
|
| 277 |
+
results["basic_merge"] = test_rotation_merge()
|
| 278 |
+
results["multiple_cycles"] = test_multiple_merges()
|
| 279 |
+
results["different_dtypes"] = test_with_different_dtypes()
|
| 280 |
+
|
| 281 |
+
# Print summary
|
| 282 |
+
print("\n" + "="*60)
|
| 283 |
+
print("FINAL SUMMARY")
|
| 284 |
+
print("="*60)
|
| 285 |
+
|
| 286 |
+
for test_name, passed in results.items():
|
| 287 |
+
status = "✅ PASS" if passed else "❌ FAIL"
|
| 288 |
+
print(f"{test_name}: {status}")
|
| 289 |
+
|
| 290 |
+
all_passed = all(results.values())
|
| 291 |
+
print("\n" + "="*60)
|
| 292 |
+
if all_passed:
|
| 293 |
+
print("🎉 ALL TESTS PASSED!")
|
| 294 |
+
else:
|
| 295 |
+
print("⚠️ SOME TESTS FAILED")
|
| 296 |
+
print("="*60)
|
omini/rotation/model.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from enum import Enum
|
| 6 |
+
from dataclasses import asdict
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer
|
| 11 |
+
|
| 12 |
+
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, _get_submodules
|
| 13 |
+
|
| 14 |
+
from .layer import RotationLayer, Linear
|
| 15 |
+
|
| 16 |
+
TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
|
| 17 |
+
|
| 18 |
+
class RotationTuner(BaseTuner):
|
| 19 |
+
|
| 20 |
+
prefix: str = "rotation_"
|
| 21 |
+
tuner_layer_class = RotationLayer
|
| 22 |
+
target_module_mapping = TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def _check_target_module_exists(rotation_config, key: str) -> bool:
|
| 27 |
+
return check_target_module_exists(rotation_config, key)
|
| 28 |
+
|
| 29 |
+
def _create_and_replace(
|
| 30 |
+
self,
|
| 31 |
+
rotation_config,
|
| 32 |
+
adapter_name: str,
|
| 33 |
+
target: nn.Module,
|
| 34 |
+
target_name: str,
|
| 35 |
+
parent: nn.Module,
|
| 36 |
+
current_key: str,
|
| 37 |
+
**optional_kwargs,
|
| 38 |
+
) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Create and replace a target module with a rotation-augmented version.
|
| 41 |
+
|
| 42 |
+
This method is called when an existing module is already a RotationLayer
|
| 43 |
+
and needs to have a new adapter added to it.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
rotation_config: Configuration for the rotation adapter
|
| 47 |
+
adapter_name: Name of the adapter to add
|
| 48 |
+
target: The target module to augment
|
| 49 |
+
target_name: Name of the target module
|
| 50 |
+
parent: Parent module containing the target
|
| 51 |
+
current_key: Full key path to the current module
|
| 52 |
+
**optional_kwargs: Additional optional arguments
|
| 53 |
+
|
| 54 |
+
Raises:
|
| 55 |
+
ValueError: If current_key is not provided
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
if current_key is None:
|
| 59 |
+
raise ValueError("current_key must be provided to create Rotation layer")
|
| 60 |
+
|
| 61 |
+
# Check if target is already a RotationLayer
|
| 62 |
+
if isinstance(target, RotationLayer):
|
| 63 |
+
target.update_layer(
|
| 64 |
+
adapter_name=adapter_name,
|
| 65 |
+
r=rotation_config.r,
|
| 66 |
+
T=rotation_config.T,
|
| 67 |
+
num_rotations=rotation_config.num_rotations,
|
| 68 |
+
)
|
| 69 |
+
else:
|
| 70 |
+
# Create new rotation layer
|
| 71 |
+
new_module = self._create_new_module(
|
| 72 |
+
rotation_config=rotation_config,
|
| 73 |
+
adapter_name=adapter_name,
|
| 74 |
+
target=target,
|
| 75 |
+
**optional_kwargs,
|
| 76 |
+
)
|
| 77 |
+
if new_module is not None:
|
| 78 |
+
self._replace_module(parent, target_name, new_module, target)
|
| 79 |
+
|
| 80 |
+
def _replace_module(self, parent, child_name, new_module, child):
|
| 81 |
+
|
| 82 |
+
setattr(parent, child_name, new_module)
|
| 83 |
+
|
| 84 |
+
# child layer wraps the original module, unpack it
|
| 85 |
+
if hasattr(child, "base_layer"):
|
| 86 |
+
child = child.base_layer
|
| 87 |
+
|
| 88 |
+
meta = torch.device("meta")
|
| 89 |
+
# dispatch to correct device
|
| 90 |
+
for name, module in new_module.named_modules():
|
| 91 |
+
if (self.prefix in name) or ("ranknum" in name):
|
| 92 |
+
if hasattr(child, "qweight"):
|
| 93 |
+
weight = child.qweight
|
| 94 |
+
elif hasattr(child, "W_q"):
|
| 95 |
+
weight = child.W_q
|
| 96 |
+
elif hasattr(child, "weight"):
|
| 97 |
+
weight = child.weight
|
| 98 |
+
elif getattr(child, "in_proj_weight", None) is not None: # MHA
|
| 99 |
+
weight = child.in_proj_weight
|
| 100 |
+
else:
|
| 101 |
+
weight = next(child.parameters())
|
| 102 |
+
if not any(p.device == meta for p in module.parameters()):
|
| 103 |
+
module.to(weight.device)
|
| 104 |
+
|
| 105 |
+
def _mark_only_adapters_as_trainable(self, model):
|
| 106 |
+
|
| 107 |
+
# First, freeze all parameters
|
| 108 |
+
for n, p in model.named_parameters():
|
| 109 |
+
if self.prefix not in n:
|
| 110 |
+
p.requires_grad = False
|
| 111 |
+
else:
|
| 112 |
+
p.requires_grad = True
|
| 113 |
+
|
| 114 |
+
# Handle bias parameters based on config
|
| 115 |
+
for active_adapter in self.active_adapters:
|
| 116 |
+
bias_config = self.peft_config[active_adapter].bias
|
| 117 |
+
|
| 118 |
+
if bias_config == "none":
|
| 119 |
+
continue
|
| 120 |
+
elif bias_config == "all":
|
| 121 |
+
# Enable all bias parameters
|
| 122 |
+
for n, p in model.named_parameters():
|
| 123 |
+
if "bias" in n:
|
| 124 |
+
p.requires_grad = True
|
| 125 |
+
elif bias_config == "rotation_only":
|
| 126 |
+
# Enable only bias in rotation layers
|
| 127 |
+
for name, m in model.named_modules():
|
| 128 |
+
if isinstance(m, RotationLayer):
|
| 129 |
+
if hasattr(m, "bias") and m.bias is not None:
|
| 130 |
+
m.bias.requires_grad = True
|
| 131 |
+
else:
|
| 132 |
+
raise NotImplementedError(
|
| 133 |
+
f"Requested bias configuration '{bias_config}' is not implemented. "
|
| 134 |
+
f"Supported values: 'none', 'all', 'rotation_only'"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def _create_new_module(
|
| 139 |
+
rotation_config,
|
| 140 |
+
adapter_name: str,
|
| 141 |
+
target: nn.Module,
|
| 142 |
+
**kwargs,
|
| 143 |
+
) -> Optional[nn.Module]:
|
| 144 |
+
"""
|
| 145 |
+
Create a new rotation-augmented module.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
rotation_config: Configuration for the rotation adapter
|
| 149 |
+
adapter_name: Name of the adapter
|
| 150 |
+
target: Base module to augment
|
| 151 |
+
**kwargs: Additional arguments
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
New RotationLayer module wrapping the target, or None if unsupported
|
| 155 |
+
"""
|
| 156 |
+
if isinstance(target, nn.Linear):
|
| 157 |
+
return Linear(
|
| 158 |
+
base_layer=target,
|
| 159 |
+
adapter_name=adapter_name,
|
| 160 |
+
r=rotation_config.r,
|
| 161 |
+
T=rotation_config.T,
|
| 162 |
+
num_rotations=rotation_config.num_rotations,
|
| 163 |
+
**kwargs,
|
| 164 |
+
)
|
| 165 |
+
else:
|
| 166 |
+
# Unsupported layer type
|
| 167 |
+
print(
|
| 168 |
+
f"Rotation layer does not support {type(target).__name__} yet. "
|
| 169 |
+
f"Skipping this module."
|
| 170 |
+
)
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def __getattr__(self, name: str):
|
| 175 |
+
"""Forward missing attributes to the wrapped module."""
|
| 176 |
+
try:
|
| 177 |
+
return super().__getattr__(name) # defer to nn.Module's logic
|
| 178 |
+
except AttributeError:
|
| 179 |
+
if name == "model": # see #1892: prevent infinite recursion if class is not initialized
|
| 180 |
+
raise
|
| 181 |
+
return getattr(self.model, name)
|
| 182 |
+
|
| 183 |
+
def get_peft_config_as_dict(self, inference: bool = False):
|
| 184 |
+
config_dict = {}
|
| 185 |
+
for key, value in self.peft_config.items():
|
| 186 |
+
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
|
| 187 |
+
if inference:
|
| 188 |
+
config["inference_mode"] = True
|
| 189 |
+
config_dict[key] = config
|
| 190 |
+
return config
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _set_adapter_layers(self, enabled=True):
|
| 194 |
+
for module in self.model.modules():
|
| 195 |
+
if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
|
| 196 |
+
module.enable_adapters(enabled)
|
| 197 |
+
|
| 198 |
+
def enable_adapter_layers(self) -> None:
|
| 199 |
+
"""Enable all adapters.
|
| 200 |
+
|
| 201 |
+
Call this if you have previously disabled all adapters and want to re-enable them.
|
| 202 |
+
"""
|
| 203 |
+
self._set_adapter_layers(enabled=True)
|
| 204 |
+
|
| 205 |
+
def disable_adapter_layers(self):
|
| 206 |
+
for active_adapter in self.active_adapters:
|
| 207 |
+
val = self.peft_config[active_adapter].bias
|
| 208 |
+
if val != "none":
|
| 209 |
+
msg = (
|
| 210 |
+
f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
|
| 211 |
+
"output as the base model would without adaption."
|
| 212 |
+
)
|
| 213 |
+
print(msg)
|
| 214 |
+
self._set_adapter_layers(enabled=False)
|
| 215 |
+
|
| 216 |
+
def set_adapter(self, adapter_name):
|
| 217 |
+
"""Set the active adapter(s).
|
| 218 |
+
|
| 219 |
+
Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
|
| 220 |
+
not desired, use the following code.
|
| 221 |
+
|
| 222 |
+
```py
|
| 223 |
+
>>> for name, param in model_peft.named_parameters():
|
| 224 |
+
... if ...: # some check on name (ex. if 'lora' in name)
|
| 225 |
+
... param.requires_grad = False
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
|
| 230 |
+
"""
|
| 231 |
+
for module in self.model.modules():
|
| 232 |
+
if isinstance(module, RotationLayer):
|
| 233 |
+
if module.merged:
|
| 234 |
+
print("Adapter cannot be set when the model is merged. Unmerging the model first.")
|
| 235 |
+
module.unmerge()
|
| 236 |
+
module.set_adapter(adapter_name)
|
| 237 |
+
self.active_adapter = adapter_name
|
| 238 |
+
|
| 239 |
+
def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
|
| 240 |
+
"""
|
| 241 |
+
Merge adapter weights into the base model weights.
|
| 242 |
+
|
| 243 |
+
This can speed up inference by eliminating the need for runtime
|
| 244 |
+
rotation computations.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
adapter_names: List of adapter names to merge. If None, merges all
|
| 248 |
+
active adapters.
|
| 249 |
+
"""
|
| 250 |
+
for module in self.model.modules():
|
| 251 |
+
if isinstance(module, RotationLayer):
|
| 252 |
+
module.merge(safe_merge=False, adapter_names=adapter_names)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def unmerge_adapter(self) -> None:
|
| 256 |
+
"""
|
| 257 |
+
Unmerge adapter weights from the base model weights.
|
| 258 |
+
|
| 259 |
+
This reverses the merge operation, restoring dynamic adapter behavior.
|
| 260 |
+
"""
|
| 261 |
+
for module in self.model.modules():
|
| 262 |
+
if isinstance(module, RotationLayer):
|
| 263 |
+
module.unmerge()
|
| 264 |
+
|
| 265 |
+
@staticmethod
|
| 266 |
+
def _prepare_adapter_config(peft_config, model_config):
|
| 267 |
+
|
| 268 |
+
if peft_config.target_modules is None:
|
| 269 |
+
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING:
|
| 270 |
+
raise ValueError("Please specify `target_modules` in `peft_config`")
|
| 271 |
+
peft_config.target_modules = set(
|
| 272 |
+
TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING[model_config["model_type"]]
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
return peft_config
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _check_new_adapter_config(self, config) -> None:
|
| 279 |
+
"""
|
| 280 |
+
Check the validity of a new adapter configuration.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
config: Configuration to validate
|
| 284 |
+
|
| 285 |
+
Raises:
|
| 286 |
+
ValueError: If configuration is invalid
|
| 287 |
+
"""
|
| 288 |
+
# Validate rank
|
| 289 |
+
if config.r <= 0:
|
| 290 |
+
raise ValueError(f"r must be positive, got {config.r}")
|
| 291 |
+
|
| 292 |
+
# Validate num_rotations
|
| 293 |
+
if config.num_rotations <= 0:
|
| 294 |
+
raise ValueError(
|
| 295 |
+
f"num_rotations must be positive, got {config.num_rotations}"
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# Validate bias configuration
|
| 300 |
+
valid_bias_configs = ["none", "all", "rotation_only"]
|
| 301 |
+
if hasattr(config, "bias") and config.bias not in valid_bias_configs:
|
| 302 |
+
raise ValueError(
|
| 303 |
+
f"Invalid bias configuration '{config.bias}'. "
|
| 304 |
+
f"Must be one of {valid_bias_configs}"
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _unload_and_optionally_merge(
|
| 309 |
+
self,
|
| 310 |
+
merge=True,
|
| 311 |
+
progressbar: bool = False,
|
| 312 |
+
safe_merge: bool = False,
|
| 313 |
+
adapter_names: Optional[list[str]] = None,
|
| 314 |
+
):
|
| 315 |
+
if merge:
|
| 316 |
+
self._check_merge_allowed()
|
| 317 |
+
|
| 318 |
+
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
| 319 |
+
desc = "Unloading " + ("and merging " if merge else "") + "model"
|
| 320 |
+
for key in tqdm(key_list, disable=not progressbar, desc=desc):
|
| 321 |
+
try:
|
| 322 |
+
parent, target, target_name = _get_submodules(self.model, key)
|
| 323 |
+
except AttributeError:
|
| 324 |
+
continue
|
| 325 |
+
with onload_layer(target):
|
| 326 |
+
if hasattr(target, "unload_and_optionally_merge_module"):
|
| 327 |
+
# if layers have special unloading method, like MultiheadAttention, use that
|
| 328 |
+
unloaded_module = target.unload_and_optionally_merge_module(
|
| 329 |
+
merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
|
| 330 |
+
)
|
| 331 |
+
self._replace_module(parent, target_name, unloaded_module, target)
|
| 332 |
+
elif hasattr(target, "base_layer"):
|
| 333 |
+
if merge:
|
| 334 |
+
target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
|
| 335 |
+
self._replace_module(parent, target_name, target.get_base_layer(), target)
|
| 336 |
+
|
| 337 |
+
return self.model
|
| 338 |
+
|
| 339 |
+
def delete_adapter(self, adapter_name: str) -> None:
|
| 340 |
+
"""
|
| 341 |
+
Deletes an existing adapter.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
adapter_name (str): Name of the adapter to be deleted.
|
| 345 |
+
"""
|
| 346 |
+
if adapter_name not in list(self.peft_config.keys()):
|
| 347 |
+
raise ValueError(f"Adapter {adapter_name} does not exist")
|
| 348 |
+
del self.peft_config[adapter_name]
|
| 349 |
+
|
| 350 |
+
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
|
| 351 |
+
new_adapter = None
|
| 352 |
+
for key in key_list:
|
| 353 |
+
_, target, _ = _get_submodules(self.model, key)
|
| 354 |
+
if isinstance(target, RotationLayer):
|
| 355 |
+
target.delete_adapter(adapter_name)
|
| 356 |
+
if new_adapter is None:
|
| 357 |
+
new_adapter = target.active_adapters[:]
|
| 358 |
+
|
| 359 |
+
self.active_adapter = new_adapter or []
|
| 360 |
+
self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter)
|
| 361 |
+
|
| 362 |
+
def merge_and_unload(
|
| 363 |
+
self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
|
| 364 |
+
) -> torch.nn.Module:
|
| 365 |
+
r"""
|
| 366 |
+
This method merges the OFT layers into the base model. This is needed if someone wants to use the base model as
|
| 367 |
+
a standalone model.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
progressbar (`bool`):
|
| 371 |
+
whether to show a progressbar indicating the unload and merge process
|
| 372 |
+
safe_merge (`bool`):
|
| 373 |
+
whether to activate the safe merging check to check if there is any potential Nan in the adapter
|
| 374 |
+
weights
|
| 375 |
+
adapter_names (`List[str]`, *optional*):
|
| 376 |
+
The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
|
| 377 |
+
to `None`.
|
| 378 |
+
|
| 379 |
+
"""
|
| 380 |
+
return self._unload_and_optionally_merge(
|
| 381 |
+
progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def unload(self) -> torch.nn.Module:
|
| 385 |
+
"""
|
| 386 |
+
Gets back the base model by removing all the oft modules without merging. This gives back the original base
|
| 387 |
+
model.
|
| 388 |
+
"""
|
| 389 |
+
return self._unload_and_optionally_merge(merge=False)
|
| 390 |
+
|
omini/rotation/rotation_config.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from peft.config import PeftConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class RotationConfig(PeftConfig):
|
| 8 |
+
"""
|
| 9 |
+
Configuration class for Rotation-based Parameter-Efficient Fine-Tuning.
|
| 10 |
+
|
| 11 |
+
This configuration stores all parameters needed to apply the Rotation method
|
| 12 |
+
(based on Cayley transformation) to a model's linear layers.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
r (`int`):
|
| 16 |
+
The rank parameter for the low-rank approximation in rotation matrices.
|
| 17 |
+
T (`float`, *optional*, defaults to 1.0):
|
| 18 |
+
Temperature parameter for the transformation.
|
| 19 |
+
num_rotations (`int`, *optional*, defaults to 4):
|
| 20 |
+
Number of rotation matrices to use in parallel.
|
| 21 |
+
target_modules (`Union[List[str], str]`):
|
| 22 |
+
Module names to apply rotation to (e.g., ["q_proj", "v_proj"]).
|
| 23 |
+
target_modules_to_skip (`Union[List[str], str]`, *optional*):
|
| 24 |
+
Module names to skip when applying rotation.
|
| 25 |
+
modules_to_save (`Union[List[str], str]`, *optional*):
|
| 26 |
+
Modules to save in addition to rotation parameters.
|
| 27 |
+
layers_to_transform (`Union[List[int], int]`, *optional*):
|
| 28 |
+
Layers to transform. If None, all layers matching target_modules are transformed.
|
| 29 |
+
apply_before (`bool`, *optional*, defaults to False):
|
| 30 |
+
If True, apply rotation before the base linear layer. If False, apply after.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
peft_type: str = field(default="ROTATION", init=False)
|
| 34 |
+
target_modules: Optional[List[str]] = field(
|
| 35 |
+
default=None,
|
| 36 |
+
metadata={
|
| 37 |
+
"help": "List of module names to apply rotation to (e.g., ['q_proj', 'v_proj', 'linear'])"
|
| 38 |
+
},
|
| 39 |
+
)
|
| 40 |
+
target_modules_to_skip: Optional[List[str]] = field(
|
| 41 |
+
default=None,
|
| 42 |
+
metadata={"help": "List of module names to skip when applying rotation"},
|
| 43 |
+
)
|
| 44 |
+
modules_to_save: Optional[List[str]] = field(
|
| 45 |
+
default=None,
|
| 46 |
+
metadata={"help": "List of modules to save in addition to rotation parameters"},
|
| 47 |
+
)
|
| 48 |
+
r: int = field(
|
| 49 |
+
default=8,
|
| 50 |
+
metadata={"help": "Rank parameter for low-rank approximation"},
|
| 51 |
+
)
|
| 52 |
+
T: float = field(
|
| 53 |
+
default=1.0,
|
| 54 |
+
metadata={"help": "Temperature parameter for Cayley transformation"},
|
| 55 |
+
)
|
| 56 |
+
num_rotations: int = field(
|
| 57 |
+
default=4,
|
| 58 |
+
metadata={"help": "Number of rotation matrices to use in parallel"},
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
bias: str = field(
|
| 62 |
+
default="none",
|
| 63 |
+
metadata={
|
| 64 |
+
"help": "Bias training configuration. Options: 'none', 'all', 'rotation_only'"
|
| 65 |
+
}
|
| 66 |
+
)
|
| 67 |
+
layers_to_transform: Optional[List[int]] = field(
|
| 68 |
+
default=None,
|
| 69 |
+
metadata={"help": "Layers to transform. If None, all matching layers are transformed"},
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def __post_init__(self):
|
| 73 |
+
self.peft_type = "ROTATION"
|
| 74 |
+
self.target_modules = (
|
| 75 |
+
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
|
| 76 |
+
)
|
| 77 |
+
self.target_modules_to_skip = (
|
| 78 |
+
set(self.target_modules_to_skip)
|
| 79 |
+
if isinstance(self.target_modules_to_skip, list)
|
| 80 |
+
else self.target_modules_to_skip
|
| 81 |
+
)
|
omini/train_flux/train_custom.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from torch.utils.data import DataLoader, Dataset
|
| 5 |
+
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
from .trainer import OminiModel, get_config, train
|
| 11 |
+
from ..pipeline.flux_omini import Condition, generate
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CustomDataset(Dataset):
|
| 15 |
+
def __getitem__(self, idx):
|
| 16 |
+
# TODO: Implement the logic to load your custom dataset
|
| 17 |
+
raise NotImplementedError("Custom dataset loading not implemented")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@torch.no_grad()
|
| 21 |
+
def test_function(model, save_path, file_name):
|
| 22 |
+
# TODO: Implement the logic to generate a sample using the model
|
| 23 |
+
raise NotImplementedError("Sample generation not implemented")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
# Initialize
|
| 28 |
+
config = get_config()
|
| 29 |
+
training_config = config["train"]
|
| 30 |
+
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
|
| 31 |
+
|
| 32 |
+
# Initialize custom dataset
|
| 33 |
+
dataset = CustomDataset()
|
| 34 |
+
|
| 35 |
+
# Initialize model
|
| 36 |
+
trainable_model = OminiModel(
|
| 37 |
+
flux_pipe_id=config["flux_path"],
|
| 38 |
+
lora_config=training_config["lora_config"],
|
| 39 |
+
device=f"cuda",
|
| 40 |
+
dtype=getattr(torch, config["dtype"]),
|
| 41 |
+
optimizer_config=training_config["optimizer"],
|
| 42 |
+
model_config=config.get("model", {}),
|
| 43 |
+
gradient_checkpointing=training_config.get("gradient_checkpointing", False),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
train(dataset, trainable_model, config, test_function)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|
omini/train_flux/train_multi_condition.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
from PIL import Image, ImageDraw
|
| 6 |
+
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
|
| 9 |
+
from .trainer import OminiModel, get_config, train
|
| 10 |
+
from ..pipeline.flux_omini import Condition, convert_to_condition, generate
|
| 11 |
+
from .train_spatial_alignment import ImageConditionDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ImageMultiConditionDataset(ImageConditionDataset):
|
| 15 |
+
def __getitem__(self, idx):
|
| 16 |
+
image = self.base_dataset[idx]["jpg"]
|
| 17 |
+
image = image.resize(self.target_size).convert("RGB")
|
| 18 |
+
description = self.base_dataset[idx]["json"]["prompt"]
|
| 19 |
+
|
| 20 |
+
condition_size = self.condition_size
|
| 21 |
+
position_scale = self.position_scale
|
| 22 |
+
|
| 23 |
+
condition_imgs, position_deltas = [], []
|
| 24 |
+
for c_type in self.condition_type:
|
| 25 |
+
condition_img, position_delta = self.__get_condition__(image, c_type)
|
| 26 |
+
condition_imgs.append(condition_img.convert("RGB"))
|
| 27 |
+
position_deltas.append(position_delta)
|
| 28 |
+
|
| 29 |
+
# Randomly drop text or image (for training)
|
| 30 |
+
drop_text = random.random() < self.drop_text_prob
|
| 31 |
+
drop_image = random.random() < self.drop_image_prob
|
| 32 |
+
|
| 33 |
+
if drop_text:
|
| 34 |
+
description = ""
|
| 35 |
+
if drop_image:
|
| 36 |
+
condition_imgs = [
|
| 37 |
+
Image.new("RGB", condition_size)
|
| 38 |
+
for _ in range(len(self.condition_type))
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
return_dict = {
|
| 42 |
+
"image": self.to_tensor(image),
|
| 43 |
+
"description": description,
|
| 44 |
+
**({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
for i, c_type in enumerate(self.condition_type):
|
| 48 |
+
return_dict[f"condition_{i}"] = self.to_tensor(condition_imgs[i])
|
| 49 |
+
return_dict[f"condition_type_{i}"] = self.condition_type[i]
|
| 50 |
+
return_dict[f"position_delta_{i}"] = position_deltas[i]
|
| 51 |
+
return_dict[f"position_scale_{i}"] = position_scale
|
| 52 |
+
|
| 53 |
+
return return_dict
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@torch.no_grad()
|
| 57 |
+
def test_function(model, save_path, file_name):
|
| 58 |
+
condition_size = model.training_config["dataset"]["condition_size"]
|
| 59 |
+
target_size = model.training_config["dataset"]["target_size"]
|
| 60 |
+
|
| 61 |
+
position_delta = model.training_config["dataset"].get("position_delta", [0, 0])
|
| 62 |
+
position_scale = model.training_config["dataset"].get("position_scale", 1.0)
|
| 63 |
+
|
| 64 |
+
condition_type = model.training_config["condition_type"]
|
| 65 |
+
test_list = []
|
| 66 |
+
|
| 67 |
+
condition_list = []
|
| 68 |
+
for i, c_type in enumerate(condition_type):
|
| 69 |
+
if c_type in ["canny", "coloring", "deblurring", "depth"]:
|
| 70 |
+
image = Image.open("assets/vase_hq.jpg")
|
| 71 |
+
image = image.resize(condition_size)
|
| 72 |
+
condition_img = convert_to_condition(c_type, image, 5)
|
| 73 |
+
elif c_type == "fill":
|
| 74 |
+
condition_img = image.resize(condition_size).convert("RGB")
|
| 75 |
+
w, h = image.size
|
| 76 |
+
x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
|
| 77 |
+
y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
|
| 78 |
+
mask = Image.new("L", image.size, 0)
|
| 79 |
+
draw = ImageDraw.Draw(mask)
|
| 80 |
+
draw.rectangle([x1, y1, x2, y2], fill=255)
|
| 81 |
+
if random.random() > 0.5:
|
| 82 |
+
mask = Image.eval(mask, lambda a: 255 - a)
|
| 83 |
+
condition_img = Image.composite(
|
| 84 |
+
image, Image.new("RGB", image.size, (0, 0, 0)), mask
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
raise NotImplementedError
|
| 88 |
+
condition = Condition(
|
| 89 |
+
condition_img,
|
| 90 |
+
model.adapter_names[i + 2],
|
| 91 |
+
position_delta,
|
| 92 |
+
position_scale,
|
| 93 |
+
)
|
| 94 |
+
condition_list.append(condition)
|
| 95 |
+
test_list.append((condition_list, "A beautiful vase on a table."))
|
| 96 |
+
os.makedirs(save_path, exist_ok=True)
|
| 97 |
+
for i, (condition, prompt) in enumerate(test_list):
|
| 98 |
+
generator = torch.Generator(device=model.device)
|
| 99 |
+
generator.manual_seed(42)
|
| 100 |
+
|
| 101 |
+
res = generate(
|
| 102 |
+
model.flux_pipe,
|
| 103 |
+
prompt=prompt,
|
| 104 |
+
conditions=condition_list,
|
| 105 |
+
height=target_size[0],
|
| 106 |
+
width=target_size[1],
|
| 107 |
+
generator=generator,
|
| 108 |
+
model_config=model.model_config,
|
| 109 |
+
kv_cache=model.model_config.get("independent_condition", False),
|
| 110 |
+
)
|
| 111 |
+
file_path = os.path.join(
|
| 112 |
+
save_path, f"{file_name}_{'|'.join(condition_type)}_{i}.jpg"
|
| 113 |
+
)
|
| 114 |
+
res.images[0].save(file_path)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
# Initialize
|
| 119 |
+
config = get_config()
|
| 120 |
+
training_config = config["train"]
|
| 121 |
+
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
|
| 122 |
+
|
| 123 |
+
# Initialize dataset
|
| 124 |
+
dataset = load_dataset(
|
| 125 |
+
"webdataset",
|
| 126 |
+
data_files={"train": training_config["dataset"]["urls"]},
|
| 127 |
+
split="train",
|
| 128 |
+
cache_dir="cache/t2i2m",
|
| 129 |
+
num_proc=32,
|
| 130 |
+
)
|
| 131 |
+
dataset = ImageMultiConditionDataset(
|
| 132 |
+
dataset,
|
| 133 |
+
condition_size=training_config["dataset"]["condition_size"],
|
| 134 |
+
target_size=training_config["dataset"]["target_size"],
|
| 135 |
+
condition_type=training_config["condition_type"],
|
| 136 |
+
drop_text_prob=training_config["dataset"]["drop_text_prob"],
|
| 137 |
+
drop_image_prob=training_config["dataset"]["drop_image_prob"],
|
| 138 |
+
position_scale=training_config["dataset"].get("position_scale", 1.0),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
cond_n = len(training_config["condition_type"])
|
| 142 |
+
|
| 143 |
+
# Initialize model
|
| 144 |
+
trainable_model = OminiModel(
|
| 145 |
+
flux_pipe_id=config["flux_path"],
|
| 146 |
+
lora_config=training_config["lora_config"],
|
| 147 |
+
device=f"cuda",
|
| 148 |
+
dtype=getattr(torch, config["dtype"]),
|
| 149 |
+
optimizer_config=training_config["optimizer"],
|
| 150 |
+
model_config=config.get("model", {}),
|
| 151 |
+
gradient_checkpointing=training_config.get("gradient_checkpointing", False),
|
| 152 |
+
adapter_names=[None, None, *["default"] * cond_n],
|
| 153 |
+
# In this setting, all the conditions are using the same LoRA adapter
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
train(dataset, trainable_model, config, test_function)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|