nvan15 commited on
Commit
b816a2c
·
verified ·
1 Parent(s): 16e46c5

Batch upload part 19

Browse files
Files changed (50) hide show
  1. nl_tasks/expsBOFT/seed42/ft/special_tokens_map.json +24 -0
  2. nl_tasks/expsBOFT/seed42/ft/tokenizer.json +0 -0
  3. nl_tasks/expsBOFT/seed42/ft/tokenizer.model +3 -0
  4. nl_tasks/expsBOFT/seed42/ft/tokenizer_config.json +43 -0
  5. nl_tasks/expsBOFT/seed42/ft2/README.md +205 -0
  6. nl_tasks/expsBOFT/seed42/ft2/adapter_config.json +27 -0
  7. nl_tasks/expsBOFT/seed42/ft2/adapter_model.safetensors +3 -0
  8. nl_tasks/expsBOFT/seed42/trainer_state.json +218 -0
  9. nl_tasks/expsBOFT/seed43/ft/special_tokens_map.json +24 -0
  10. nl_tasks/expsBOFT/seed43/ft/tokenizer.json +0 -0
  11. nl_tasks/expsBOFT/seed43/ft/tokenizer.model +3 -0
  12. nl_tasks/expsBOFT/seed43/ft/tokenizer_config.json +43 -0
  13. nl_tasks/expsBOFT/seed43/ft2/README.md +205 -0
  14. nl_tasks/expsBOFT/seed43/ft2/adapter_config.json +27 -0
  15. nl_tasks/expsBOFT/seed43/ft2/adapter_model.safetensors +3 -0
  16. nl_tasks/expsOFT/seed42/ft/special_tokens_map.json +24 -0
  17. nl_tasks/expsOFT/seed42/ft/tokenizer.json +0 -0
  18. nl_tasks/expsOFT/seed42/ft/tokenizer.model +3 -0
  19. nl_tasks/expsOFT/seed42/ft/tokenizer_config.json +43 -0
  20. nl_tasks/expsOFT/seed42/ft2/README.md +205 -0
  21. nl_tasks/expsOFT/seed42/ft2/adapter_config.json +31 -0
  22. nl_tasks/expsOFT/seed42/ft2/adapter_model.safetensors +3 -0
  23. nl_tasks/expsOFT/seed42/trainer_state.json +218 -0
  24. nl_tasks/expsOFT/seed43/ft/special_tokens_map.json +24 -0
  25. nl_tasks/expsOFT/seed43/ft/tokenizer.json +0 -0
  26. nl_tasks/expsOFT/seed43/ft/tokenizer.model +3 -0
  27. nl_tasks/expsOFT/seed43/ft/tokenizer_config.json +43 -0
  28. nl_tasks/expsOFT/seed43/ft2/README.md +205 -0
  29. nl_tasks/expsOFT/seed43/ft2/adapter_config.json +31 -0
  30. nl_tasks/expsOFT/seed43/ft2/adapter_model.safetensors +3 -0
  31. nl_tasks/expsOFT/seed43/trainer_state.json +218 -0
  32. nl_tasks/expsOFT/seed44/ft/special_tokens_map.json +24 -0
  33. nl_tasks/expsOFT/seed44/ft/tokenizer.json +0 -0
  34. nl_tasks/expsOFT/seed44/ft/tokenizer.model +3 -0
  35. nl_tasks/expsOFT/seed44/ft/tokenizer_config.json +43 -0
  36. nl_tasks/expsOFT/seed44/ft2/README.md +205 -0
  37. nl_tasks/expsOFT/seed44/ft2/adapter_config.json +31 -0
  38. nl_tasks/expsOFT/seed44/ft2/adapter_model.safetensors +3 -0
  39. nl_tasks/expsOFT/seed44/trainer_state.json +218 -0
  40. omini/__init__.py +0 -0
  41. omini/pipeline/flux_omini.py +734 -0
  42. omini/pipeline/flux_omini_ablate_qkv.py +772 -0
  43. omini/pipeline/flux_omini_ablate_scale.py +748 -0
  44. omini/rotation/__init__.py +3 -0
  45. omini/rotation/layer.py +313 -0
  46. omini/rotation/layer_test.py +296 -0
  47. omini/rotation/model.py +390 -0
  48. omini/rotation/rotation_config.py +81 -0
  49. omini/train_flux/train_custom.py +50 -0
  50. omini/train_flux/train_multi_condition.py +160 -0
nl_tasks/expsBOFT/seed42/ft/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
nl_tasks/expsBOFT/seed42/ft/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
nl_tasks/expsBOFT/seed42/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
nl_tasks/expsBOFT/seed42/ft/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "extra_special_tokens": {},
35
+ "legacy": false,
36
+ "model_max_length": 512,
37
+ "pad_token": "<unk>",
38
+ "padding_side": "right",
39
+ "sp_model_kwargs": {},
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
nl_tasks/expsBOFT/seed42/ft2/README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ tags:
5
+ - base_model:adapter:meta-llama/Llama-2-7b-hf
6
+ - transformers
7
+ ---
8
+
9
+ # Model Card for Model ID
10
+
11
+ <!-- Provide a quick summary of what the model is/does. -->
12
+
13
+
14
+
15
+ ## Model Details
16
+
17
+ ### Model Description
18
+
19
+ <!-- Provide a longer summary of what this model is. -->
20
+
21
+
22
+
23
+ - **Developed by:** [More Information Needed]
24
+ - **Funded by [optional]:** [More Information Needed]
25
+ - **Shared by [optional]:** [More Information Needed]
26
+ - **Model type:** [More Information Needed]
27
+ - **Language(s) (NLP):** [More Information Needed]
28
+ - **License:** [More Information Needed]
29
+ - **Finetuned from model [optional]:** [More Information Needed]
30
+
31
+ ### Model Sources [optional]
32
+
33
+ <!-- Provide the basic links for the model. -->
34
+
35
+ - **Repository:** [More Information Needed]
36
+ - **Paper [optional]:** [More Information Needed]
37
+ - **Demo [optional]:** [More Information Needed]
38
+
39
+ ## Uses
40
+
41
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
42
+
43
+ ### Direct Use
44
+
45
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
46
+
47
+ [More Information Needed]
48
+
49
+ ### Downstream Use [optional]
50
+
51
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
52
+
53
+ [More Information Needed]
54
+
55
+ ### Out-of-Scope Use
56
+
57
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
58
+
59
+ [More Information Needed]
60
+
61
+ ## Bias, Risks, and Limitations
62
+
63
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
64
+
65
+ [More Information Needed]
66
+
67
+ ### Recommendations
68
+
69
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
70
+
71
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
72
+
73
+ ## How to Get Started with the Model
74
+
75
+ Use the code below to get started with the model.
76
+
77
+ [More Information Needed]
78
+
79
+ ## Training Details
80
+
81
+ ### Training Data
82
+
83
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
84
+
85
+ [More Information Needed]
86
+
87
+ ### Training Procedure
88
+
89
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
90
+
91
+ #### Preprocessing [optional]
92
+
93
+ [More Information Needed]
94
+
95
+
96
+ #### Training Hyperparameters
97
+
98
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
99
+
100
+ #### Speeds, Sizes, Times [optional]
101
+
102
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
103
+
104
+ [More Information Needed]
105
+
106
+ ## Evaluation
107
+
108
+ <!-- This section describes the evaluation protocols and provides the results. -->
109
+
110
+ ### Testing Data, Factors & Metrics
111
+
112
+ #### Testing Data
113
+
114
+ <!-- This should link to a Dataset Card if possible. -->
115
+
116
+ [More Information Needed]
117
+
118
+ #### Factors
119
+
120
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
121
+
122
+ [More Information Needed]
123
+
124
+ #### Metrics
125
+
126
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
127
+
128
+ [More Information Needed]
129
+
130
+ ### Results
131
+
132
+ [More Information Needed]
133
+
134
+ #### Summary
135
+
136
+
137
+
138
+ ## Model Examination [optional]
139
+
140
+ <!-- Relevant interpretability work for the model goes here -->
141
+
142
+ [More Information Needed]
143
+
144
+ ## Environmental Impact
145
+
146
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
147
+
148
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
149
+
150
+ - **Hardware Type:** [More Information Needed]
151
+ - **Hours used:** [More Information Needed]
152
+ - **Cloud Provider:** [More Information Needed]
153
+ - **Compute Region:** [More Information Needed]
154
+ - **Carbon Emitted:** [More Information Needed]
155
+
156
+ ## Technical Specifications [optional]
157
+
158
+ ### Model Architecture and Objective
159
+
160
+ [More Information Needed]
161
+
162
+ ### Compute Infrastructure
163
+
164
+ [More Information Needed]
165
+
166
+ #### Hardware
167
+
168
+ [More Information Needed]
169
+
170
+ #### Software
171
+
172
+ [More Information Needed]
173
+
174
+ ## Citation [optional]
175
+
176
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
177
+
178
+ **BibTeX:**
179
+
180
+ [More Information Needed]
181
+
182
+ **APA:**
183
+
184
+ [More Information Needed]
185
+
186
+ ## Glossary [optional]
187
+
188
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
189
+
190
+ [More Information Needed]
191
+
192
+ ## More Information [optional]
193
+
194
+ [More Information Needed]
195
+
196
+ ## Model Card Authors [optional]
197
+
198
+ [More Information Needed]
199
+
200
+ ## Model Card Contact
201
+
202
+ [More Information Needed]
203
+ ### Framework versions
204
+
205
+ - PEFT 0.18.0
nl_tasks/expsBOFT/seed42/ft2/adapter_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": {
3
+ "base_model_class": "LlamaForCausalLM",
4
+ "parent_library": "transformers.models.llama.modeling_llama"
5
+ },
6
+ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
7
+ "bias": "none",
8
+ "boft_block_num": 0,
9
+ "boft_block_size": 16,
10
+ "boft_dropout": 0.05,
11
+ "boft_n_butterfly_factor": 2,
12
+ "exclude_modules": null,
13
+ "fan_in_fan_out": false,
14
+ "inference_mode": true,
15
+ "init_weights": true,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "modules_to_save": null,
19
+ "peft_type": "BOFT",
20
+ "peft_version": "0.18.0",
21
+ "revision": null,
22
+ "target_modules": [
23
+ "v_proj",
24
+ "q_proj"
25
+ ],
26
+ "task_type": null
27
+ }
nl_tasks/expsBOFT/seed42/ft2/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:584526a06a1f45f2f77e6a89a7201b05aa25a3d6be60f231b255a32c48c4b261
3
+ size 34619504
nl_tasks/expsBOFT/seed42/trainer_state.json ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 2.0,
6
+ "eval_steps": 500,
7
+ "global_step": 1250,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.08,
14
+ "grad_norm": 0.08375173062086105,
15
+ "learning_rate": 0.000392,
16
+ "loss": 0.5193,
17
+ "step": 50
18
+ },
19
+ {
20
+ "epoch": 0.16,
21
+ "grad_norm": 0.09268203377723694,
22
+ "learning_rate": 0.0007920000000000001,
23
+ "loss": 0.3316,
24
+ "step": 100
25
+ },
26
+ {
27
+ "epoch": 0.24,
28
+ "grad_norm": 0.08198747783899307,
29
+ "learning_rate": 0.0007964216926581925,
30
+ "loss": 0.304,
31
+ "step": 150
32
+ },
33
+ {
34
+ "epoch": 0.32,
35
+ "grad_norm": 0.0816216915845871,
36
+ "learning_rate": 0.0007854602918076551,
37
+ "loss": 0.2918,
38
+ "step": 200
39
+ },
40
+ {
41
+ "epoch": 0.4,
42
+ "grad_norm": 0.07457849383354187,
43
+ "learning_rate": 0.0007673184950396212,
44
+ "loss": 0.274,
45
+ "step": 250
46
+ },
47
+ {
48
+ "epoch": 0.48,
49
+ "grad_norm": 0.07685171067714691,
50
+ "learning_rate": 0.0007423342497022817,
51
+ "loss": 0.2687,
52
+ "step": 300
53
+ },
54
+ {
55
+ "epoch": 0.56,
56
+ "grad_norm": 0.07849128544330597,
57
+ "learning_rate": 0.0007109729650142636,
58
+ "loss": 0.2651,
59
+ "step": 350
60
+ },
61
+ {
62
+ "epoch": 0.64,
63
+ "grad_norm": 0.07266736030578613,
64
+ "learning_rate": 0.0006738188423714755,
65
+ "loss": 0.2575,
66
+ "step": 400
67
+ },
68
+ {
69
+ "epoch": 0.72,
70
+ "grad_norm": 0.06927025318145752,
71
+ "learning_rate": 0.0006315639927804526,
72
+ "loss": 0.2525,
73
+ "step": 450
74
+ },
75
+ {
76
+ "epoch": 0.8,
77
+ "grad_norm": 0.08536054193973541,
78
+ "learning_rate": 0.00058499554413983,
79
+ "loss": 0.2494,
80
+ "step": 500
81
+ },
82
+ {
83
+ "epoch": 0.88,
84
+ "grad_norm": 0.07602768391370773,
85
+ "learning_rate": 0.000534980978536894,
86
+ "loss": 0.2429,
87
+ "step": 550
88
+ },
89
+ {
90
+ "epoch": 0.96,
91
+ "grad_norm": 0.07055249065160751,
92
+ "learning_rate": 0.00048245197269763485,
93
+ "loss": 0.2457,
94
+ "step": 600
95
+ },
96
+ {
97
+ "epoch": 1.04,
98
+ "grad_norm": 0.07144515216350555,
99
+ "learning_rate": 0.00042838704261214224,
100
+ "loss": 0.2292,
101
+ "step": 650
102
+ },
103
+ {
104
+ "epoch": 1.12,
105
+ "grad_norm": 0.07937044650316238,
106
+ "learning_rate": 0.00037379331563313267,
107
+ "loss": 0.2169,
108
+ "step": 700
109
+ },
110
+ {
111
+ "epoch": 1.2,
112
+ "grad_norm": 0.07409252226352692,
113
+ "learning_rate": 0.00031968776959892677,
114
+ "loss": 0.2098,
115
+ "step": 750
116
+ },
117
+ {
118
+ "epoch": 1.28,
119
+ "grad_norm": 0.07844420522451401,
120
+ "learning_rate": 0.00026707828846051743,
121
+ "loss": 0.2145,
122
+ "step": 800
123
+ },
124
+ {
125
+ "epoch": 1.3599999999999999,
126
+ "grad_norm": 0.07791652530431747,
127
+ "learning_rate": 0.00021694488731055218,
128
+ "loss": 0.2082,
129
+ "step": 850
130
+ },
131
+ {
132
+ "epoch": 1.44,
133
+ "grad_norm": 0.0782908946275711,
134
+ "learning_rate": 0.00017022145655641685,
135
+ "loss": 0.2077,
136
+ "step": 900
137
+ },
138
+ {
139
+ "epoch": 1.52,
140
+ "grad_norm": 0.0826650932431221,
141
+ "learning_rate": 0.00012777836530893536,
142
+ "loss": 0.2137,
143
+ "step": 950
144
+ },
145
+ {
146
+ "epoch": 1.6,
147
+ "grad_norm": 0.0696156919002533,
148
+ "learning_rate": 9.040624805263558e-05,
149
+ "loss": 0.2076,
150
+ "step": 1000
151
+ },
152
+ {
153
+ "epoch": 1.6800000000000002,
154
+ "grad_norm": 0.06966507434844971,
155
+ "learning_rate": 5.880127662124091e-05,
156
+ "loss": 0.2108,
157
+ "step": 1050
158
+ },
159
+ {
160
+ "epoch": 1.76,
161
+ "grad_norm": 0.08326321095228195,
162
+ "learning_rate": 3.355219183361582e-05,
163
+ "loss": 0.2106,
164
+ "step": 1100
165
+ },
166
+ {
167
+ "epoch": 1.8399999999999999,
168
+ "grad_norm": 0.0792745053768158,
169
+ "learning_rate": 1.512933636625089e-05,
170
+ "loss": 0.2073,
171
+ "step": 1150
172
+ },
173
+ {
174
+ "epoch": 1.92,
175
+ "grad_norm": 0.07648582756519318,
176
+ "learning_rate": 3.8758931591217575e-06,
177
+ "loss": 0.209,
178
+ "step": 1200
179
+ },
180
+ {
181
+ "epoch": 2.0,
182
+ "grad_norm": 0.0787830799818039,
183
+ "learning_rate": 1.4925668450960217e-09,
184
+ "loss": 0.2124,
185
+ "step": 1250
186
+ },
187
+ {
188
+ "epoch": 2.0,
189
+ "step": 1250,
190
+ "total_flos": 1.62594677587968e+18,
191
+ "train_loss": 0.25041088790893556,
192
+ "train_runtime": 3370.9131,
193
+ "train_samples_per_second": 23.732,
194
+ "train_steps_per_second": 0.371
195
+ }
196
+ ],
197
+ "logging_steps": 50,
198
+ "max_steps": 1250,
199
+ "num_input_tokens_seen": 0,
200
+ "num_train_epochs": 2,
201
+ "save_steps": 0,
202
+ "stateful_callbacks": {
203
+ "TrainerControl": {
204
+ "args": {
205
+ "should_epoch_stop": false,
206
+ "should_evaluate": false,
207
+ "should_log": false,
208
+ "should_save": false,
209
+ "should_training_stop": false
210
+ },
211
+ "attributes": {}
212
+ }
213
+ },
214
+ "total_flos": 1.62594677587968e+18,
215
+ "train_batch_size": 32,
216
+ "trial_name": null,
217
+ "trial_params": null
218
+ }
nl_tasks/expsBOFT/seed43/ft/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
nl_tasks/expsBOFT/seed43/ft/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
nl_tasks/expsBOFT/seed43/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
nl_tasks/expsBOFT/seed43/ft/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "extra_special_tokens": {},
35
+ "legacy": false,
36
+ "model_max_length": 512,
37
+ "pad_token": "<unk>",
38
+ "padding_side": "right",
39
+ "sp_model_kwargs": {},
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
nl_tasks/expsBOFT/seed43/ft2/README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ tags:
5
+ - base_model:adapter:meta-llama/Llama-2-7b-hf
6
+ - transformers
7
+ ---
8
+
9
+ # Model Card for Model ID
10
+
11
+ <!-- Provide a quick summary of what the model is/does. -->
12
+
13
+
14
+
15
+ ## Model Details
16
+
17
+ ### Model Description
18
+
19
+ <!-- Provide a longer summary of what this model is. -->
20
+
21
+
22
+
23
+ - **Developed by:** [More Information Needed]
24
+ - **Funded by [optional]:** [More Information Needed]
25
+ - **Shared by [optional]:** [More Information Needed]
26
+ - **Model type:** [More Information Needed]
27
+ - **Language(s) (NLP):** [More Information Needed]
28
+ - **License:** [More Information Needed]
29
+ - **Finetuned from model [optional]:** [More Information Needed]
30
+
31
+ ### Model Sources [optional]
32
+
33
+ <!-- Provide the basic links for the model. -->
34
+
35
+ - **Repository:** [More Information Needed]
36
+ - **Paper [optional]:** [More Information Needed]
37
+ - **Demo [optional]:** [More Information Needed]
38
+
39
+ ## Uses
40
+
41
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
42
+
43
+ ### Direct Use
44
+
45
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
46
+
47
+ [More Information Needed]
48
+
49
+ ### Downstream Use [optional]
50
+
51
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
52
+
53
+ [More Information Needed]
54
+
55
+ ### Out-of-Scope Use
56
+
57
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
58
+
59
+ [More Information Needed]
60
+
61
+ ## Bias, Risks, and Limitations
62
+
63
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
64
+
65
+ [More Information Needed]
66
+
67
+ ### Recommendations
68
+
69
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
70
+
71
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
72
+
73
+ ## How to Get Started with the Model
74
+
75
+ Use the code below to get started with the model.
76
+
77
+ [More Information Needed]
78
+
79
+ ## Training Details
80
+
81
+ ### Training Data
82
+
83
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
84
+
85
+ [More Information Needed]
86
+
87
+ ### Training Procedure
88
+
89
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
90
+
91
+ #### Preprocessing [optional]
92
+
93
+ [More Information Needed]
94
+
95
+
96
+ #### Training Hyperparameters
97
+
98
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
99
+
100
+ #### Speeds, Sizes, Times [optional]
101
+
102
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
103
+
104
+ [More Information Needed]
105
+
106
+ ## Evaluation
107
+
108
+ <!-- This section describes the evaluation protocols and provides the results. -->
109
+
110
+ ### Testing Data, Factors & Metrics
111
+
112
+ #### Testing Data
113
+
114
+ <!-- This should link to a Dataset Card if possible. -->
115
+
116
+ [More Information Needed]
117
+
118
+ #### Factors
119
+
120
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
121
+
122
+ [More Information Needed]
123
+
124
+ #### Metrics
125
+
126
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
127
+
128
+ [More Information Needed]
129
+
130
+ ### Results
131
+
132
+ [More Information Needed]
133
+
134
+ #### Summary
135
+
136
+
137
+
138
+ ## Model Examination [optional]
139
+
140
+ <!-- Relevant interpretability work for the model goes here -->
141
+
142
+ [More Information Needed]
143
+
144
+ ## Environmental Impact
145
+
146
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
147
+
148
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
149
+
150
+ - **Hardware Type:** [More Information Needed]
151
+ - **Hours used:** [More Information Needed]
152
+ - **Cloud Provider:** [More Information Needed]
153
+ - **Compute Region:** [More Information Needed]
154
+ - **Carbon Emitted:** [More Information Needed]
155
+
156
+ ## Technical Specifications [optional]
157
+
158
+ ### Model Architecture and Objective
159
+
160
+ [More Information Needed]
161
+
162
+ ### Compute Infrastructure
163
+
164
+ [More Information Needed]
165
+
166
+ #### Hardware
167
+
168
+ [More Information Needed]
169
+
170
+ #### Software
171
+
172
+ [More Information Needed]
173
+
174
+ ## Citation [optional]
175
+
176
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
177
+
178
+ **BibTeX:**
179
+
180
+ [More Information Needed]
181
+
182
+ **APA:**
183
+
184
+ [More Information Needed]
185
+
186
+ ## Glossary [optional]
187
+
188
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
189
+
190
+ [More Information Needed]
191
+
192
+ ## More Information [optional]
193
+
194
+ [More Information Needed]
195
+
196
+ ## Model Card Authors [optional]
197
+
198
+ [More Information Needed]
199
+
200
+ ## Model Card Contact
201
+
202
+ [More Information Needed]
203
+ ### Framework versions
204
+
205
+ - PEFT 0.18.0
nl_tasks/expsBOFT/seed43/ft2/adapter_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": {
3
+ "base_model_class": "LlamaForCausalLM",
4
+ "parent_library": "transformers.models.llama.modeling_llama"
5
+ },
6
+ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
7
+ "bias": "none",
8
+ "boft_block_num": 0,
9
+ "boft_block_size": 16,
10
+ "boft_dropout": 0.05,
11
+ "boft_n_butterfly_factor": 2,
12
+ "exclude_modules": null,
13
+ "fan_in_fan_out": false,
14
+ "inference_mode": true,
15
+ "init_weights": true,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "modules_to_save": null,
19
+ "peft_type": "BOFT",
20
+ "peft_version": "0.18.0",
21
+ "revision": null,
22
+ "target_modules": [
23
+ "v_proj",
24
+ "q_proj"
25
+ ],
26
+ "task_type": null
27
+ }
nl_tasks/expsBOFT/seed43/ft2/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:584526a06a1f45f2f77e6a89a7201b05aa25a3d6be60f231b255a32c48c4b261
3
+ size 34619504
nl_tasks/expsOFT/seed42/ft/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
nl_tasks/expsOFT/seed42/ft/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
nl_tasks/expsOFT/seed42/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
nl_tasks/expsOFT/seed42/ft/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "extra_special_tokens": {},
35
+ "legacy": false,
36
+ "model_max_length": 512,
37
+ "pad_token": "<unk>",
38
+ "padding_side": "right",
39
+ "sp_model_kwargs": {},
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
nl_tasks/expsOFT/seed42/ft2/README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ tags:
5
+ - base_model:adapter:meta-llama/Llama-2-7b-hf
6
+ - transformers
7
+ ---
8
+
9
+ # Model Card for Model ID
10
+
11
+ <!-- Provide a quick summary of what the model is/does. -->
12
+
13
+
14
+
15
+ ## Model Details
16
+
17
+ ### Model Description
18
+
19
+ <!-- Provide a longer summary of what this model is. -->
20
+
21
+
22
+
23
+ - **Developed by:** [More Information Needed]
24
+ - **Funded by [optional]:** [More Information Needed]
25
+ - **Shared by [optional]:** [More Information Needed]
26
+ - **Model type:** [More Information Needed]
27
+ - **Language(s) (NLP):** [More Information Needed]
28
+ - **License:** [More Information Needed]
29
+ - **Finetuned from model [optional]:** [More Information Needed]
30
+
31
+ ### Model Sources [optional]
32
+
33
+ <!-- Provide the basic links for the model. -->
34
+
35
+ - **Repository:** [More Information Needed]
36
+ - **Paper [optional]:** [More Information Needed]
37
+ - **Demo [optional]:** [More Information Needed]
38
+
39
+ ## Uses
40
+
41
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
42
+
43
+ ### Direct Use
44
+
45
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
46
+
47
+ [More Information Needed]
48
+
49
+ ### Downstream Use [optional]
50
+
51
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
52
+
53
+ [More Information Needed]
54
+
55
+ ### Out-of-Scope Use
56
+
57
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
58
+
59
+ [More Information Needed]
60
+
61
+ ## Bias, Risks, and Limitations
62
+
63
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
64
+
65
+ [More Information Needed]
66
+
67
+ ### Recommendations
68
+
69
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
70
+
71
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
72
+
73
+ ## How to Get Started with the Model
74
+
75
+ Use the code below to get started with the model.
76
+
77
+ [More Information Needed]
78
+
79
+ ## Training Details
80
+
81
+ ### Training Data
82
+
83
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
84
+
85
+ [More Information Needed]
86
+
87
+ ### Training Procedure
88
+
89
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
90
+
91
+ #### Preprocessing [optional]
92
+
93
+ [More Information Needed]
94
+
95
+
96
+ #### Training Hyperparameters
97
+
98
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
99
+
100
+ #### Speeds, Sizes, Times [optional]
101
+
102
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
103
+
104
+ [More Information Needed]
105
+
106
+ ## Evaluation
107
+
108
+ <!-- This section describes the evaluation protocols and provides the results. -->
109
+
110
+ ### Testing Data, Factors & Metrics
111
+
112
+ #### Testing Data
113
+
114
+ <!-- This should link to a Dataset Card if possible. -->
115
+
116
+ [More Information Needed]
117
+
118
+ #### Factors
119
+
120
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
121
+
122
+ [More Information Needed]
123
+
124
+ #### Metrics
125
+
126
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
127
+
128
+ [More Information Needed]
129
+
130
+ ### Results
131
+
132
+ [More Information Needed]
133
+
134
+ #### Summary
135
+
136
+
137
+
138
+ ## Model Examination [optional]
139
+
140
+ <!-- Relevant interpretability work for the model goes here -->
141
+
142
+ [More Information Needed]
143
+
144
+ ## Environmental Impact
145
+
146
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
147
+
148
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
149
+
150
+ - **Hardware Type:** [More Information Needed]
151
+ - **Hours used:** [More Information Needed]
152
+ - **Cloud Provider:** [More Information Needed]
153
+ - **Compute Region:** [More Information Needed]
154
+ - **Carbon Emitted:** [More Information Needed]
155
+
156
+ ## Technical Specifications [optional]
157
+
158
+ ### Model Architecture and Objective
159
+
160
+ [More Information Needed]
161
+
162
+ ### Compute Infrastructure
163
+
164
+ [More Information Needed]
165
+
166
+ #### Hardware
167
+
168
+ [More Information Needed]
169
+
170
+ #### Software
171
+
172
+ [More Information Needed]
173
+
174
+ ## Citation [optional]
175
+
176
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
177
+
178
+ **BibTeX:**
179
+
180
+ [More Information Needed]
181
+
182
+ **APA:**
183
+
184
+ [More Information Needed]
185
+
186
+ ## Glossary [optional]
187
+
188
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
189
+
190
+ [More Information Needed]
191
+
192
+ ## More Information [optional]
193
+
194
+ [More Information Needed]
195
+
196
+ ## Model Card Authors [optional]
197
+
198
+ [More Information Needed]
199
+
200
+ ## Model Card Contact
201
+
202
+ [More Information Needed]
203
+ ### Framework versions
204
+
205
+ - PEFT 0.18.0
nl_tasks/expsOFT/seed42/ft2/adapter_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": {
3
+ "base_model_class": "LlamaForCausalLM",
4
+ "parent_library": "transformers.models.llama.modeling_llama"
5
+ },
6
+ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
7
+ "bias": "none",
8
+ "block_share": false,
9
+ "coft": false,
10
+ "eps": 6e-05,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_weights": true,
15
+ "layers_pattern": null,
16
+ "layers_to_transform": null,
17
+ "module_dropout": 0.05,
18
+ "modules_to_save": null,
19
+ "num_cayley_neumann_terms": 5,
20
+ "oft_block_size": 64,
21
+ "peft_type": "OFT",
22
+ "peft_version": "0.18.0",
23
+ "r": 0,
24
+ "revision": null,
25
+ "target_modules": [
26
+ "q_proj",
27
+ "v_proj"
28
+ ],
29
+ "task_type": null,
30
+ "use_cayley_neumann": true
31
+ }
nl_tasks/expsOFT/seed42/ft2/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d16378461c75d46a179539ea2223803c3af83b5ebb2dcc6face78c64e3ac4f9c
3
+ size 33038696
nl_tasks/expsOFT/seed42/trainer_state.json ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 2.0,
6
+ "eval_steps": 500,
7
+ "global_step": 1250,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.08,
14
+ "grad_norm": 0.15338309109210968,
15
+ "learning_rate": 0.000392,
16
+ "loss": 0.4726,
17
+ "step": 50
18
+ },
19
+ {
20
+ "epoch": 0.16,
21
+ "grad_norm": 0.1656411737203598,
22
+ "learning_rate": 0.0007920000000000001,
23
+ "loss": 0.3098,
24
+ "step": 100
25
+ },
26
+ {
27
+ "epoch": 0.24,
28
+ "grad_norm": 0.161162331700325,
29
+ "learning_rate": 0.0007964216926581925,
30
+ "loss": 0.2883,
31
+ "step": 150
32
+ },
33
+ {
34
+ "epoch": 0.32,
35
+ "grad_norm": 0.14719629287719727,
36
+ "learning_rate": 0.0007854602918076551,
37
+ "loss": 0.2773,
38
+ "step": 200
39
+ },
40
+ {
41
+ "epoch": 0.4,
42
+ "grad_norm": 0.1362672597169876,
43
+ "learning_rate": 0.0007673184950396212,
44
+ "loss": 0.2606,
45
+ "step": 250
46
+ },
47
+ {
48
+ "epoch": 0.48,
49
+ "grad_norm": 0.1420401930809021,
50
+ "learning_rate": 0.0007423342497022817,
51
+ "loss": 0.2549,
52
+ "step": 300
53
+ },
54
+ {
55
+ "epoch": 0.56,
56
+ "grad_norm": 0.15255458652973175,
57
+ "learning_rate": 0.0007109729650142636,
58
+ "loss": 0.2516,
59
+ "step": 350
60
+ },
61
+ {
62
+ "epoch": 0.64,
63
+ "grad_norm": 0.13546934723854065,
64
+ "learning_rate": 0.0006738188423714755,
65
+ "loss": 0.2439,
66
+ "step": 400
67
+ },
68
+ {
69
+ "epoch": 0.72,
70
+ "grad_norm": 0.1296033263206482,
71
+ "learning_rate": 0.0006315639927804526,
72
+ "loss": 0.2383,
73
+ "step": 450
74
+ },
75
+ {
76
+ "epoch": 0.8,
77
+ "grad_norm": 0.14936736226081848,
78
+ "learning_rate": 0.00058499554413983,
79
+ "loss": 0.2348,
80
+ "step": 500
81
+ },
82
+ {
83
+ "epoch": 0.88,
84
+ "grad_norm": 0.12654532492160797,
85
+ "learning_rate": 0.000534980978536894,
86
+ "loss": 0.2274,
87
+ "step": 550
88
+ },
89
+ {
90
+ "epoch": 0.96,
91
+ "grad_norm": 0.1250297725200653,
92
+ "learning_rate": 0.00048245197269763485,
93
+ "loss": 0.2298,
94
+ "step": 600
95
+ },
96
+ {
97
+ "epoch": 1.04,
98
+ "grad_norm": 0.1344439834356308,
99
+ "learning_rate": 0.00042838704261214224,
100
+ "loss": 0.2065,
101
+ "step": 650
102
+ },
103
+ {
104
+ "epoch": 1.12,
105
+ "grad_norm": 0.12664927542209625,
106
+ "learning_rate": 0.00037379331563313267,
107
+ "loss": 0.1907,
108
+ "step": 700
109
+ },
110
+ {
111
+ "epoch": 1.2,
112
+ "grad_norm": 0.1543550342321396,
113
+ "learning_rate": 0.00031968776959892677,
114
+ "loss": 0.1887,
115
+ "step": 750
116
+ },
117
+ {
118
+ "epoch": 1.28,
119
+ "grad_norm": 0.13837428390979767,
120
+ "learning_rate": 0.00026707828846051743,
121
+ "loss": 0.185,
122
+ "step": 800
123
+ },
124
+ {
125
+ "epoch": 1.3599999999999999,
126
+ "grad_norm": 0.12324073910713196,
127
+ "learning_rate": 0.00021694488731055218,
128
+ "loss": 0.1787,
129
+ "step": 850
130
+ },
131
+ {
132
+ "epoch": 1.44,
133
+ "grad_norm": 0.14447391033172607,
134
+ "learning_rate": 0.00017022145655641685,
135
+ "loss": 0.1779,
136
+ "step": 900
137
+ },
138
+ {
139
+ "epoch": 1.52,
140
+ "grad_norm": 0.13559409976005554,
141
+ "learning_rate": 0.00012777836530893536,
142
+ "loss": 0.1785,
143
+ "step": 950
144
+ },
145
+ {
146
+ "epoch": 1.6,
147
+ "grad_norm": 0.13572397828102112,
148
+ "learning_rate": 9.040624805263558e-05,
149
+ "loss": 0.176,
150
+ "step": 1000
151
+ },
152
+ {
153
+ "epoch": 1.6800000000000002,
154
+ "grad_norm": 0.13348858058452606,
155
+ "learning_rate": 5.880127662124091e-05,
156
+ "loss": 0.1743,
157
+ "step": 1050
158
+ },
159
+ {
160
+ "epoch": 1.76,
161
+ "grad_norm": 0.1402943730354309,
162
+ "learning_rate": 3.355219183361582e-05,
163
+ "loss": 0.1755,
164
+ "step": 1100
165
+ },
166
+ {
167
+ "epoch": 1.8399999999999999,
168
+ "grad_norm": 0.14928816258907318,
169
+ "learning_rate": 1.512933636625089e-05,
170
+ "loss": 0.1729,
171
+ "step": 1150
172
+ },
173
+ {
174
+ "epoch": 1.92,
175
+ "grad_norm": 0.14678366482257843,
176
+ "learning_rate": 3.8758931591217575e-06,
177
+ "loss": 0.1785,
178
+ "step": 1200
179
+ },
180
+ {
181
+ "epoch": 2.0,
182
+ "grad_norm": 0.13319681584835052,
183
+ "learning_rate": 1.4925668450960217e-09,
184
+ "loss": 0.1739,
185
+ "step": 1250
186
+ },
187
+ {
188
+ "epoch": 2.0,
189
+ "step": 1250,
190
+ "total_flos": 1.62585013911552e+18,
191
+ "train_loss": 0.2258549835205078,
192
+ "train_runtime": 2135.866,
193
+ "train_samples_per_second": 37.456,
194
+ "train_steps_per_second": 0.585
195
+ }
196
+ ],
197
+ "logging_steps": 50,
198
+ "max_steps": 1250,
199
+ "num_input_tokens_seen": 0,
200
+ "num_train_epochs": 2,
201
+ "save_steps": 0,
202
+ "stateful_callbacks": {
203
+ "TrainerControl": {
204
+ "args": {
205
+ "should_epoch_stop": false,
206
+ "should_evaluate": false,
207
+ "should_log": false,
208
+ "should_save": false,
209
+ "should_training_stop": false
210
+ },
211
+ "attributes": {}
212
+ }
213
+ },
214
+ "total_flos": 1.62585013911552e+18,
215
+ "train_batch_size": 64,
216
+ "trial_name": null,
217
+ "trial_params": null
218
+ }
nl_tasks/expsOFT/seed43/ft/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
nl_tasks/expsOFT/seed43/ft/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
nl_tasks/expsOFT/seed43/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
nl_tasks/expsOFT/seed43/ft/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "extra_special_tokens": {},
35
+ "legacy": false,
36
+ "model_max_length": 512,
37
+ "pad_token": "<unk>",
38
+ "padding_side": "right",
39
+ "sp_model_kwargs": {},
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
nl_tasks/expsOFT/seed43/ft2/README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ tags:
5
+ - base_model:adapter:meta-llama/Llama-2-7b-hf
6
+ - transformers
7
+ ---
8
+
9
+ # Model Card for Model ID
10
+
11
+ <!-- Provide a quick summary of what the model is/does. -->
12
+
13
+
14
+
15
+ ## Model Details
16
+
17
+ ### Model Description
18
+
19
+ <!-- Provide a longer summary of what this model is. -->
20
+
21
+
22
+
23
+ - **Developed by:** [More Information Needed]
24
+ - **Funded by [optional]:** [More Information Needed]
25
+ - **Shared by [optional]:** [More Information Needed]
26
+ - **Model type:** [More Information Needed]
27
+ - **Language(s) (NLP):** [More Information Needed]
28
+ - **License:** [More Information Needed]
29
+ - **Finetuned from model [optional]:** [More Information Needed]
30
+
31
+ ### Model Sources [optional]
32
+
33
+ <!-- Provide the basic links for the model. -->
34
+
35
+ - **Repository:** [More Information Needed]
36
+ - **Paper [optional]:** [More Information Needed]
37
+ - **Demo [optional]:** [More Information Needed]
38
+
39
+ ## Uses
40
+
41
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
42
+
43
+ ### Direct Use
44
+
45
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
46
+
47
+ [More Information Needed]
48
+
49
+ ### Downstream Use [optional]
50
+
51
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
52
+
53
+ [More Information Needed]
54
+
55
+ ### Out-of-Scope Use
56
+
57
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
58
+
59
+ [More Information Needed]
60
+
61
+ ## Bias, Risks, and Limitations
62
+
63
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
64
+
65
+ [More Information Needed]
66
+
67
+ ### Recommendations
68
+
69
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
70
+
71
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
72
+
73
+ ## How to Get Started with the Model
74
+
75
+ Use the code below to get started with the model.
76
+
77
+ [More Information Needed]
78
+
79
+ ## Training Details
80
+
81
+ ### Training Data
82
+
83
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
84
+
85
+ [More Information Needed]
86
+
87
+ ### Training Procedure
88
+
89
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
90
+
91
+ #### Preprocessing [optional]
92
+
93
+ [More Information Needed]
94
+
95
+
96
+ #### Training Hyperparameters
97
+
98
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
99
+
100
+ #### Speeds, Sizes, Times [optional]
101
+
102
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
103
+
104
+ [More Information Needed]
105
+
106
+ ## Evaluation
107
+
108
+ <!-- This section describes the evaluation protocols and provides the results. -->
109
+
110
+ ### Testing Data, Factors & Metrics
111
+
112
+ #### Testing Data
113
+
114
+ <!-- This should link to a Dataset Card if possible. -->
115
+
116
+ [More Information Needed]
117
+
118
+ #### Factors
119
+
120
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
121
+
122
+ [More Information Needed]
123
+
124
+ #### Metrics
125
+
126
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
127
+
128
+ [More Information Needed]
129
+
130
+ ### Results
131
+
132
+ [More Information Needed]
133
+
134
+ #### Summary
135
+
136
+
137
+
138
+ ## Model Examination [optional]
139
+
140
+ <!-- Relevant interpretability work for the model goes here -->
141
+
142
+ [More Information Needed]
143
+
144
+ ## Environmental Impact
145
+
146
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
147
+
148
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
149
+
150
+ - **Hardware Type:** [More Information Needed]
151
+ - **Hours used:** [More Information Needed]
152
+ - **Cloud Provider:** [More Information Needed]
153
+ - **Compute Region:** [More Information Needed]
154
+ - **Carbon Emitted:** [More Information Needed]
155
+
156
+ ## Technical Specifications [optional]
157
+
158
+ ### Model Architecture and Objective
159
+
160
+ [More Information Needed]
161
+
162
+ ### Compute Infrastructure
163
+
164
+ [More Information Needed]
165
+
166
+ #### Hardware
167
+
168
+ [More Information Needed]
169
+
170
+ #### Software
171
+
172
+ [More Information Needed]
173
+
174
+ ## Citation [optional]
175
+
176
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
177
+
178
+ **BibTeX:**
179
+
180
+ [More Information Needed]
181
+
182
+ **APA:**
183
+
184
+ [More Information Needed]
185
+
186
+ ## Glossary [optional]
187
+
188
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
189
+
190
+ [More Information Needed]
191
+
192
+ ## More Information [optional]
193
+
194
+ [More Information Needed]
195
+
196
+ ## Model Card Authors [optional]
197
+
198
+ [More Information Needed]
199
+
200
+ ## Model Card Contact
201
+
202
+ [More Information Needed]
203
+ ### Framework versions
204
+
205
+ - PEFT 0.18.0
nl_tasks/expsOFT/seed43/ft2/adapter_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": {
3
+ "base_model_class": "LlamaForCausalLM",
4
+ "parent_library": "transformers.models.llama.modeling_llama"
5
+ },
6
+ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
7
+ "bias": "none",
8
+ "block_share": false,
9
+ "coft": false,
10
+ "eps": 6e-05,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_weights": true,
15
+ "layers_pattern": null,
16
+ "layers_to_transform": null,
17
+ "module_dropout": 0.05,
18
+ "modules_to_save": null,
19
+ "num_cayley_neumann_terms": 5,
20
+ "oft_block_size": 64,
21
+ "peft_type": "OFT",
22
+ "peft_version": "0.18.0",
23
+ "r": 0,
24
+ "revision": null,
25
+ "target_modules": [
26
+ "q_proj",
27
+ "v_proj"
28
+ ],
29
+ "task_type": null,
30
+ "use_cayley_neumann": true
31
+ }
nl_tasks/expsOFT/seed43/ft2/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d16378461c75d46a179539ea2223803c3af83b5ebb2dcc6face78c64e3ac4f9c
3
+ size 33038696
nl_tasks/expsOFT/seed43/trainer_state.json ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 2.0,
6
+ "eval_steps": 500,
7
+ "global_step": 1250,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.08,
14
+ "grad_norm": 0.15338309109210968,
15
+ "learning_rate": 0.000392,
16
+ "loss": 0.4726,
17
+ "step": 50
18
+ },
19
+ {
20
+ "epoch": 0.16,
21
+ "grad_norm": 0.1656411737203598,
22
+ "learning_rate": 0.0007920000000000001,
23
+ "loss": 0.3098,
24
+ "step": 100
25
+ },
26
+ {
27
+ "epoch": 0.24,
28
+ "grad_norm": 0.161162331700325,
29
+ "learning_rate": 0.0007964216926581925,
30
+ "loss": 0.2883,
31
+ "step": 150
32
+ },
33
+ {
34
+ "epoch": 0.32,
35
+ "grad_norm": 0.14719629287719727,
36
+ "learning_rate": 0.0007854602918076551,
37
+ "loss": 0.2773,
38
+ "step": 200
39
+ },
40
+ {
41
+ "epoch": 0.4,
42
+ "grad_norm": 0.1362672597169876,
43
+ "learning_rate": 0.0007673184950396212,
44
+ "loss": 0.2606,
45
+ "step": 250
46
+ },
47
+ {
48
+ "epoch": 0.48,
49
+ "grad_norm": 0.1420401930809021,
50
+ "learning_rate": 0.0007423342497022817,
51
+ "loss": 0.2549,
52
+ "step": 300
53
+ },
54
+ {
55
+ "epoch": 0.56,
56
+ "grad_norm": 0.15255458652973175,
57
+ "learning_rate": 0.0007109729650142636,
58
+ "loss": 0.2516,
59
+ "step": 350
60
+ },
61
+ {
62
+ "epoch": 0.64,
63
+ "grad_norm": 0.13546934723854065,
64
+ "learning_rate": 0.0006738188423714755,
65
+ "loss": 0.2439,
66
+ "step": 400
67
+ },
68
+ {
69
+ "epoch": 0.72,
70
+ "grad_norm": 0.1296033263206482,
71
+ "learning_rate": 0.0006315639927804526,
72
+ "loss": 0.2383,
73
+ "step": 450
74
+ },
75
+ {
76
+ "epoch": 0.8,
77
+ "grad_norm": 0.14936736226081848,
78
+ "learning_rate": 0.00058499554413983,
79
+ "loss": 0.2348,
80
+ "step": 500
81
+ },
82
+ {
83
+ "epoch": 0.88,
84
+ "grad_norm": 0.12654532492160797,
85
+ "learning_rate": 0.000534980978536894,
86
+ "loss": 0.2274,
87
+ "step": 550
88
+ },
89
+ {
90
+ "epoch": 0.96,
91
+ "grad_norm": 0.1250297725200653,
92
+ "learning_rate": 0.00048245197269763485,
93
+ "loss": 0.2298,
94
+ "step": 600
95
+ },
96
+ {
97
+ "epoch": 1.04,
98
+ "grad_norm": 0.1344439834356308,
99
+ "learning_rate": 0.00042838704261214224,
100
+ "loss": 0.2065,
101
+ "step": 650
102
+ },
103
+ {
104
+ "epoch": 1.12,
105
+ "grad_norm": 0.12664927542209625,
106
+ "learning_rate": 0.00037379331563313267,
107
+ "loss": 0.1907,
108
+ "step": 700
109
+ },
110
+ {
111
+ "epoch": 1.2,
112
+ "grad_norm": 0.1543550342321396,
113
+ "learning_rate": 0.00031968776959892677,
114
+ "loss": 0.1887,
115
+ "step": 750
116
+ },
117
+ {
118
+ "epoch": 1.28,
119
+ "grad_norm": 0.13837428390979767,
120
+ "learning_rate": 0.00026707828846051743,
121
+ "loss": 0.185,
122
+ "step": 800
123
+ },
124
+ {
125
+ "epoch": 1.3599999999999999,
126
+ "grad_norm": 0.12324073910713196,
127
+ "learning_rate": 0.00021694488731055218,
128
+ "loss": 0.1787,
129
+ "step": 850
130
+ },
131
+ {
132
+ "epoch": 1.44,
133
+ "grad_norm": 0.14447391033172607,
134
+ "learning_rate": 0.00017022145655641685,
135
+ "loss": 0.1779,
136
+ "step": 900
137
+ },
138
+ {
139
+ "epoch": 1.52,
140
+ "grad_norm": 0.13559409976005554,
141
+ "learning_rate": 0.00012777836530893536,
142
+ "loss": 0.1785,
143
+ "step": 950
144
+ },
145
+ {
146
+ "epoch": 1.6,
147
+ "grad_norm": 0.13572397828102112,
148
+ "learning_rate": 9.040624805263558e-05,
149
+ "loss": 0.176,
150
+ "step": 1000
151
+ },
152
+ {
153
+ "epoch": 1.6800000000000002,
154
+ "grad_norm": 0.13348858058452606,
155
+ "learning_rate": 5.880127662124091e-05,
156
+ "loss": 0.1743,
157
+ "step": 1050
158
+ },
159
+ {
160
+ "epoch": 1.76,
161
+ "grad_norm": 0.1402943730354309,
162
+ "learning_rate": 3.355219183361582e-05,
163
+ "loss": 0.1755,
164
+ "step": 1100
165
+ },
166
+ {
167
+ "epoch": 1.8399999999999999,
168
+ "grad_norm": 0.14928816258907318,
169
+ "learning_rate": 1.512933636625089e-05,
170
+ "loss": 0.1729,
171
+ "step": 1150
172
+ },
173
+ {
174
+ "epoch": 1.92,
175
+ "grad_norm": 0.14678366482257843,
176
+ "learning_rate": 3.8758931591217575e-06,
177
+ "loss": 0.1785,
178
+ "step": 1200
179
+ },
180
+ {
181
+ "epoch": 2.0,
182
+ "grad_norm": 0.13319681584835052,
183
+ "learning_rate": 1.4925668450960217e-09,
184
+ "loss": 0.1739,
185
+ "step": 1250
186
+ },
187
+ {
188
+ "epoch": 2.0,
189
+ "step": 1250,
190
+ "total_flos": 1.62585013911552e+18,
191
+ "train_loss": 0.2258549835205078,
192
+ "train_runtime": 2134.8975,
193
+ "train_samples_per_second": 37.473,
194
+ "train_steps_per_second": 0.586
195
+ }
196
+ ],
197
+ "logging_steps": 50,
198
+ "max_steps": 1250,
199
+ "num_input_tokens_seen": 0,
200
+ "num_train_epochs": 2,
201
+ "save_steps": 0,
202
+ "stateful_callbacks": {
203
+ "TrainerControl": {
204
+ "args": {
205
+ "should_epoch_stop": false,
206
+ "should_evaluate": false,
207
+ "should_log": false,
208
+ "should_save": false,
209
+ "should_training_stop": false
210
+ },
211
+ "attributes": {}
212
+ }
213
+ },
214
+ "total_flos": 1.62585013911552e+18,
215
+ "train_batch_size": 64,
216
+ "trial_name": null,
217
+ "trial_params": null
218
+ }
nl_tasks/expsOFT/seed44/ft/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
nl_tasks/expsOFT/seed44/ft/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
nl_tasks/expsOFT/seed44/ft/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
nl_tasks/expsOFT/seed44/ft/tokenizer_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ }
30
+ },
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "extra_special_tokens": {},
35
+ "legacy": false,
36
+ "model_max_length": 512,
37
+ "pad_token": "<unk>",
38
+ "padding_side": "right",
39
+ "sp_model_kwargs": {},
40
+ "tokenizer_class": "LlamaTokenizer",
41
+ "unk_token": "<unk>",
42
+ "use_default_system_prompt": false
43
+ }
nl_tasks/expsOFT/seed44/ft2/README.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ tags:
5
+ - base_model:adapter:meta-llama/Llama-2-7b-hf
6
+ - transformers
7
+ ---
8
+
9
+ # Model Card for Model ID
10
+
11
+ <!-- Provide a quick summary of what the model is/does. -->
12
+
13
+
14
+
15
+ ## Model Details
16
+
17
+ ### Model Description
18
+
19
+ <!-- Provide a longer summary of what this model is. -->
20
+
21
+
22
+
23
+ - **Developed by:** [More Information Needed]
24
+ - **Funded by [optional]:** [More Information Needed]
25
+ - **Shared by [optional]:** [More Information Needed]
26
+ - **Model type:** [More Information Needed]
27
+ - **Language(s) (NLP):** [More Information Needed]
28
+ - **License:** [More Information Needed]
29
+ - **Finetuned from model [optional]:** [More Information Needed]
30
+
31
+ ### Model Sources [optional]
32
+
33
+ <!-- Provide the basic links for the model. -->
34
+
35
+ - **Repository:** [More Information Needed]
36
+ - **Paper [optional]:** [More Information Needed]
37
+ - **Demo [optional]:** [More Information Needed]
38
+
39
+ ## Uses
40
+
41
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
42
+
43
+ ### Direct Use
44
+
45
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
46
+
47
+ [More Information Needed]
48
+
49
+ ### Downstream Use [optional]
50
+
51
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
52
+
53
+ [More Information Needed]
54
+
55
+ ### Out-of-Scope Use
56
+
57
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
58
+
59
+ [More Information Needed]
60
+
61
+ ## Bias, Risks, and Limitations
62
+
63
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
64
+
65
+ [More Information Needed]
66
+
67
+ ### Recommendations
68
+
69
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
70
+
71
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
72
+
73
+ ## How to Get Started with the Model
74
+
75
+ Use the code below to get started with the model.
76
+
77
+ [More Information Needed]
78
+
79
+ ## Training Details
80
+
81
+ ### Training Data
82
+
83
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
84
+
85
+ [More Information Needed]
86
+
87
+ ### Training Procedure
88
+
89
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
90
+
91
+ #### Preprocessing [optional]
92
+
93
+ [More Information Needed]
94
+
95
+
96
+ #### Training Hyperparameters
97
+
98
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
99
+
100
+ #### Speeds, Sizes, Times [optional]
101
+
102
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
103
+
104
+ [More Information Needed]
105
+
106
+ ## Evaluation
107
+
108
+ <!-- This section describes the evaluation protocols and provides the results. -->
109
+
110
+ ### Testing Data, Factors & Metrics
111
+
112
+ #### Testing Data
113
+
114
+ <!-- This should link to a Dataset Card if possible. -->
115
+
116
+ [More Information Needed]
117
+
118
+ #### Factors
119
+
120
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
121
+
122
+ [More Information Needed]
123
+
124
+ #### Metrics
125
+
126
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
127
+
128
+ [More Information Needed]
129
+
130
+ ### Results
131
+
132
+ [More Information Needed]
133
+
134
+ #### Summary
135
+
136
+
137
+
138
+ ## Model Examination [optional]
139
+
140
+ <!-- Relevant interpretability work for the model goes here -->
141
+
142
+ [More Information Needed]
143
+
144
+ ## Environmental Impact
145
+
146
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
147
+
148
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
149
+
150
+ - **Hardware Type:** [More Information Needed]
151
+ - **Hours used:** [More Information Needed]
152
+ - **Cloud Provider:** [More Information Needed]
153
+ - **Compute Region:** [More Information Needed]
154
+ - **Carbon Emitted:** [More Information Needed]
155
+
156
+ ## Technical Specifications [optional]
157
+
158
+ ### Model Architecture and Objective
159
+
160
+ [More Information Needed]
161
+
162
+ ### Compute Infrastructure
163
+
164
+ [More Information Needed]
165
+
166
+ #### Hardware
167
+
168
+ [More Information Needed]
169
+
170
+ #### Software
171
+
172
+ [More Information Needed]
173
+
174
+ ## Citation [optional]
175
+
176
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
177
+
178
+ **BibTeX:**
179
+
180
+ [More Information Needed]
181
+
182
+ **APA:**
183
+
184
+ [More Information Needed]
185
+
186
+ ## Glossary [optional]
187
+
188
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
189
+
190
+ [More Information Needed]
191
+
192
+ ## More Information [optional]
193
+
194
+ [More Information Needed]
195
+
196
+ ## Model Card Authors [optional]
197
+
198
+ [More Information Needed]
199
+
200
+ ## Model Card Contact
201
+
202
+ [More Information Needed]
203
+ ### Framework versions
204
+
205
+ - PEFT 0.18.0
nl_tasks/expsOFT/seed44/ft2/adapter_config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_mapping": {
3
+ "base_model_class": "LlamaForCausalLM",
4
+ "parent_library": "transformers.models.llama.modeling_llama"
5
+ },
6
+ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
7
+ "bias": "none",
8
+ "block_share": false,
9
+ "coft": false,
10
+ "eps": 6e-05,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_weights": true,
15
+ "layers_pattern": null,
16
+ "layers_to_transform": null,
17
+ "module_dropout": 0.05,
18
+ "modules_to_save": null,
19
+ "num_cayley_neumann_terms": 5,
20
+ "oft_block_size": 64,
21
+ "peft_type": "OFT",
22
+ "peft_version": "0.18.0",
23
+ "r": 0,
24
+ "revision": null,
25
+ "target_modules": [
26
+ "q_proj",
27
+ "v_proj"
28
+ ],
29
+ "task_type": null,
30
+ "use_cayley_neumann": true
31
+ }
nl_tasks/expsOFT/seed44/ft2/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d16378461c75d46a179539ea2223803c3af83b5ebb2dcc6face78c64e3ac4f9c
3
+ size 33038696
nl_tasks/expsOFT/seed44/trainer_state.json ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 2.0,
6
+ "eval_steps": 500,
7
+ "global_step": 1250,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.08,
14
+ "grad_norm": 0.15338309109210968,
15
+ "learning_rate": 0.000392,
16
+ "loss": 0.4726,
17
+ "step": 50
18
+ },
19
+ {
20
+ "epoch": 0.16,
21
+ "grad_norm": 0.1656411737203598,
22
+ "learning_rate": 0.0007920000000000001,
23
+ "loss": 0.3098,
24
+ "step": 100
25
+ },
26
+ {
27
+ "epoch": 0.24,
28
+ "grad_norm": 0.161162331700325,
29
+ "learning_rate": 0.0007964216926581925,
30
+ "loss": 0.2883,
31
+ "step": 150
32
+ },
33
+ {
34
+ "epoch": 0.32,
35
+ "grad_norm": 0.14719629287719727,
36
+ "learning_rate": 0.0007854602918076551,
37
+ "loss": 0.2773,
38
+ "step": 200
39
+ },
40
+ {
41
+ "epoch": 0.4,
42
+ "grad_norm": 0.1362672597169876,
43
+ "learning_rate": 0.0007673184950396212,
44
+ "loss": 0.2606,
45
+ "step": 250
46
+ },
47
+ {
48
+ "epoch": 0.48,
49
+ "grad_norm": 0.1420401930809021,
50
+ "learning_rate": 0.0007423342497022817,
51
+ "loss": 0.2549,
52
+ "step": 300
53
+ },
54
+ {
55
+ "epoch": 0.56,
56
+ "grad_norm": 0.15255458652973175,
57
+ "learning_rate": 0.0007109729650142636,
58
+ "loss": 0.2516,
59
+ "step": 350
60
+ },
61
+ {
62
+ "epoch": 0.64,
63
+ "grad_norm": 0.13546934723854065,
64
+ "learning_rate": 0.0006738188423714755,
65
+ "loss": 0.2439,
66
+ "step": 400
67
+ },
68
+ {
69
+ "epoch": 0.72,
70
+ "grad_norm": 0.1296033263206482,
71
+ "learning_rate": 0.0006315639927804526,
72
+ "loss": 0.2383,
73
+ "step": 450
74
+ },
75
+ {
76
+ "epoch": 0.8,
77
+ "grad_norm": 0.14936736226081848,
78
+ "learning_rate": 0.00058499554413983,
79
+ "loss": 0.2348,
80
+ "step": 500
81
+ },
82
+ {
83
+ "epoch": 0.88,
84
+ "grad_norm": 0.12654532492160797,
85
+ "learning_rate": 0.000534980978536894,
86
+ "loss": 0.2274,
87
+ "step": 550
88
+ },
89
+ {
90
+ "epoch": 0.96,
91
+ "grad_norm": 0.1250297725200653,
92
+ "learning_rate": 0.00048245197269763485,
93
+ "loss": 0.2298,
94
+ "step": 600
95
+ },
96
+ {
97
+ "epoch": 1.04,
98
+ "grad_norm": 0.1344439834356308,
99
+ "learning_rate": 0.00042838704261214224,
100
+ "loss": 0.2065,
101
+ "step": 650
102
+ },
103
+ {
104
+ "epoch": 1.12,
105
+ "grad_norm": 0.12664927542209625,
106
+ "learning_rate": 0.00037379331563313267,
107
+ "loss": 0.1907,
108
+ "step": 700
109
+ },
110
+ {
111
+ "epoch": 1.2,
112
+ "grad_norm": 0.1543550342321396,
113
+ "learning_rate": 0.00031968776959892677,
114
+ "loss": 0.1887,
115
+ "step": 750
116
+ },
117
+ {
118
+ "epoch": 1.28,
119
+ "grad_norm": 0.13837428390979767,
120
+ "learning_rate": 0.00026707828846051743,
121
+ "loss": 0.185,
122
+ "step": 800
123
+ },
124
+ {
125
+ "epoch": 1.3599999999999999,
126
+ "grad_norm": 0.12324073910713196,
127
+ "learning_rate": 0.00021694488731055218,
128
+ "loss": 0.1787,
129
+ "step": 850
130
+ },
131
+ {
132
+ "epoch": 1.44,
133
+ "grad_norm": 0.14447391033172607,
134
+ "learning_rate": 0.00017022145655641685,
135
+ "loss": 0.1779,
136
+ "step": 900
137
+ },
138
+ {
139
+ "epoch": 1.52,
140
+ "grad_norm": 0.13559409976005554,
141
+ "learning_rate": 0.00012777836530893536,
142
+ "loss": 0.1785,
143
+ "step": 950
144
+ },
145
+ {
146
+ "epoch": 1.6,
147
+ "grad_norm": 0.13572397828102112,
148
+ "learning_rate": 9.040624805263558e-05,
149
+ "loss": 0.176,
150
+ "step": 1000
151
+ },
152
+ {
153
+ "epoch": 1.6800000000000002,
154
+ "grad_norm": 0.13348858058452606,
155
+ "learning_rate": 5.880127662124091e-05,
156
+ "loss": 0.1743,
157
+ "step": 1050
158
+ },
159
+ {
160
+ "epoch": 1.76,
161
+ "grad_norm": 0.1402943730354309,
162
+ "learning_rate": 3.355219183361582e-05,
163
+ "loss": 0.1755,
164
+ "step": 1100
165
+ },
166
+ {
167
+ "epoch": 1.8399999999999999,
168
+ "grad_norm": 0.14928816258907318,
169
+ "learning_rate": 1.512933636625089e-05,
170
+ "loss": 0.1729,
171
+ "step": 1150
172
+ },
173
+ {
174
+ "epoch": 1.92,
175
+ "grad_norm": 0.14678366482257843,
176
+ "learning_rate": 3.8758931591217575e-06,
177
+ "loss": 0.1785,
178
+ "step": 1200
179
+ },
180
+ {
181
+ "epoch": 2.0,
182
+ "grad_norm": 0.13319681584835052,
183
+ "learning_rate": 1.4925668450960217e-09,
184
+ "loss": 0.1739,
185
+ "step": 1250
186
+ },
187
+ {
188
+ "epoch": 2.0,
189
+ "step": 1250,
190
+ "total_flos": 1.62585013911552e+18,
191
+ "train_loss": 0.2258549835205078,
192
+ "train_runtime": 2124.0047,
193
+ "train_samples_per_second": 37.665,
194
+ "train_steps_per_second": 0.589
195
+ }
196
+ ],
197
+ "logging_steps": 50,
198
+ "max_steps": 1250,
199
+ "num_input_tokens_seen": 0,
200
+ "num_train_epochs": 2,
201
+ "save_steps": 0,
202
+ "stateful_callbacks": {
203
+ "TrainerControl": {
204
+ "args": {
205
+ "should_epoch_stop": false,
206
+ "should_evaluate": false,
207
+ "should_log": false,
208
+ "should_save": false,
209
+ "should_training_stop": false
210
+ },
211
+ "attributes": {}
212
+ }
213
+ },
214
+ "total_flos": 1.62585013911552e+18,
215
+ "train_batch_size": 64,
216
+ "trial_name": null,
217
+ "trial_params": null
218
+ }
omini/__init__.py ADDED
File without changes
omini/pipeline/flux_omini.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import List, Union, Optional, Dict, Any, Callable, Type, Tuple
3
+
4
+ from diffusers.pipelines import FluxPipeline
5
+ from diffusers.pipelines.flux.pipeline_flux import (
6
+ FluxPipelineOutput,
7
+ FluxTransformer2DModel,
8
+ calculate_shift,
9
+ retrieve_timesteps,
10
+ np,
11
+ )
12
+ from diffusers.models.attention_processor import Attention, F
13
+ from diffusers.models.embeddings import apply_rotary_emb
14
+ from transformers import pipeline
15
+
16
+ from peft.tuners.tuners_utils import BaseTunerLayer
17
+ from accelerate.utils import is_torch_version
18
+
19
+ from contextlib import contextmanager
20
+
21
+ import cv2
22
+
23
+ from PIL import Image, ImageFilter
24
+
25
+
26
+ def seed_everything(seed: int = 42):
27
+ torch.backends.cudnn.deterministic = True
28
+ torch.manual_seed(seed)
29
+ np.random.seed(seed)
30
+
31
+
32
+ def clip_hidden_states(hidden_states: torch.FloatTensor) -> torch.FloatTensor:
33
+ if hidden_states.dtype == torch.float16:
34
+ hidden_states = hidden_states.clip(-65504, 65504)
35
+ return hidden_states
36
+
37
+
38
+ def encode_images(pipeline: FluxPipeline, images: torch.Tensor):
39
+ """
40
+ Encodes the images into tokens and ids for FLUX pipeline.
41
+ """
42
+ images = pipeline.image_processor.preprocess(images)
43
+ images = images.to(pipeline.device).to(pipeline.dtype)
44
+ images = pipeline.vae.encode(images).latent_dist.sample()
45
+ images = (
46
+ images - pipeline.vae.config.shift_factor
47
+ ) * pipeline.vae.config.scaling_factor
48
+ images_tokens = pipeline._pack_latents(images, *images.shape)
49
+ images_ids = pipeline._prepare_latent_image_ids(
50
+ images.shape[0],
51
+ images.shape[2],
52
+ images.shape[3],
53
+ pipeline.device,
54
+ pipeline.dtype,
55
+ )
56
+ if images_tokens.shape[1] != images_ids.shape[0]:
57
+ images_ids = pipeline._prepare_latent_image_ids(
58
+ images.shape[0],
59
+ images.shape[2] // 2,
60
+ images.shape[3] // 2,
61
+ pipeline.device,
62
+ pipeline.dtype,
63
+ )
64
+ return images_tokens, images_ids
65
+
66
+
67
+ depth_pipe = None
68
+
69
+
70
+ def convert_to_condition(
71
+ condition_type: str,
72
+ raw_img: Union[Image.Image, torch.Tensor],
73
+ blur_radius: Optional[int] = 5,
74
+ ) -> Union[Image.Image, torch.Tensor]:
75
+ if condition_type == "depth":
76
+ global depth_pipe
77
+ depth_pipe = depth_pipe or pipeline(
78
+ task="depth-estimation",
79
+ model="LiheYoung/depth-anything-small-hf",
80
+ device="cpu", # Use "cpu" to enable parallel processing
81
+ )
82
+ source_image = raw_img.convert("RGB")
83
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
84
+ return condition_img
85
+ elif condition_type == "canny":
86
+ img = np.array(raw_img)
87
+ edges = cv2.Canny(img, 100, 200)
88
+ edges = Image.fromarray(edges).convert("RGB")
89
+ return edges
90
+ elif condition_type == "coloring":
91
+ return raw_img.convert("L").convert("RGB")
92
+ elif condition_type == "deblurring":
93
+ condition_image = (
94
+ raw_img.convert("RGB")
95
+ .filter(ImageFilter.GaussianBlur(blur_radius))
96
+ .convert("RGB")
97
+ )
98
+ return condition_image
99
+ else:
100
+ print("Warning: Returning the raw image.")
101
+ return raw_img.convert("RGB")
102
+
103
+
104
+ class Condition(object):
105
+ def __init__(
106
+ self,
107
+ condition: Union[Image.Image, torch.Tensor],
108
+ adapter_setting: Union[str, dict],
109
+ position_delta=None,
110
+ position_scale=1.0,
111
+ latent_mask=None,
112
+ is_complement=False,
113
+ ) -> None:
114
+ self.condition = condition
115
+ self.adapter = adapter_setting
116
+ self.position_delta = position_delta
117
+ self.position_scale = position_scale
118
+ self.latent_mask = (
119
+ latent_mask.T.reshape(-1) if latent_mask is not None else None
120
+ )
121
+ self.is_complement = is_complement
122
+
123
+ def encode(
124
+ self, pipe: FluxPipeline, empty: bool = False
125
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
126
+ condition_empty = Image.new("RGB", self.condition.size, (0, 0, 0))
127
+ tokens, ids = encode_images(pipe, condition_empty if empty else self.condition)
128
+
129
+ if self.position_delta is not None:
130
+ ids[:, 1] += self.position_delta[0]
131
+ ids[:, 2] += self.position_delta[1]
132
+
133
+ if self.position_scale != 1.0:
134
+ scale_bias = (self.position_scale - 1.0) / 2
135
+ ids[:, 1:] *= self.position_scale
136
+ ids[:, 1:] += scale_bias
137
+
138
+ if self.latent_mask is not None:
139
+ tokens = tokens[:, self.latent_mask]
140
+ ids = ids[self.latent_mask]
141
+
142
+ return tokens, ids
143
+
144
+
145
+ @contextmanager
146
+ def specify_lora(lora_modules: List[BaseTunerLayer], specified_lora):
147
+ # Filter valid lora modules
148
+ valid_lora_modules = [m for m in lora_modules if isinstance(m, BaseTunerLayer)]
149
+ # Save original scales
150
+ original_scales = [
151
+ {
152
+ adapter: module.scaling[adapter]
153
+ for adapter in module.active_adapters
154
+ if adapter in module.scaling
155
+ }
156
+ for module in valid_lora_modules
157
+ ]
158
+ # Enter context: adjust scaling
159
+ for module in valid_lora_modules:
160
+ for adapter in module.active_adapters:
161
+ if adapter in module.scaling:
162
+ module.scaling[adapter] = 1 if adapter == specified_lora else 0
163
+ try:
164
+ yield
165
+ finally:
166
+ # Exit context: restore original scales
167
+ for module, scales in zip(valid_lora_modules, original_scales):
168
+ for adapter in module.active_adapters:
169
+ if adapter in module.scaling:
170
+ module.scaling[adapter] = scales[adapter]
171
+
172
+
173
+ def attn_forward(
174
+ attn: Attention,
175
+ hidden_states: List[torch.FloatTensor],
176
+ adapters: List[str],
177
+ hidden_states2: Optional[List[torch.FloatTensor]] = [],
178
+ position_embs: Optional[List[torch.Tensor]] = None,
179
+ group_mask: Optional[torch.Tensor] = None,
180
+ cache_mode: Optional[str] = None,
181
+ # to determine whether to cache the keys and values for this branch
182
+ to_cache: Optional[List[torch.Tensor]] = None,
183
+ cache_storage: Optional[List[torch.Tensor]] = None,
184
+ **kwargs: dict,
185
+ ) -> torch.FloatTensor:
186
+ bs, _, _ = hidden_states[0].shape
187
+ h2_n = len(hidden_states2)
188
+
189
+ queries, keys, values = [], [], []
190
+
191
+ # Prepare query, key, value for each encoder hidden state (text branch)
192
+ for i, hidden_state in enumerate(hidden_states2):
193
+ query = attn.add_q_proj(hidden_state)
194
+ key = attn.add_k_proj(hidden_state)
195
+ value = attn.add_v_proj(hidden_state)
196
+
197
+ head_dim = key.shape[-1] // attn.heads
198
+ reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
199
+
200
+ query, key, value = map(reshape_fn, (query, key, value))
201
+ query, key = attn.norm_added_q(query), attn.norm_added_k(key)
202
+
203
+ queries.append(query)
204
+ keys.append(key)
205
+ values.append(value)
206
+
207
+ # Prepare query, key, value for each hidden state (image branch)
208
+ for i, hidden_state in enumerate(hidden_states):
209
+ with specify_lora((attn.to_q, attn.to_k, attn.to_v), adapters[i + h2_n]):
210
+ query = attn.to_q(hidden_state)
211
+ key = attn.to_k(hidden_state)
212
+ value = attn.to_v(hidden_state)
213
+
214
+ head_dim = key.shape[-1] // attn.heads
215
+ reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
216
+
217
+ query, key, value = map(reshape_fn, (query, key, value))
218
+ query, key = attn.norm_q(query), attn.norm_k(key)
219
+
220
+ queries.append(query)
221
+ keys.append(key)
222
+ values.append(value)
223
+
224
+ # Apply rotary embedding
225
+ if position_embs is not None:
226
+ queries = [apply_rotary_emb(q, position_embs[i]) for i, q in enumerate(queries)]
227
+ keys = [apply_rotary_emb(k, position_embs[i]) for i, k in enumerate(keys)]
228
+
229
+ if cache_mode == "write":
230
+ for i, (k, v) in enumerate(zip(keys, values)):
231
+ if to_cache[i]:
232
+ cache_storage[attn.cache_idx][0].append(k)
233
+ cache_storage[attn.cache_idx][1].append(v)
234
+
235
+ attn_outputs = []
236
+ for i, query in enumerate(queries):
237
+ keys_, values_ = [], []
238
+ # Add keys and values from other branches
239
+ for j, (k, v) in enumerate(zip(keys, values)):
240
+ if (group_mask is not None) and not (group_mask[i][j].item()):
241
+ continue
242
+ keys_.append(k)
243
+ values_.append(v)
244
+ if cache_mode == "read":
245
+ keys_.extend(cache_storage[attn.cache_idx][0])
246
+ values_.extend(cache_storage[attn.cache_idx][1])
247
+ # Add keys and values from cache TODO
248
+ # Attention computation
249
+ attn_output = F.scaled_dot_product_attention(
250
+ query, torch.cat(keys_, dim=2), torch.cat(values_, dim=2)
251
+ ).to(query.dtype)
252
+ attn_output = attn_output.transpose(1, 2).reshape(bs, -1, attn.heads * head_dim)
253
+ attn_outputs.append(attn_output)
254
+
255
+ # Reshape attention output to match the original hidden states
256
+ h_out, h2_out = [], []
257
+
258
+ for i, hidden_state in enumerate(hidden_states2):
259
+ h2_out.append(attn.to_add_out(attn_outputs[i]))
260
+
261
+ for i, hidden_state in enumerate(hidden_states):
262
+ h = attn_outputs[i + h2_n]
263
+ if getattr(attn, "to_out", None) is not None:
264
+ with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
265
+ h = attn.to_out[0](h)
266
+ h_out.append(h)
267
+
268
+ return (h_out, h2_out) if h2_n else h_out
269
+
270
+
271
+ def block_forward(
272
+ self,
273
+ image_hidden_states: List[torch.FloatTensor],
274
+ text_hidden_states: List[torch.FloatTensor],
275
+ tembs: List[torch.FloatTensor],
276
+ adapters: List[str],
277
+ position_embs=None,
278
+ attn_forward=attn_forward,
279
+ **kwargs: dict,
280
+ ):
281
+ txt_n = len(text_hidden_states)
282
+
283
+ img_variables, txt_variables = [], []
284
+
285
+ for i, text_h in enumerate(text_hidden_states):
286
+ txt_variables.append(self.norm1_context(text_h, emb=tembs[i]))
287
+
288
+ for i, image_h in enumerate(image_hidden_states):
289
+ with specify_lora((self.norm1.linear,), adapters[i + txt_n]):
290
+ img_variables.append(self.norm1(image_h, emb=tembs[i + txt_n]))
291
+
292
+ # Attention.
293
+ img_attn_output, txt_attn_output = attn_forward(
294
+ self.attn,
295
+ hidden_states=[each[0] for each in img_variables],
296
+ hidden_states2=[each[0] for each in txt_variables],
297
+ position_embs=position_embs,
298
+ adapters=adapters,
299
+ **kwargs,
300
+ )
301
+
302
+ text_out = []
303
+ for i in range(len(text_hidden_states)):
304
+ _, gate_msa, shift_mlp, scale_mlp, gate_mlp = txt_variables[i]
305
+ text_h = text_hidden_states[i] + txt_attn_output[i] * gate_msa.unsqueeze(1)
306
+ norm_h = (
307
+ self.norm2_context(text_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
308
+ )
309
+ text_h = self.ff_context(norm_h) * gate_mlp.unsqueeze(1) + text_h
310
+ text_out.append(clip_hidden_states(text_h))
311
+
312
+ image_out = []
313
+ for i in range(len(image_hidden_states)):
314
+ _, gate_msa, shift_mlp, scale_mlp, gate_mlp = img_variables[i]
315
+ image_h = (
316
+ image_hidden_states[i] + img_attn_output[i] * gate_msa.unsqueeze(1)
317
+ ).to(image_hidden_states[i].dtype)
318
+ norm_h = self.norm2(image_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
319
+ with specify_lora((self.ff.net[2],), adapters[i + txt_n]):
320
+ image_h = image_h + self.ff(norm_h) * gate_mlp.unsqueeze(1)
321
+ image_out.append(clip_hidden_states(image_h))
322
+ return image_out, text_out
323
+
324
+
325
+ def single_block_forward(
326
+ self,
327
+ hidden_states: List[torch.FloatTensor],
328
+ tembs: List[torch.FloatTensor],
329
+ adapters: List[str],
330
+ position_embs=None,
331
+ attn_forward=attn_forward,
332
+ **kwargs: dict,
333
+ ):
334
+ mlp_hidden_states, gates = [[None for _ in hidden_states] for _ in range(2)]
335
+
336
+ hidden_state_norm = []
337
+ for i, hidden_state in enumerate(hidden_states):
338
+ # [NOTE]!: This function's output is slightly DIFFERENT from the original
339
+ # FLUX version. In the original implementation, the gates were computed using
340
+ # the combined hidden states from both the image and text branches. Here, each
341
+ # branch computes its gate using only its own hidden state.
342
+ with specify_lora((self.norm.linear, self.proj_mlp), adapters[i]):
343
+ h_norm, gates[i] = self.norm(hidden_state, emb=tembs[i])
344
+ mlp_hidden_states[i] = self.act_mlp(self.proj_mlp(h_norm))
345
+ hidden_state_norm.append(h_norm)
346
+
347
+ attn_outputs = attn_forward(
348
+ self.attn, hidden_state_norm, adapters, position_embs=position_embs, **kwargs
349
+ )
350
+
351
+ h_out = []
352
+ for i in range(len(hidden_states)):
353
+ with specify_lora((self.proj_out,), adapters[i]):
354
+ h = torch.cat([attn_outputs[i], mlp_hidden_states[i]], dim=2)
355
+ h = gates[i].unsqueeze(1) * self.proj_out(h) + hidden_states[i]
356
+ h_out.append(clip_hidden_states(h))
357
+
358
+ return h_out
359
+
360
+
361
+ def transformer_forward(
362
+ transformer: FluxTransformer2DModel,
363
+ image_features: List[torch.Tensor],
364
+ text_features: List[torch.Tensor] = None,
365
+ img_ids: List[torch.Tensor] = None,
366
+ txt_ids: List[torch.Tensor] = None,
367
+ pooled_projections: List[torch.Tensor] = None,
368
+ timesteps: List[torch.LongTensor] = None,
369
+ guidances: List[torch.Tensor] = None,
370
+ adapters: List[str] = None,
371
+ # Assign the function to be used for the forward pass
372
+ single_block_forward=single_block_forward,
373
+ block_forward=block_forward,
374
+ attn_forward=attn_forward,
375
+ **kwargs: dict,
376
+ ):
377
+ self = transformer
378
+ txt_n = len(text_features) if text_features is not None else 0
379
+
380
+ adapters = adapters or [None] * (txt_n + len(image_features))
381
+ assert len(adapters) == len(timesteps)
382
+
383
+ # Preprocess the image_features
384
+ image_hidden_states = []
385
+ for i, image_feature in enumerate(image_features):
386
+ with specify_lora((self.x_embedder,), adapters[i + txt_n]):
387
+ image_hidden_states.append(self.x_embedder(image_feature))
388
+
389
+ # Preprocess the text_features
390
+ text_hidden_states = []
391
+ for text_feature in text_features:
392
+ text_hidden_states.append(self.context_embedder(text_feature))
393
+
394
+ # Prepare embeddings of (timestep, guidance, pooled_projections)
395
+ assert len(timesteps) == len(image_features) + len(text_features)
396
+
397
+ def get_temb(timestep, guidance, pooled_projection):
398
+ timestep = timestep.to(image_hidden_states[0].dtype) * 1000
399
+ if guidance is not None:
400
+ guidance = guidance.to(image_hidden_states[0].dtype) * 1000
401
+ return self.time_text_embed(timestep, guidance, pooled_projection)
402
+ else:
403
+ return self.time_text_embed(timestep, pooled_projection)
404
+
405
+ tembs = [get_temb(*each) for each in zip(timesteps, guidances, pooled_projections)]
406
+
407
+ # Prepare position embeddings for each token
408
+ position_embs = [self.pos_embed(each) for each in (*txt_ids, *img_ids)]
409
+
410
+ # Prepare the gradient checkpointing kwargs
411
+ gckpt_kwargs: Dict[str, Any] = (
412
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
413
+ )
414
+
415
+ # dual branch blocks
416
+ for block in self.transformer_blocks:
417
+ block_kwargs = {
418
+ "self": block,
419
+ "image_hidden_states": image_hidden_states,
420
+ "text_hidden_states": text_hidden_states,
421
+ "tembs": tembs,
422
+ "position_embs": position_embs,
423
+ "adapters": adapters,
424
+ "attn_forward": attn_forward,
425
+ **kwargs,
426
+ }
427
+ if self.training and self.gradient_checkpointing:
428
+ image_hidden_states, text_hidden_states = torch.utils.checkpoint.checkpoint(
429
+ block_forward, **block_kwargs, **gckpt_kwargs
430
+ )
431
+ else:
432
+ image_hidden_states, text_hidden_states = block_forward(**block_kwargs)
433
+
434
+ # combine image and text hidden states then pass through the single transformer blocks
435
+ all_hidden_states = [*text_hidden_states, *image_hidden_states]
436
+ for block in self.single_transformer_blocks:
437
+ block_kwargs = {
438
+ "self": block,
439
+ "hidden_states": all_hidden_states,
440
+ "tembs": tembs,
441
+ "position_embs": position_embs,
442
+ "adapters": adapters,
443
+ "attn_forward": attn_forward,
444
+ **kwargs,
445
+ }
446
+ if self.training and self.gradient_checkpointing:
447
+ all_hidden_states = torch.utils.checkpoint.checkpoint(
448
+ single_block_forward, **block_kwargs, **gckpt_kwargs
449
+ )
450
+ else:
451
+ all_hidden_states = single_block_forward(**block_kwargs)
452
+
453
+ image_hidden_states = self.norm_out(all_hidden_states[txt_n], tembs[txt_n])
454
+ output = self.proj_out(image_hidden_states)
455
+
456
+ return (output,)
457
+
458
+
459
+ @torch.no_grad()
460
+ def generate(
461
+ pipeline: FluxPipeline,
462
+ prompt: Union[str, List[str]] = None,
463
+ prompt_2: Optional[Union[str, List[str]]] = None,
464
+ height: Optional[int] = 512,
465
+ width: Optional[int] = 512,
466
+ num_inference_steps: int = 28,
467
+ timesteps: List[int] = None,
468
+ guidance_scale: float = 3.5,
469
+ num_images_per_prompt: Optional[int] = 1,
470
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
471
+ latents: Optional[torch.FloatTensor] = None,
472
+ prompt_embeds: Optional[torch.FloatTensor] = None,
473
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
474
+ output_type: Optional[str] = "pil",
475
+ return_dict: bool = True,
476
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
477
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
478
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
479
+ max_sequence_length: int = 512,
480
+ # Condition Parameters (Optional)
481
+ main_adapter: Optional[List[str]] = None,
482
+ conditions: List[Condition] = [],
483
+ image_guidance_scale: float = 1.0,
484
+ transformer_kwargs: Optional[Dict[str, Any]] = {},
485
+ kv_cache=False,
486
+ latent_mask=None,
487
+ **params: dict,
488
+ ):
489
+ self = pipeline
490
+
491
+ height = height or self.default_sample_size * self.vae_scale_factor
492
+ width = width or self.default_sample_size * self.vae_scale_factor
493
+
494
+ # Check inputs. Raise error if not correct
495
+ self.check_inputs(
496
+ prompt,
497
+ prompt_2,
498
+ height,
499
+ width,
500
+ prompt_embeds=prompt_embeds,
501
+ pooled_prompt_embeds=pooled_prompt_embeds,
502
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
503
+ max_sequence_length=max_sequence_length,
504
+ )
505
+
506
+ self._guidance_scale = guidance_scale
507
+ self._joint_attention_kwargs = joint_attention_kwargs
508
+
509
+ # Define call parameters
510
+ if prompt is not None and isinstance(prompt, str):
511
+ batch_size = 1
512
+ elif prompt is not None and isinstance(prompt, list):
513
+ batch_size = len(prompt)
514
+ else:
515
+ batch_size = prompt_embeds.shape[0]
516
+
517
+ device = self._execution_device
518
+
519
+ # Prepare prompt embeddings
520
+ (
521
+ prompt_embeds,
522
+ pooled_prompt_embeds,
523
+ text_ids,
524
+ ) = self.encode_prompt(
525
+ prompt=prompt,
526
+ prompt_2=prompt_2,
527
+ prompt_embeds=prompt_embeds,
528
+ pooled_prompt_embeds=pooled_prompt_embeds,
529
+ device=device,
530
+ num_images_per_prompt=num_images_per_prompt,
531
+ max_sequence_length=max_sequence_length,
532
+ )
533
+
534
+ # Prepare latent variables
535
+ num_channels_latents = self.transformer.config.in_channels // 4
536
+ latents, latent_image_ids = self.prepare_latents(
537
+ batch_size * num_images_per_prompt,
538
+ num_channels_latents,
539
+ height,
540
+ width,
541
+ prompt_embeds.dtype,
542
+ device,
543
+ generator,
544
+ latents,
545
+ )
546
+
547
+ if latent_mask is not None:
548
+ latent_mask = latent_mask.T.reshape(-1)
549
+ latents = latents[:, latent_mask]
550
+ latent_image_ids = latent_image_ids[latent_mask]
551
+
552
+ # Prepare conditions
553
+ c_latents, uc_latents, c_ids, c_timesteps = ([], [], [], [])
554
+ c_projections, c_guidances, c_adapters = ([], [], [])
555
+ complement_cond = None
556
+ for condition in conditions:
557
+ tokens, ids = condition.encode(self)
558
+ c_latents.append(tokens) # [batch_size, token_n, token_dim]
559
+ # Empty condition for unconditioned image
560
+ if image_guidance_scale != 1.0:
561
+ uc_latents.append(condition.encode(self, empty=True)[0])
562
+ c_ids.append(ids) # [token_n, id_dim(3)]
563
+ c_timesteps.append(torch.zeros([1], device=device))
564
+ c_projections.append(pooled_prompt_embeds)
565
+ c_guidances.append(torch.ones([1], device=device))
566
+ c_adapters.append(condition.adapter)
567
+ # This complement_condition will be combined with the original image.
568
+ # See the token integration of OminiControl2 [https://arxiv.org/abs/2503.08280]
569
+ if condition.is_complement:
570
+ complement_cond = (tokens, ids)
571
+
572
+ # Prepare timesteps
573
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
574
+ image_seq_len = latents.shape[1]
575
+ mu = calculate_shift(
576
+ image_seq_len,
577
+ self.scheduler.config.base_image_seq_len,
578
+ self.scheduler.config.max_image_seq_len,
579
+ self.scheduler.config.base_shift,
580
+ self.scheduler.config.max_shift,
581
+ )
582
+ timesteps, num_inference_steps = retrieve_timesteps(
583
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
584
+ )
585
+ num_warmup_steps = max(
586
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
587
+ )
588
+ self._num_timesteps = len(timesteps)
589
+
590
+ if kv_cache:
591
+ attn_counter = 0
592
+ for module in self.transformer.modules():
593
+ if isinstance(module, Attention):
594
+ setattr(module, "cache_idx", attn_counter)
595
+ attn_counter += 1
596
+ kv_cond = [[[], []] for _ in range(attn_counter)]
597
+ kv_uncond = [[[], []] for _ in range(attn_counter)]
598
+
599
+ def clear_cache():
600
+ for storage in [kv_cond, kv_uncond]:
601
+ for kesy, values in storage:
602
+ kesy.clear()
603
+ values.clear()
604
+
605
+ branch_n = len(conditions) + 2
606
+ group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool)
607
+ # Disable the attention cross different condition branches
608
+ group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions)))
609
+ # Disable the attention from condition branches to image branch and text branch
610
+ if kv_cache:
611
+ group_mask[2:, :2] = False
612
+
613
+ # Denoising loop
614
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
615
+ for i, t in enumerate(timesteps):
616
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
617
+ timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000
618
+
619
+ # handle guidance
620
+ if self.transformer.config.guidance_embeds:
621
+ guidance = torch.tensor([guidance_scale], device=device)
622
+ guidance = guidance.expand(latents.shape[0])
623
+ else:
624
+ guidance, c_guidances = None, [None for _ in c_guidances]
625
+
626
+ if kv_cache:
627
+ mode = "write" if i == 0 else "read"
628
+ if mode == "write":
629
+ clear_cache()
630
+ use_cond = not (kv_cache) or mode == "write"
631
+
632
+ noise_pred = transformer_forward(
633
+ self.transformer,
634
+ image_features=[latents] + (c_latents if use_cond else []),
635
+ text_features=[prompt_embeds],
636
+ img_ids=[latent_image_ids] + (c_ids if use_cond else []),
637
+ txt_ids=[text_ids],
638
+ timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
639
+ pooled_projections=[pooled_prompt_embeds] * 2
640
+ + (c_projections if use_cond else []),
641
+ guidances=[guidance] * 2 + (c_guidances if use_cond else []),
642
+ return_dict=False,
643
+ adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
644
+ cache_mode=mode if kv_cache else None,
645
+ cache_storage=kv_cond if kv_cache else None,
646
+ to_cache=[False, False, *[True] * len(c_latents)],
647
+ group_mask=group_mask,
648
+ **transformer_kwargs,
649
+ )[0]
650
+
651
+ if image_guidance_scale != 1.0:
652
+ unc_pred = transformer_forward(
653
+ self.transformer,
654
+ image_features=[latents] + (uc_latents if use_cond else []),
655
+ text_features=[prompt_embeds],
656
+ img_ids=[latent_image_ids] + (c_ids if use_cond else []),
657
+ txt_ids=[text_ids],
658
+ timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
659
+ pooled_projections=[pooled_prompt_embeds] * 2
660
+ + (c_projections if use_cond else []),
661
+ guidances=[guidance] * 2 + (c_guidances if use_cond else []),
662
+ return_dict=False,
663
+ adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
664
+ cache_mode=mode if kv_cache else None,
665
+ cache_storage=kv_uncond if kv_cache else None,
666
+ to_cache=[False, False, *[True] * len(c_latents)],
667
+ **transformer_kwargs,
668
+ )[0]
669
+
670
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
671
+
672
+ # compute the previous noisy sample x_t -> x_t-1
673
+ latents_dtype = latents.dtype
674
+ latents = self.scheduler.step(noise_pred, t, latents)[0]
675
+
676
+ if latents.dtype != latents_dtype:
677
+ if torch.backends.mps.is_available():
678
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
679
+ latents = latents.to(latents_dtype)
680
+
681
+ if callback_on_step_end is not None:
682
+ callback_kwargs = {}
683
+ for k in callback_on_step_end_tensor_inputs:
684
+ callback_kwargs[k] = locals()[k]
685
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
686
+
687
+ latents = callback_outputs.pop("latents", latents)
688
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
689
+
690
+ # call the callback, if provided
691
+ if i == len(timesteps) - 1 or (
692
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
693
+ ):
694
+ progress_bar.update()
695
+
696
+ if latent_mask is not None:
697
+ # Combine the generated latents and the complement condition
698
+ assert complement_cond is not None
699
+ comp_latent, comp_ids = complement_cond
700
+ all_ids = torch.cat([latent_image_ids, comp_ids], dim=0) # (Ta+Tc,3)
701
+ shape = (all_ids.max(dim=0).values + 1).to(torch.long) # (3,)
702
+ H, W = shape[1].item(), shape[2].item()
703
+ B, _, C = latents.shape
704
+ # Create a empty canvas
705
+ canvas = latents.new_zeros(B, H * W, C) # (B,H*W,C)
706
+
707
+ # Stash the latents and the complement condition
708
+ def _stash(canvas, tokens, ids, H, W) -> None:
709
+ B, T, C = tokens.shape
710
+ ids = ids.to(torch.long)
711
+ flat_idx = (ids[:, 1] * W + ids[:, 2]).to(torch.long)
712
+ canvas.view(B, -1, C).index_copy_(1, flat_idx, tokens)
713
+
714
+ _stash(canvas, latents, latent_image_ids, H, W)
715
+ _stash(canvas, comp_latent, comp_ids, H, W)
716
+ latents = canvas.view(B, H * W, C)
717
+
718
+ if output_type == "latent":
719
+ image = latents
720
+ else:
721
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
722
+ latents = (
723
+ latents / self.vae.config.scaling_factor
724
+ ) + self.vae.config.shift_factor
725
+ image = self.vae.decode(latents, return_dict=False)[0]
726
+ image = self.image_processor.postprocess(image, output_type=output_type)
727
+
728
+ # Offload all models
729
+ self.maybe_free_model_hooks()
730
+
731
+ if not return_dict:
732
+ return (image,)
733
+
734
+ return FluxPipelineOutput(images=image)
omini/pipeline/flux_omini_ablate_qkv.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This version is for ablation study for the effect of scaling in LORA adapters.
3
+
4
+ The `generate` function is modified to include a `global_scale` parameter,
5
+ the `SCALE` variable is set globally at the start of generation.
6
+ """
7
+
8
+ import torch
9
+ from typing import List, Union, Optional, Dict, Any, Callable, Type, Tuple
10
+
11
+ from diffusers.pipelines import FluxPipeline
12
+ from diffusers.pipelines.flux.pipeline_flux import (
13
+ FluxPipelineOutput,
14
+ FluxTransformer2DModel,
15
+ calculate_shift,
16
+ retrieve_timesteps,
17
+ np,
18
+ )
19
+ from diffusers.models.attention_processor import Attention, F
20
+ from diffusers.models.embeddings import apply_rotary_emb
21
+ from transformers import pipeline
22
+
23
+ from peft.tuners.tuners_utils import BaseTunerLayer
24
+ from accelerate.utils import is_torch_version
25
+
26
+ from contextlib import contextmanager
27
+
28
+ import cv2
29
+
30
+ from PIL import Image, ImageFilter
31
+
32
+ T_Q = None
33
+ T_K = None
34
+ T_V = None
35
+
36
+ def seed_everything(seed: int = 42):
37
+ torch.backends.cudnn.deterministic = True
38
+ torch.manual_seed(seed)
39
+ np.random.seed(seed)
40
+
41
+
42
+ def clip_hidden_states(hidden_states: torch.FloatTensor) -> torch.FloatTensor:
43
+ if hidden_states.dtype == torch.float16:
44
+ hidden_states = hidden_states.clip(-65504, 65504)
45
+ return hidden_states
46
+
47
+
48
+ def encode_images(pipeline: FluxPipeline, images: torch.Tensor):
49
+ """
50
+ Encodes the images into tokens and ids for FLUX pipeline.
51
+ """
52
+ images = pipeline.image_processor.preprocess(images)
53
+ images = images.to(pipeline.device).to(pipeline.dtype)
54
+ images = pipeline.vae.encode(images).latent_dist.sample()
55
+ images = (
56
+ images - pipeline.vae.config.shift_factor
57
+ ) * pipeline.vae.config.scaling_factor
58
+ images_tokens = pipeline._pack_latents(images, *images.shape)
59
+ images_ids = pipeline._prepare_latent_image_ids(
60
+ images.shape[0],
61
+ images.shape[2],
62
+ images.shape[3],
63
+ pipeline.device,
64
+ pipeline.dtype,
65
+ )
66
+ if images_tokens.shape[1] != images_ids.shape[0]:
67
+ images_ids = pipeline._prepare_latent_image_ids(
68
+ images.shape[0],
69
+ images.shape[2] // 2,
70
+ images.shape[3] // 2,
71
+ pipeline.device,
72
+ pipeline.dtype,
73
+ )
74
+ return images_tokens, images_ids
75
+
76
+
77
+ depth_pipe = None
78
+
79
+
80
+ def convert_to_condition(
81
+ condition_type: str,
82
+ raw_img: Union[Image.Image, torch.Tensor],
83
+ blur_radius: Optional[int] = 5,
84
+ ) -> Union[Image.Image, torch.Tensor]:
85
+ if condition_type == "depth":
86
+ global depth_pipe
87
+ depth_pipe = depth_pipe or pipeline(
88
+ task="depth-estimation",
89
+ model="LiheYoung/depth-anything-small-hf",
90
+ device="cpu", # Use "cpu" to enable parallel processing
91
+ )
92
+ source_image = raw_img.convert("RGB")
93
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
94
+ return condition_img
95
+ elif condition_type == "canny":
96
+ img = np.array(raw_img)
97
+ edges = cv2.Canny(img, 100, 200)
98
+ edges = Image.fromarray(edges).convert("RGB")
99
+ return edges
100
+ elif condition_type == "coloring":
101
+ return raw_img.convert("L").convert("RGB")
102
+ elif condition_type == "deblurring":
103
+ condition_image = (
104
+ raw_img.convert("RGB")
105
+ .filter(ImageFilter.GaussianBlur(blur_radius))
106
+ .convert("RGB")
107
+ )
108
+ return condition_image
109
+ else:
110
+ print("Warning: Returning the raw image.")
111
+ return raw_img.convert("RGB")
112
+
113
+
114
+ class Condition(object):
115
+ def __init__(
116
+ self,
117
+ condition: Union[Image.Image, torch.Tensor],
118
+ adapter_setting: Union[str, dict],
119
+ position_delta=None,
120
+ position_scale=1.0,
121
+ latent_mask=None,
122
+ is_complement=False,
123
+ ) -> None:
124
+ self.condition = condition
125
+ self.adapter = adapter_setting
126
+ self.position_delta = position_delta
127
+ self.position_scale = position_scale
128
+ self.latent_mask = (
129
+ latent_mask.T.reshape(-1) if latent_mask is not None else None
130
+ )
131
+ self.is_complement = is_complement
132
+
133
+ def encode(
134
+ self, pipe: FluxPipeline, empty: bool = False
135
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
136
+ condition_empty = Image.new("RGB", self.condition.size, (0, 0, 0))
137
+ tokens, ids = encode_images(pipe, condition_empty if empty else self.condition)
138
+
139
+ if self.position_delta is not None:
140
+ ids[:, 1] += self.position_delta[0]
141
+ ids[:, 2] += self.position_delta[1]
142
+
143
+ if self.position_scale != 1.0:
144
+ scale_bias = (self.position_scale - 1.0) / 2
145
+ ids[:, 1:] *= self.position_scale
146
+ ids[:, 1:] += scale_bias
147
+
148
+ if self.latent_mask is not None:
149
+ tokens = tokens[:, self.latent_mask]
150
+ ids = ids[self.latent_mask]
151
+
152
+ return tokens, ids
153
+
154
+
155
+ @contextmanager
156
+ def specify_lora(lora_modules: List[BaseTunerLayer], specified_lora, T=None):
157
+ # Filter valid lora modules
158
+ valid_lora_modules = [m for m in lora_modules if isinstance(m, BaseTunerLayer)]
159
+ # Save original scales
160
+ original_scales = [
161
+ {
162
+ adapter: module.scaling[adapter]
163
+ for adapter in module.active_adapters
164
+ if adapter in module.scaling
165
+ }
166
+ for module in valid_lora_modules
167
+ ]
168
+ # Enter context: adjust scaling
169
+ for module in valid_lora_modules:
170
+ for adapter in module.active_adapters:
171
+ if adapter in module.scaling:
172
+ module.scaling[adapter] = 1. if adapter == specified_lora else 0
173
+
174
+ if hasattr(module, 'rotation') and T is not None:
175
+ # alter T if specified
176
+ if adapter in module.rotation:
177
+ # print("FOR DEBUG:entering specify_lora context: setting T")
178
+ module.rotation[adapter].T = T
179
+
180
+ try:
181
+ yield
182
+ finally:
183
+ # Exit context: restore original scales
184
+ for module, scales in zip(valid_lora_modules, original_scales):
185
+ for adapter in module.active_adapters:
186
+ if adapter in module.scaling:
187
+ module.scaling[adapter] = scales[adapter]
188
+
189
+
190
+ def attn_forward(
191
+ attn: Attention,
192
+ hidden_states: List[torch.FloatTensor],
193
+ adapters: List[str],
194
+ hidden_states2: Optional[List[torch.FloatTensor]] = [],
195
+ position_embs: Optional[List[torch.Tensor]] = None,
196
+ group_mask: Optional[torch.Tensor] = None,
197
+ cache_mode: Optional[str] = None,
198
+ # to determine whether to cache the keys and values for this branch
199
+ to_cache: Optional[List[torch.Tensor]] = None,
200
+ cache_storage: Optional[List[torch.Tensor]] = None,
201
+ **kwargs: dict,
202
+ ) -> torch.FloatTensor:
203
+ bs, _, _ = hidden_states[0].shape
204
+ h2_n = len(hidden_states2)
205
+
206
+ queries, keys, values = [], [], []
207
+
208
+ # Prepare query, key, value for each encoder hidden state (text branch)
209
+ for i, hidden_state in enumerate(hidden_states2):
210
+ query = attn.add_q_proj(hidden_state)
211
+ key = attn.add_k_proj(hidden_state)
212
+ value = attn.add_v_proj(hidden_state)
213
+
214
+ head_dim = key.shape[-1] // attn.heads
215
+ reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
216
+
217
+ query, key, value = map(reshape_fn, (query, key, value))
218
+ query, key = attn.norm_added_q(query), attn.norm_added_k(key)
219
+
220
+ queries.append(query)
221
+ keys.append(key)
222
+ values.append(value)
223
+
224
+
225
+ ## THIS IS THE MODIFIED PART TO ABALTE QKV ROTATION T ##
226
+ # Prepare query, key, value for each hidden state (image branch)
227
+ for i, hidden_state in enumerate(hidden_states):
228
+ with specify_lora((attn.to_q,), adapters[i + h2_n], T=T_Q):
229
+ query = attn.to_q(hidden_state)
230
+
231
+ with specify_lora((attn.to_k,), adapters[i + h2_n], T=T_K):
232
+ key = attn.to_k(hidden_state)
233
+
234
+ with specify_lora((attn.to_v,), adapters[i + h2_n], T=T_V):
235
+ value = attn.to_v(hidden_state)
236
+
237
+ head_dim = key.shape[-1] // attn.heads
238
+ reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
239
+
240
+ query, key, value = map(reshape_fn, (query, key, value))
241
+ query, key = attn.norm_q(query), attn.norm_k(key)
242
+
243
+ queries.append(query)
244
+ keys.append(key)
245
+ values.append(value)
246
+
247
+ # Apply rotary embedding
248
+ if position_embs is not None:
249
+ queries = [apply_rotary_emb(q, position_embs[i]) for i, q in enumerate(queries)]
250
+ keys = [apply_rotary_emb(k, position_embs[i]) for i, k in enumerate(keys)]
251
+
252
+ if cache_mode == "write":
253
+ for i, (k, v) in enumerate(zip(keys, values)):
254
+ if to_cache[i]:
255
+ cache_storage[attn.cache_idx][0].append(k)
256
+ cache_storage[attn.cache_idx][1].append(v)
257
+
258
+ attn_outputs = []
259
+ for i, query in enumerate(queries):
260
+ keys_, values_ = [], []
261
+ # Add keys and values from other branches
262
+ for j, (k, v) in enumerate(zip(keys, values)):
263
+ if (group_mask is not None) and not (group_mask[i][j].item()):
264
+ continue
265
+ keys_.append(k)
266
+ values_.append(v)
267
+ if cache_mode == "read":
268
+ keys_.extend(cache_storage[attn.cache_idx][0])
269
+ values_.extend(cache_storage[attn.cache_idx][1])
270
+ # Add keys and values from cache TODO
271
+ # Attention computation
272
+ attn_output = F.scaled_dot_product_attention(
273
+ query, torch.cat(keys_, dim=2), torch.cat(values_, dim=2)
274
+ ).to(query.dtype)
275
+ attn_output = attn_output.transpose(1, 2).reshape(bs, -1, attn.heads * head_dim)
276
+ attn_outputs.append(attn_output)
277
+
278
+ # Reshape attention output to match the original hidden states
279
+ h_out, h2_out = [], []
280
+
281
+ for i, hidden_state in enumerate(hidden_states2):
282
+ h2_out.append(attn.to_add_out(attn_outputs[i]))
283
+
284
+ for i, hidden_state in enumerate(hidden_states):
285
+ h = attn_outputs[i + h2_n]
286
+ if getattr(attn, "to_out", None) is not None:
287
+ with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
288
+ h = attn.to_out[0](h)
289
+ h_out.append(h)
290
+
291
+ return (h_out, h2_out) if h2_n else h_out
292
+
293
+
294
+ def block_forward(
295
+ self,
296
+ image_hidden_states: List[torch.FloatTensor],
297
+ text_hidden_states: List[torch.FloatTensor],
298
+ tembs: List[torch.FloatTensor],
299
+ adapters: List[str],
300
+ position_embs=None,
301
+ attn_forward=attn_forward,
302
+ **kwargs: dict,
303
+ ):
304
+ txt_n = len(text_hidden_states)
305
+
306
+ img_variables, txt_variables = [], []
307
+
308
+ for i, text_h in enumerate(text_hidden_states):
309
+ txt_variables.append(self.norm1_context(text_h, emb=tembs[i]))
310
+
311
+ for i, image_h in enumerate(image_hidden_states):
312
+ with specify_lora((self.norm1.linear,), adapters[i + txt_n]):
313
+ img_variables.append(self.norm1(image_h, emb=tembs[i + txt_n]))
314
+
315
+ # Attention.
316
+ img_attn_output, txt_attn_output = attn_forward(
317
+ self.attn,
318
+ hidden_states=[each[0] for each in img_variables],
319
+ hidden_states2=[each[0] for each in txt_variables],
320
+ position_embs=position_embs,
321
+ adapters=adapters,
322
+ **kwargs,
323
+ )
324
+
325
+ text_out = []
326
+ for i in range(len(text_hidden_states)):
327
+ _, gate_msa, shift_mlp, scale_mlp, gate_mlp = txt_variables[i]
328
+ text_h = text_hidden_states[i] + txt_attn_output[i] * gate_msa.unsqueeze(1)
329
+ norm_h = (
330
+ self.norm2_context(text_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
331
+ )
332
+ text_h = self.ff_context(norm_h) * gate_mlp.unsqueeze(1) + text_h
333
+ text_out.append(clip_hidden_states(text_h))
334
+
335
+ image_out = []
336
+ for i in range(len(image_hidden_states)):
337
+ _, gate_msa, shift_mlp, scale_mlp, gate_mlp = img_variables[i]
338
+ image_h = (
339
+ image_hidden_states[i] + img_attn_output[i] * gate_msa.unsqueeze(1)
340
+ ).to(image_hidden_states[i].dtype)
341
+ norm_h = self.norm2(image_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
342
+ with specify_lora((self.ff.net[2],), adapters[i + txt_n]):
343
+ image_h = image_h + self.ff(norm_h) * gate_mlp.unsqueeze(1)
344
+ image_out.append(clip_hidden_states(image_h))
345
+ return image_out, text_out
346
+
347
+
348
+ def single_block_forward(
349
+ self,
350
+ hidden_states: List[torch.FloatTensor],
351
+ tembs: List[torch.FloatTensor],
352
+ adapters: List[str],
353
+ position_embs=None,
354
+ attn_forward=attn_forward,
355
+ **kwargs: dict,
356
+ ):
357
+ mlp_hidden_states, gates = [[None for _ in hidden_states] for _ in range(2)]
358
+
359
+ hidden_state_norm = []
360
+ for i, hidden_state in enumerate(hidden_states):
361
+ # [NOTE]!: This function's output is slightly DIFFERENT from the original
362
+ # FLUX version. In the original implementation, the gates were computed using
363
+ # the combined hidden states from both the image and text branches. Here, each
364
+ # branch computes its gate using only its own hidden state.
365
+ with specify_lora((self.norm.linear, self.proj_mlp), adapters[i]):
366
+ h_norm, gates[i] = self.norm(hidden_state, emb=tembs[i])
367
+ mlp_hidden_states[i] = self.act_mlp(self.proj_mlp(h_norm))
368
+ hidden_state_norm.append(h_norm)
369
+
370
+ attn_outputs = attn_forward(
371
+ self.attn, hidden_state_norm, adapters, position_embs=position_embs, **kwargs
372
+ )
373
+
374
+ h_out = []
375
+ for i in range(len(hidden_states)):
376
+ with specify_lora((self.proj_out,), adapters[i]):
377
+ h = torch.cat([attn_outputs[i], mlp_hidden_states[i]], dim=2)
378
+ h = gates[i].unsqueeze(1) * self.proj_out(h) + hidden_states[i]
379
+ h_out.append(clip_hidden_states(h))
380
+
381
+ return h_out
382
+
383
+
384
+ def transformer_forward(
385
+ transformer: FluxTransformer2DModel,
386
+ image_features: List[torch.Tensor],
387
+ text_features: List[torch.Tensor] = None,
388
+ img_ids: List[torch.Tensor] = None,
389
+ txt_ids: List[torch.Tensor] = None,
390
+ pooled_projections: List[torch.Tensor] = None,
391
+ timesteps: List[torch.LongTensor] = None,
392
+ guidances: List[torch.Tensor] = None,
393
+ adapters: List[str] = None,
394
+ # Assign the function to be used for the forward pass
395
+ single_block_forward=single_block_forward,
396
+ block_forward=block_forward,
397
+ attn_forward=attn_forward,
398
+ **kwargs: dict,
399
+ ):
400
+ self = transformer
401
+ txt_n = len(text_features) if text_features is not None else 0
402
+
403
+ adapters = adapters or [None] * (txt_n + len(image_features))
404
+ assert len(adapters) == len(timesteps)
405
+
406
+ # Preprocess the image_features
407
+ image_hidden_states = []
408
+ for i, image_feature in enumerate(image_features):
409
+ with specify_lora((self.x_embedder,), adapters[i + txt_n]):
410
+ image_hidden_states.append(self.x_embedder(image_feature))
411
+
412
+ # Preprocess the text_features
413
+ text_hidden_states = []
414
+ for text_feature in text_features:
415
+ text_hidden_states.append(self.context_embedder(text_feature))
416
+
417
+ # Prepare embeddings of (timestep, guidance, pooled_projections)
418
+ assert len(timesteps) == len(image_features) + len(text_features)
419
+
420
+ def get_temb(timestep, guidance, pooled_projection):
421
+ timestep = timestep.to(image_hidden_states[0].dtype) * 1000
422
+ if guidance is not None:
423
+ guidance = guidance.to(image_hidden_states[0].dtype) * 1000
424
+ return self.time_text_embed(timestep, guidance, pooled_projection)
425
+ else:
426
+ return self.time_text_embed(timestep, pooled_projection)
427
+
428
+ tembs = [get_temb(*each) for each in zip(timesteps, guidances, pooled_projections)]
429
+
430
+ # Prepare position embeddings for each token
431
+ position_embs = [self.pos_embed(each) for each in (*txt_ids, *img_ids)]
432
+
433
+ # Prepare the gradient checkpointing kwargs
434
+ gckpt_kwargs: Dict[str, Any] = (
435
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
436
+ )
437
+
438
+ # dual branch blocks
439
+ for block in self.transformer_blocks:
440
+ block_kwargs = {
441
+ "self": block,
442
+ "image_hidden_states": image_hidden_states,
443
+ "text_hidden_states": text_hidden_states,
444
+ "tembs": tembs,
445
+ "position_embs": position_embs,
446
+ "adapters": adapters,
447
+ "attn_forward": attn_forward,
448
+ **kwargs,
449
+ }
450
+ if self.training and self.gradient_checkpointing:
451
+ image_hidden_states, text_hidden_states = torch.utils.checkpoint.checkpoint(
452
+ block_forward, **block_kwargs, **gckpt_kwargs
453
+ )
454
+ else:
455
+ image_hidden_states, text_hidden_states = block_forward(**block_kwargs)
456
+
457
+ # combine image and text hidden states then pass through the single transformer blocks
458
+ all_hidden_states = [*text_hidden_states, *image_hidden_states]
459
+ for block in self.single_transformer_blocks:
460
+ block_kwargs = {
461
+ "self": block,
462
+ "hidden_states": all_hidden_states,
463
+ "tembs": tembs,
464
+ "position_embs": position_embs,
465
+ "adapters": adapters,
466
+ "attn_forward": attn_forward,
467
+ **kwargs,
468
+ }
469
+ if self.training and self.gradient_checkpointing:
470
+ all_hidden_states = torch.utils.checkpoint.checkpoint(
471
+ single_block_forward, **block_kwargs, **gckpt_kwargs
472
+ )
473
+ else:
474
+ all_hidden_states = single_block_forward(**block_kwargs)
475
+
476
+ image_hidden_states = self.norm_out(all_hidden_states[txt_n], tembs[txt_n])
477
+ output = self.proj_out(image_hidden_states)
478
+
479
+ return (output,)
480
+
481
+
482
+ @torch.no_grad()
483
+ def generate(
484
+ pipeline: FluxPipeline,
485
+ prompt: Union[str, List[str]] = None,
486
+ prompt_2: Optional[Union[str, List[str]]] = None,
487
+ height: Optional[int] = 512,
488
+ width: Optional[int] = 512,
489
+ num_inference_steps: int = 28,
490
+ timesteps: List[int] = None,
491
+ guidance_scale: float = 3.5,
492
+ num_images_per_prompt: Optional[int] = 1,
493
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
494
+ latents: Optional[torch.FloatTensor] = None,
495
+ prompt_embeds: Optional[torch.FloatTensor] = None,
496
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
497
+ output_type: Optional[str] = "pil",
498
+ return_dict: bool = True,
499
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
500
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
501
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
502
+ max_sequence_length: int = 512,
503
+ # Condition Parameters (Optional)
504
+ main_adapter: Optional[List[str]] = None,
505
+ conditions: List[Condition] = [],
506
+ image_guidance_scale: float = 1.0,
507
+ transformer_kwargs: Optional[Dict[str, Any]] = {},
508
+ kv_cache=False,
509
+ latent_mask=None,
510
+ global_T_Q=None,
511
+ global_T_K=None,
512
+ global_T_V=None,
513
+ **params: dict,
514
+ ):
515
+
516
+ # Set global T_Q, T_K, T_V if provided
517
+ if global_T_Q is not None:
518
+ global T_Q
519
+ T_Q = global_T_Q
520
+ if global_T_K is not None:
521
+ global T_K
522
+ T_K = global_T_K
523
+ if global_T_V is not None:
524
+ global T_V
525
+ T_V = global_T_V
526
+
527
+ self = pipeline
528
+
529
+ height = height or self.default_sample_size * self.vae_scale_factor
530
+ width = width or self.default_sample_size * self.vae_scale_factor
531
+
532
+ # Check inputs. Raise error if not correct
533
+ self.check_inputs(
534
+ prompt,
535
+ prompt_2,
536
+ height,
537
+ width,
538
+ prompt_embeds=prompt_embeds,
539
+ pooled_prompt_embeds=pooled_prompt_embeds,
540
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
541
+ max_sequence_length=max_sequence_length,
542
+ )
543
+
544
+ self._guidance_scale = guidance_scale
545
+ self._joint_attention_kwargs = joint_attention_kwargs
546
+
547
+ # Define call parameters
548
+ if prompt is not None and isinstance(prompt, str):
549
+ batch_size = 1
550
+ elif prompt is not None and isinstance(prompt, list):
551
+ batch_size = len(prompt)
552
+ else:
553
+ batch_size = prompt_embeds.shape[0]
554
+
555
+ device = self._execution_device
556
+
557
+ # Prepare prompt embeddings
558
+ (
559
+ prompt_embeds,
560
+ pooled_prompt_embeds,
561
+ text_ids,
562
+ ) = self.encode_prompt(
563
+ prompt=prompt,
564
+ prompt_2=prompt_2,
565
+ prompt_embeds=prompt_embeds,
566
+ pooled_prompt_embeds=pooled_prompt_embeds,
567
+ device=device,
568
+ num_images_per_prompt=num_images_per_prompt,
569
+ max_sequence_length=max_sequence_length,
570
+ )
571
+
572
+ # Prepare latent variables
573
+ num_channels_latents = self.transformer.config.in_channels // 4
574
+ latents, latent_image_ids = self.prepare_latents(
575
+ batch_size * num_images_per_prompt,
576
+ num_channels_latents,
577
+ height,
578
+ width,
579
+ prompt_embeds.dtype,
580
+ device,
581
+ generator,
582
+ latents,
583
+ )
584
+
585
+ if latent_mask is not None:
586
+ latent_mask = latent_mask.T.reshape(-1)
587
+ latents = latents[:, latent_mask]
588
+ latent_image_ids = latent_image_ids[latent_mask]
589
+
590
+ # Prepare conditions
591
+ c_latents, uc_latents, c_ids, c_timesteps = ([], [], [], [])
592
+ c_projections, c_guidances, c_adapters = ([], [], [])
593
+ complement_cond = None
594
+ for condition in conditions:
595
+ tokens, ids = condition.encode(self)
596
+ c_latents.append(tokens) # [batch_size, token_n, token_dim]
597
+ # Empty condition for unconditioned image
598
+ if image_guidance_scale != 1.0:
599
+ uc_latents.append(condition.encode(self, empty=True)[0])
600
+ c_ids.append(ids) # [token_n, id_dim(3)]
601
+ c_timesteps.append(torch.zeros([1], device=device))
602
+ c_projections.append(pooled_prompt_embeds)
603
+ c_guidances.append(torch.ones([1], device=device))
604
+ c_adapters.append(condition.adapter)
605
+ # This complement_condition will be combined with the original image.
606
+ # See the token integration of OminiControl2 [https://arxiv.org/abs/2503.08280]
607
+ if condition.is_complement:
608
+ complement_cond = (tokens, ids)
609
+
610
+ # Prepare timesteps
611
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
612
+ image_seq_len = latents.shape[1]
613
+ mu = calculate_shift(
614
+ image_seq_len,
615
+ self.scheduler.config.base_image_seq_len,
616
+ self.scheduler.config.max_image_seq_len,
617
+ self.scheduler.config.base_shift,
618
+ self.scheduler.config.max_shift,
619
+ )
620
+ timesteps, num_inference_steps = retrieve_timesteps(
621
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
622
+ )
623
+ num_warmup_steps = max(
624
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
625
+ )
626
+ self._num_timesteps = len(timesteps)
627
+
628
+ if kv_cache:
629
+ attn_counter = 0
630
+ for module in self.transformer.modules():
631
+ if isinstance(module, Attention):
632
+ setattr(module, "cache_idx", attn_counter)
633
+ attn_counter += 1
634
+ kv_cond = [[[], []] for _ in range(attn_counter)]
635
+ kv_uncond = [[[], []] for _ in range(attn_counter)]
636
+
637
+ def clear_cache():
638
+ for storage in [kv_cond, kv_uncond]:
639
+ for kesy, values in storage:
640
+ kesy.clear()
641
+ values.clear()
642
+
643
+ branch_n = len(conditions) + 2
644
+ group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool)
645
+ # Disable the attention cross different condition branches
646
+ group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions)))
647
+ # Disable the attention from condition branches to image branch and text branch
648
+ if kv_cache:
649
+ group_mask[2:, :2] = False
650
+
651
+ # Denoising loop
652
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
653
+ for i, t in enumerate(timesteps):
654
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
655
+ timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000
656
+
657
+ # handle guidance
658
+ if self.transformer.config.guidance_embeds:
659
+ guidance = torch.tensor([guidance_scale], device=device)
660
+ guidance = guidance.expand(latents.shape[0])
661
+ else:
662
+ guidance, c_guidances = None, [None for _ in c_guidances]
663
+
664
+ if kv_cache:
665
+ mode = "write" if i == 0 else "read"
666
+ if mode == "write":
667
+ clear_cache()
668
+ use_cond = not (kv_cache) or mode == "write"
669
+
670
+ noise_pred = transformer_forward(
671
+ self.transformer,
672
+ image_features=[latents] + (c_latents if use_cond else []),
673
+ text_features=[prompt_embeds],
674
+ img_ids=[latent_image_ids] + (c_ids if use_cond else []),
675
+ txt_ids=[text_ids],
676
+ timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
677
+ pooled_projections=[pooled_prompt_embeds] * 2
678
+ + (c_projections if use_cond else []),
679
+ guidances=[guidance] * 2 + (c_guidances if use_cond else []),
680
+ return_dict=False,
681
+ adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
682
+ cache_mode=mode if kv_cache else None,
683
+ cache_storage=kv_cond if kv_cache else None,
684
+ to_cache=[False, False, *[True] * len(c_latents)],
685
+ group_mask=group_mask,
686
+ **transformer_kwargs,
687
+ )[0]
688
+
689
+ if image_guidance_scale != 1.0:
690
+ unc_pred = transformer_forward(
691
+ self.transformer,
692
+ image_features=[latents] + (uc_latents if use_cond else []),
693
+ text_features=[prompt_embeds],
694
+ img_ids=[latent_image_ids] + (c_ids if use_cond else []),
695
+ txt_ids=[text_ids],
696
+ timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
697
+ pooled_projections=[pooled_prompt_embeds] * 2
698
+ + (c_projections if use_cond else []),
699
+ guidances=[guidance] * 2 + (c_guidances if use_cond else []),
700
+ return_dict=False,
701
+ adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
702
+ cache_mode=mode if kv_cache else None,
703
+ cache_storage=kv_uncond if kv_cache else None,
704
+ to_cache=[False, False, *[True] * len(c_latents)],
705
+ **transformer_kwargs,
706
+ )[0]
707
+
708
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
709
+
710
+ # compute the previous noisy sample x_t -> x_t-1
711
+ latents_dtype = latents.dtype
712
+ latents = self.scheduler.step(noise_pred, t, latents)[0]
713
+
714
+ if latents.dtype != latents_dtype:
715
+ if torch.backends.mps.is_available():
716
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
717
+ latents = latents.to(latents_dtype)
718
+
719
+ if callback_on_step_end is not None:
720
+ callback_kwargs = {}
721
+ for k in callback_on_step_end_tensor_inputs:
722
+ callback_kwargs[k] = locals()[k]
723
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
724
+
725
+ latents = callback_outputs.pop("latents", latents)
726
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
727
+
728
+ # call the callback, if provided
729
+ if i == len(timesteps) - 1 or (
730
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
731
+ ):
732
+ progress_bar.update()
733
+
734
+ if latent_mask is not None:
735
+ # Combine the generated latents and the complement condition
736
+ assert complement_cond is not None
737
+ comp_latent, comp_ids = complement_cond
738
+ all_ids = torch.cat([latent_image_ids, comp_ids], dim=0) # (Ta+Tc,3)
739
+ shape = (all_ids.max(dim=0).values + 1).to(torch.long) # (3,)
740
+ H, W = shape[1].item(), shape[2].item()
741
+ B, _, C = latents.shape
742
+ # Create a empty canvas
743
+ canvas = latents.new_zeros(B, H * W, C) # (B,H*W,C)
744
+
745
+ # Stash the latents and the complement condition
746
+ def _stash(canvas, tokens, ids, H, W) -> None:
747
+ B, T, C = tokens.shape
748
+ ids = ids.to(torch.long)
749
+ flat_idx = (ids[:, 1] * W + ids[:, 2]).to(torch.long)
750
+ canvas.view(B, -1, C).index_copy_(1, flat_idx, tokens)
751
+
752
+ _stash(canvas, latents, latent_image_ids, H, W)
753
+ _stash(canvas, comp_latent, comp_ids, H, W)
754
+ latents = canvas.view(B, H * W, C)
755
+
756
+ if output_type == "latent":
757
+ image = latents
758
+ else:
759
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
760
+ latents = (
761
+ latents / self.vae.config.scaling_factor
762
+ ) + self.vae.config.shift_factor
763
+ image = self.vae.decode(latents, return_dict=False)[0]
764
+ image = self.image_processor.postprocess(image, output_type=output_type)
765
+
766
+ # Offload all models
767
+ self.maybe_free_model_hooks()
768
+
769
+ if not return_dict:
770
+ return (image,)
771
+
772
+ return FluxPipelineOutput(images=image)
omini/pipeline/flux_omini_ablate_scale.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This version is for ablation study for the effect of scaling in LORA adapters.
3
+
4
+ The `generate` function is modified to include a `global_scale` parameter,
5
+ the `SCALE` variable is set globally at the start of generation.
6
+ """
7
+
8
+ import torch
9
+ from typing import List, Union, Optional, Dict, Any, Callable, Type, Tuple
10
+
11
+ from diffusers.pipelines import FluxPipeline
12
+ from diffusers.pipelines.flux.pipeline_flux import (
13
+ FluxPipelineOutput,
14
+ FluxTransformer2DModel,
15
+ calculate_shift,
16
+ retrieve_timesteps,
17
+ np,
18
+ )
19
+ from diffusers.models.attention_processor import Attention, F
20
+ from diffusers.models.embeddings import apply_rotary_emb
21
+ from transformers import pipeline
22
+
23
+ from peft.tuners.tuners_utils import BaseTunerLayer
24
+ from accelerate.utils import is_torch_version
25
+
26
+ from contextlib import contextmanager
27
+
28
+ import cv2
29
+
30
+ from PIL import Image, ImageFilter
31
+
32
+ SCALE=1.
33
+
34
+ def seed_everything(seed: int = 42):
35
+ torch.backends.cudnn.deterministic = True
36
+ torch.manual_seed(seed)
37
+ np.random.seed(seed)
38
+
39
+
40
+ def clip_hidden_states(hidden_states: torch.FloatTensor) -> torch.FloatTensor:
41
+ if hidden_states.dtype == torch.float16:
42
+ hidden_states = hidden_states.clip(-65504, 65504)
43
+ return hidden_states
44
+
45
+
46
+ def encode_images(pipeline: FluxPipeline, images: torch.Tensor):
47
+ """
48
+ Encodes the images into tokens and ids for FLUX pipeline.
49
+ """
50
+ images = pipeline.image_processor.preprocess(images)
51
+ images = images.to(pipeline.device).to(pipeline.dtype)
52
+ images = pipeline.vae.encode(images).latent_dist.sample()
53
+ images = (
54
+ images - pipeline.vae.config.shift_factor
55
+ ) * pipeline.vae.config.scaling_factor
56
+ images_tokens = pipeline._pack_latents(images, *images.shape)
57
+ images_ids = pipeline._prepare_latent_image_ids(
58
+ images.shape[0],
59
+ images.shape[2],
60
+ images.shape[3],
61
+ pipeline.device,
62
+ pipeline.dtype,
63
+ )
64
+ if images_tokens.shape[1] != images_ids.shape[0]:
65
+ images_ids = pipeline._prepare_latent_image_ids(
66
+ images.shape[0],
67
+ images.shape[2] // 2,
68
+ images.shape[3] // 2,
69
+ pipeline.device,
70
+ pipeline.dtype,
71
+ )
72
+ return images_tokens, images_ids
73
+
74
+
75
+ depth_pipe = None
76
+
77
+
78
+ def convert_to_condition(
79
+ condition_type: str,
80
+ raw_img: Union[Image.Image, torch.Tensor],
81
+ blur_radius: Optional[int] = 5,
82
+ ) -> Union[Image.Image, torch.Tensor]:
83
+ if condition_type == "depth":
84
+ global depth_pipe
85
+ depth_pipe = depth_pipe or pipeline(
86
+ task="depth-estimation",
87
+ model="LiheYoung/depth-anything-small-hf",
88
+ device="cpu", # Use "cpu" to enable parallel processing
89
+ )
90
+ source_image = raw_img.convert("RGB")
91
+ condition_img = depth_pipe(source_image)["depth"].convert("RGB")
92
+ return condition_img
93
+ elif condition_type == "canny":
94
+ img = np.array(raw_img)
95
+ edges = cv2.Canny(img, 100, 200)
96
+ edges = Image.fromarray(edges).convert("RGB")
97
+ return edges
98
+ elif condition_type == "coloring":
99
+ return raw_img.convert("L").convert("RGB")
100
+ elif condition_type == "deblurring":
101
+ condition_image = (
102
+ raw_img.convert("RGB")
103
+ .filter(ImageFilter.GaussianBlur(blur_radius))
104
+ .convert("RGB")
105
+ )
106
+ return condition_image
107
+ else:
108
+ print("Warning: Returning the raw image.")
109
+ return raw_img.convert("RGB")
110
+
111
+
112
+ class Condition(object):
113
+ def __init__(
114
+ self,
115
+ condition: Union[Image.Image, torch.Tensor],
116
+ adapter_setting: Union[str, dict],
117
+ position_delta=None,
118
+ position_scale=1.0,
119
+ latent_mask=None,
120
+ is_complement=False,
121
+ ) -> None:
122
+ self.condition = condition
123
+ self.adapter = adapter_setting
124
+ self.position_delta = position_delta
125
+ self.position_scale = position_scale
126
+ self.latent_mask = (
127
+ latent_mask.T.reshape(-1) if latent_mask is not None else None
128
+ )
129
+ self.is_complement = is_complement
130
+
131
+ def encode(
132
+ self, pipe: FluxPipeline, empty: bool = False
133
+ ) -> Tuple[torch.Tensor, torch.Tensor, int]:
134
+ condition_empty = Image.new("RGB", self.condition.size, (0, 0, 0))
135
+ tokens, ids = encode_images(pipe, condition_empty if empty else self.condition)
136
+
137
+ if self.position_delta is not None:
138
+ ids[:, 1] += self.position_delta[0]
139
+ ids[:, 2] += self.position_delta[1]
140
+
141
+ if self.position_scale != 1.0:
142
+ scale_bias = (self.position_scale - 1.0) / 2
143
+ ids[:, 1:] *= self.position_scale
144
+ ids[:, 1:] += scale_bias
145
+
146
+ if self.latent_mask is not None:
147
+ tokens = tokens[:, self.latent_mask]
148
+ ids = ids[self.latent_mask]
149
+
150
+ return tokens, ids
151
+
152
+
153
+ @contextmanager
154
+ def specify_lora(lora_modules: List[BaseTunerLayer], specified_lora):
155
+ # Filter valid lora modules
156
+ valid_lora_modules = [m for m in lora_modules if isinstance(m, BaseTunerLayer)]
157
+ # Save original scales
158
+ original_scales = [
159
+ {
160
+ adapter: module.scaling[adapter]
161
+ for adapter in module.active_adapters
162
+ if adapter in module.scaling
163
+ }
164
+ for module in valid_lora_modules
165
+ ]
166
+ # Enter context: adjust scaling
167
+ for module in valid_lora_modules:
168
+ for adapter in module.active_adapters:
169
+ if adapter in module.scaling:
170
+ module.scaling[adapter] = SCALE if adapter == specified_lora else 0
171
+ try:
172
+ yield
173
+ finally:
174
+ # Exit context: restore original scales
175
+ for module, scales in zip(valid_lora_modules, original_scales):
176
+ for adapter in module.active_adapters:
177
+ if adapter in module.scaling:
178
+ module.scaling[adapter] = scales[adapter]
179
+
180
+
181
+ def attn_forward(
182
+ attn: Attention,
183
+ hidden_states: List[torch.FloatTensor],
184
+ adapters: List[str],
185
+ hidden_states2: Optional[List[torch.FloatTensor]] = [],
186
+ position_embs: Optional[List[torch.Tensor]] = None,
187
+ group_mask: Optional[torch.Tensor] = None,
188
+ cache_mode: Optional[str] = None,
189
+ # to determine whether to cache the keys and values for this branch
190
+ to_cache: Optional[List[torch.Tensor]] = None,
191
+ cache_storage: Optional[List[torch.Tensor]] = None,
192
+ **kwargs: dict,
193
+ ) -> torch.FloatTensor:
194
+ bs, _, _ = hidden_states[0].shape
195
+ h2_n = len(hidden_states2)
196
+
197
+ queries, keys, values = [], [], []
198
+
199
+ # Prepare query, key, value for each encoder hidden state (text branch)
200
+ for i, hidden_state in enumerate(hidden_states2):
201
+ query = attn.add_q_proj(hidden_state)
202
+ key = attn.add_k_proj(hidden_state)
203
+ value = attn.add_v_proj(hidden_state)
204
+
205
+ head_dim = key.shape[-1] // attn.heads
206
+ reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
207
+
208
+ query, key, value = map(reshape_fn, (query, key, value))
209
+ query, key = attn.norm_added_q(query), attn.norm_added_k(key)
210
+
211
+ queries.append(query)
212
+ keys.append(key)
213
+ values.append(value)
214
+
215
+ # Prepare query, key, value for each hidden state (image branch)
216
+ for i, hidden_state in enumerate(hidden_states):
217
+ with specify_lora((attn.to_q, attn.to_k, attn.to_v), adapters[i + h2_n]):
218
+ query = attn.to_q(hidden_state)
219
+ key = attn.to_k(hidden_state)
220
+ value = attn.to_v(hidden_state)
221
+
222
+ head_dim = key.shape[-1] // attn.heads
223
+ reshape_fn = lambda x: x.view(bs, -1, attn.heads, head_dim).transpose(1, 2)
224
+
225
+ query, key, value = map(reshape_fn, (query, key, value))
226
+ query, key = attn.norm_q(query), attn.norm_k(key)
227
+
228
+ queries.append(query)
229
+ keys.append(key)
230
+ values.append(value)
231
+
232
+ # Apply rotary embedding
233
+ if position_embs is not None:
234
+ queries = [apply_rotary_emb(q, position_embs[i]) for i, q in enumerate(queries)]
235
+ keys = [apply_rotary_emb(k, position_embs[i]) for i, k in enumerate(keys)]
236
+
237
+ if cache_mode == "write":
238
+ for i, (k, v) in enumerate(zip(keys, values)):
239
+ if to_cache[i]:
240
+ cache_storage[attn.cache_idx][0].append(k)
241
+ cache_storage[attn.cache_idx][1].append(v)
242
+
243
+ attn_outputs = []
244
+ for i, query in enumerate(queries):
245
+ keys_, values_ = [], []
246
+ # Add keys and values from other branches
247
+ for j, (k, v) in enumerate(zip(keys, values)):
248
+ if (group_mask is not None) and not (group_mask[i][j].item()):
249
+ continue
250
+ keys_.append(k)
251
+ values_.append(v)
252
+ if cache_mode == "read":
253
+ keys_.extend(cache_storage[attn.cache_idx][0])
254
+ values_.extend(cache_storage[attn.cache_idx][1])
255
+ # Add keys and values from cache TODO
256
+ # Attention computation
257
+ attn_output = F.scaled_dot_product_attention(
258
+ query, torch.cat(keys_, dim=2), torch.cat(values_, dim=2)
259
+ ).to(query.dtype)
260
+ attn_output = attn_output.transpose(1, 2).reshape(bs, -1, attn.heads * head_dim)
261
+ attn_outputs.append(attn_output)
262
+
263
+ # Reshape attention output to match the original hidden states
264
+ h_out, h2_out = [], []
265
+
266
+ for i, hidden_state in enumerate(hidden_states2):
267
+ h2_out.append(attn.to_add_out(attn_outputs[i]))
268
+
269
+ for i, hidden_state in enumerate(hidden_states):
270
+ h = attn_outputs[i + h2_n]
271
+ if getattr(attn, "to_out", None) is not None:
272
+ with specify_lora((attn.to_out[0],), adapters[i + h2_n]):
273
+ h = attn.to_out[0](h)
274
+ h_out.append(h)
275
+
276
+ return (h_out, h2_out) if h2_n else h_out
277
+
278
+
279
+ def block_forward(
280
+ self,
281
+ image_hidden_states: List[torch.FloatTensor],
282
+ text_hidden_states: List[torch.FloatTensor],
283
+ tembs: List[torch.FloatTensor],
284
+ adapters: List[str],
285
+ position_embs=None,
286
+ attn_forward=attn_forward,
287
+ **kwargs: dict,
288
+ ):
289
+ txt_n = len(text_hidden_states)
290
+
291
+ img_variables, txt_variables = [], []
292
+
293
+ for i, text_h in enumerate(text_hidden_states):
294
+ txt_variables.append(self.norm1_context(text_h, emb=tembs[i]))
295
+
296
+ for i, image_h in enumerate(image_hidden_states):
297
+ with specify_lora((self.norm1.linear,), adapters[i + txt_n]):
298
+ img_variables.append(self.norm1(image_h, emb=tembs[i + txt_n]))
299
+
300
+ # Attention.
301
+ img_attn_output, txt_attn_output = attn_forward(
302
+ self.attn,
303
+ hidden_states=[each[0] for each in img_variables],
304
+ hidden_states2=[each[0] for each in txt_variables],
305
+ position_embs=position_embs,
306
+ adapters=adapters,
307
+ **kwargs,
308
+ )
309
+
310
+ text_out = []
311
+ for i in range(len(text_hidden_states)):
312
+ _, gate_msa, shift_mlp, scale_mlp, gate_mlp = txt_variables[i]
313
+ text_h = text_hidden_states[i] + txt_attn_output[i] * gate_msa.unsqueeze(1)
314
+ norm_h = (
315
+ self.norm2_context(text_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
316
+ )
317
+ text_h = self.ff_context(norm_h) * gate_mlp.unsqueeze(1) + text_h
318
+ text_out.append(clip_hidden_states(text_h))
319
+
320
+ image_out = []
321
+ for i in range(len(image_hidden_states)):
322
+ _, gate_msa, shift_mlp, scale_mlp, gate_mlp = img_variables[i]
323
+ image_h = (
324
+ image_hidden_states[i] + img_attn_output[i] * gate_msa.unsqueeze(1)
325
+ ).to(image_hidden_states[i].dtype)
326
+ norm_h = self.norm2(image_h) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
327
+ with specify_lora((self.ff.net[2],), adapters[i + txt_n]):
328
+ image_h = image_h + self.ff(norm_h) * gate_mlp.unsqueeze(1)
329
+ image_out.append(clip_hidden_states(image_h))
330
+ return image_out, text_out
331
+
332
+
333
+ def single_block_forward(
334
+ self,
335
+ hidden_states: List[torch.FloatTensor],
336
+ tembs: List[torch.FloatTensor],
337
+ adapters: List[str],
338
+ position_embs=None,
339
+ attn_forward=attn_forward,
340
+ **kwargs: dict,
341
+ ):
342
+ mlp_hidden_states, gates = [[None for _ in hidden_states] for _ in range(2)]
343
+
344
+ hidden_state_norm = []
345
+ for i, hidden_state in enumerate(hidden_states):
346
+ # [NOTE]!: This function's output is slightly DIFFERENT from the original
347
+ # FLUX version. In the original implementation, the gates were computed using
348
+ # the combined hidden states from both the image and text branches. Here, each
349
+ # branch computes its gate using only its own hidden state.
350
+ with specify_lora((self.norm.linear, self.proj_mlp), adapters[i]):
351
+ h_norm, gates[i] = self.norm(hidden_state, emb=tembs[i])
352
+ mlp_hidden_states[i] = self.act_mlp(self.proj_mlp(h_norm))
353
+ hidden_state_norm.append(h_norm)
354
+
355
+ attn_outputs = attn_forward(
356
+ self.attn, hidden_state_norm, adapters, position_embs=position_embs, **kwargs
357
+ )
358
+
359
+ h_out = []
360
+ for i in range(len(hidden_states)):
361
+ with specify_lora((self.proj_out,), adapters[i]):
362
+ h = torch.cat([attn_outputs[i], mlp_hidden_states[i]], dim=2)
363
+ h = gates[i].unsqueeze(1) * self.proj_out(h) + hidden_states[i]
364
+ h_out.append(clip_hidden_states(h))
365
+
366
+ return h_out
367
+
368
+
369
+ def transformer_forward(
370
+ transformer: FluxTransformer2DModel,
371
+ image_features: List[torch.Tensor],
372
+ text_features: List[torch.Tensor] = None,
373
+ img_ids: List[torch.Tensor] = None,
374
+ txt_ids: List[torch.Tensor] = None,
375
+ pooled_projections: List[torch.Tensor] = None,
376
+ timesteps: List[torch.LongTensor] = None,
377
+ guidances: List[torch.Tensor] = None,
378
+ adapters: List[str] = None,
379
+ # Assign the function to be used for the forward pass
380
+ single_block_forward=single_block_forward,
381
+ block_forward=block_forward,
382
+ attn_forward=attn_forward,
383
+ **kwargs: dict,
384
+ ):
385
+ self = transformer
386
+ txt_n = len(text_features) if text_features is not None else 0
387
+
388
+ adapters = adapters or [None] * (txt_n + len(image_features))
389
+ assert len(adapters) == len(timesteps)
390
+
391
+ # Preprocess the image_features
392
+ image_hidden_states = []
393
+ for i, image_feature in enumerate(image_features):
394
+ with specify_lora((self.x_embedder,), adapters[i + txt_n]):
395
+ image_hidden_states.append(self.x_embedder(image_feature))
396
+
397
+ # Preprocess the text_features
398
+ text_hidden_states = []
399
+ for text_feature in text_features:
400
+ text_hidden_states.append(self.context_embedder(text_feature))
401
+
402
+ # Prepare embeddings of (timestep, guidance, pooled_projections)
403
+ assert len(timesteps) == len(image_features) + len(text_features)
404
+
405
+ def get_temb(timestep, guidance, pooled_projection):
406
+ timestep = timestep.to(image_hidden_states[0].dtype) * 1000
407
+ if guidance is not None:
408
+ guidance = guidance.to(image_hidden_states[0].dtype) * 1000
409
+ return self.time_text_embed(timestep, guidance, pooled_projection)
410
+ else:
411
+ return self.time_text_embed(timestep, pooled_projection)
412
+
413
+ tembs = [get_temb(*each) for each in zip(timesteps, guidances, pooled_projections)]
414
+
415
+ # Prepare position embeddings for each token
416
+ position_embs = [self.pos_embed(each) for each in (*txt_ids, *img_ids)]
417
+
418
+ # Prepare the gradient checkpointing kwargs
419
+ gckpt_kwargs: Dict[str, Any] = (
420
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
421
+ )
422
+
423
+ # dual branch blocks
424
+ for block in self.transformer_blocks:
425
+ block_kwargs = {
426
+ "self": block,
427
+ "image_hidden_states": image_hidden_states,
428
+ "text_hidden_states": text_hidden_states,
429
+ "tembs": tembs,
430
+ "position_embs": position_embs,
431
+ "adapters": adapters,
432
+ "attn_forward": attn_forward,
433
+ **kwargs,
434
+ }
435
+ if self.training and self.gradient_checkpointing:
436
+ image_hidden_states, text_hidden_states = torch.utils.checkpoint.checkpoint(
437
+ block_forward, **block_kwargs, **gckpt_kwargs
438
+ )
439
+ else:
440
+ image_hidden_states, text_hidden_states = block_forward(**block_kwargs)
441
+
442
+ # combine image and text hidden states then pass through the single transformer blocks
443
+ all_hidden_states = [*text_hidden_states, *image_hidden_states]
444
+ for block in self.single_transformer_blocks:
445
+ block_kwargs = {
446
+ "self": block,
447
+ "hidden_states": all_hidden_states,
448
+ "tembs": tembs,
449
+ "position_embs": position_embs,
450
+ "adapters": adapters,
451
+ "attn_forward": attn_forward,
452
+ **kwargs,
453
+ }
454
+ if self.training and self.gradient_checkpointing:
455
+ all_hidden_states = torch.utils.checkpoint.checkpoint(
456
+ single_block_forward, **block_kwargs, **gckpt_kwargs
457
+ )
458
+ else:
459
+ all_hidden_states = single_block_forward(**block_kwargs)
460
+
461
+ image_hidden_states = self.norm_out(all_hidden_states[txt_n], tembs[txt_n])
462
+ output = self.proj_out(image_hidden_states)
463
+
464
+ return (output,)
465
+
466
+
467
+ @torch.no_grad()
468
+ def generate(
469
+ pipeline: FluxPipeline,
470
+ prompt: Union[str, List[str]] = None,
471
+ prompt_2: Optional[Union[str, List[str]]] = None,
472
+ height: Optional[int] = 512,
473
+ width: Optional[int] = 512,
474
+ num_inference_steps: int = 28,
475
+ timesteps: List[int] = None,
476
+ guidance_scale: float = 3.5,
477
+ num_images_per_prompt: Optional[int] = 1,
478
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
479
+ latents: Optional[torch.FloatTensor] = None,
480
+ prompt_embeds: Optional[torch.FloatTensor] = None,
481
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
482
+ output_type: Optional[str] = "pil",
483
+ return_dict: bool = True,
484
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
485
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
486
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
487
+ max_sequence_length: int = 512,
488
+ # Condition Parameters (Optional)
489
+ main_adapter: Optional[List[str]] = None,
490
+ conditions: List[Condition] = [],
491
+ image_guidance_scale: float = 1.0,
492
+ transformer_kwargs: Optional[Dict[str, Any]] = {},
493
+ kv_cache=False,
494
+ latent_mask=None,
495
+ global_scale=None,
496
+ **params: dict,
497
+ ):
498
+
499
+ if global_scale is not None:
500
+ global SCALE
501
+ SCALE = global_scale
502
+
503
+ self = pipeline
504
+
505
+ height = height or self.default_sample_size * self.vae_scale_factor
506
+ width = width or self.default_sample_size * self.vae_scale_factor
507
+
508
+ # Check inputs. Raise error if not correct
509
+ self.check_inputs(
510
+ prompt,
511
+ prompt_2,
512
+ height,
513
+ width,
514
+ prompt_embeds=prompt_embeds,
515
+ pooled_prompt_embeds=pooled_prompt_embeds,
516
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
517
+ max_sequence_length=max_sequence_length,
518
+ )
519
+
520
+ self._guidance_scale = guidance_scale
521
+ self._joint_attention_kwargs = joint_attention_kwargs
522
+
523
+ # Define call parameters
524
+ if prompt is not None and isinstance(prompt, str):
525
+ batch_size = 1
526
+ elif prompt is not None and isinstance(prompt, list):
527
+ batch_size = len(prompt)
528
+ else:
529
+ batch_size = prompt_embeds.shape[0]
530
+
531
+ device = self._execution_device
532
+
533
+ # Prepare prompt embeddings
534
+ (
535
+ prompt_embeds,
536
+ pooled_prompt_embeds,
537
+ text_ids,
538
+ ) = self.encode_prompt(
539
+ prompt=prompt,
540
+ prompt_2=prompt_2,
541
+ prompt_embeds=prompt_embeds,
542
+ pooled_prompt_embeds=pooled_prompt_embeds,
543
+ device=device,
544
+ num_images_per_prompt=num_images_per_prompt,
545
+ max_sequence_length=max_sequence_length,
546
+ )
547
+
548
+ # Prepare latent variables
549
+ num_channels_latents = self.transformer.config.in_channels // 4
550
+ latents, latent_image_ids = self.prepare_latents(
551
+ batch_size * num_images_per_prompt,
552
+ num_channels_latents,
553
+ height,
554
+ width,
555
+ prompt_embeds.dtype,
556
+ device,
557
+ generator,
558
+ latents,
559
+ )
560
+
561
+ if latent_mask is not None:
562
+ latent_mask = latent_mask.T.reshape(-1)
563
+ latents = latents[:, latent_mask]
564
+ latent_image_ids = latent_image_ids[latent_mask]
565
+
566
+ # Prepare conditions
567
+ c_latents, uc_latents, c_ids, c_timesteps = ([], [], [], [])
568
+ c_projections, c_guidances, c_adapters = ([], [], [])
569
+ complement_cond = None
570
+ for condition in conditions:
571
+ tokens, ids = condition.encode(self)
572
+ c_latents.append(tokens) # [batch_size, token_n, token_dim]
573
+ # Empty condition for unconditioned image
574
+ if image_guidance_scale != 1.0:
575
+ uc_latents.append(condition.encode(self, empty=True)[0])
576
+ c_ids.append(ids) # [token_n, id_dim(3)]
577
+ c_timesteps.append(torch.zeros([1], device=device))
578
+ c_projections.append(pooled_prompt_embeds)
579
+ c_guidances.append(torch.ones([1], device=device))
580
+ c_adapters.append(condition.adapter)
581
+ # This complement_condition will be combined with the original image.
582
+ # See the token integration of OminiControl2 [https://arxiv.org/abs/2503.08280]
583
+ if condition.is_complement:
584
+ complement_cond = (tokens, ids)
585
+
586
+ # Prepare timesteps
587
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
588
+ image_seq_len = latents.shape[1]
589
+ mu = calculate_shift(
590
+ image_seq_len,
591
+ self.scheduler.config.base_image_seq_len,
592
+ self.scheduler.config.max_image_seq_len,
593
+ self.scheduler.config.base_shift,
594
+ self.scheduler.config.max_shift,
595
+ )
596
+ timesteps, num_inference_steps = retrieve_timesteps(
597
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
598
+ )
599
+ num_warmup_steps = max(
600
+ len(timesteps) - num_inference_steps * self.scheduler.order, 0
601
+ )
602
+ self._num_timesteps = len(timesteps)
603
+
604
+ if kv_cache:
605
+ attn_counter = 0
606
+ for module in self.transformer.modules():
607
+ if isinstance(module, Attention):
608
+ setattr(module, "cache_idx", attn_counter)
609
+ attn_counter += 1
610
+ kv_cond = [[[], []] for _ in range(attn_counter)]
611
+ kv_uncond = [[[], []] for _ in range(attn_counter)]
612
+
613
+ def clear_cache():
614
+ for storage in [kv_cond, kv_uncond]:
615
+ for kesy, values in storage:
616
+ kesy.clear()
617
+ values.clear()
618
+
619
+ branch_n = len(conditions) + 2
620
+ group_mask = torch.ones([branch_n, branch_n], dtype=torch.bool)
621
+ # Disable the attention cross different condition branches
622
+ group_mask[2:, 2:] = torch.diag(torch.tensor([1] * len(conditions)))
623
+ # Disable the attention from condition branches to image branch and text branch
624
+ if kv_cache:
625
+ group_mask[2:, :2] = False
626
+
627
+ # Denoising loop
628
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
629
+ for i, t in enumerate(timesteps):
630
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
631
+ timestep = t.expand(latents.shape[0]).to(latents.dtype) / 1000
632
+
633
+ # handle guidance
634
+ if self.transformer.config.guidance_embeds:
635
+ guidance = torch.tensor([guidance_scale], device=device)
636
+ guidance = guidance.expand(latents.shape[0])
637
+ else:
638
+ guidance, c_guidances = None, [None for _ in c_guidances]
639
+
640
+ if kv_cache:
641
+ mode = "write" if i == 0 else "read"
642
+ if mode == "write":
643
+ clear_cache()
644
+ use_cond = not (kv_cache) or mode == "write"
645
+
646
+ noise_pred = transformer_forward(
647
+ self.transformer,
648
+ image_features=[latents] + (c_latents if use_cond else []),
649
+ text_features=[prompt_embeds],
650
+ img_ids=[latent_image_ids] + (c_ids if use_cond else []),
651
+ txt_ids=[text_ids],
652
+ timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
653
+ pooled_projections=[pooled_prompt_embeds] * 2
654
+ + (c_projections if use_cond else []),
655
+ guidances=[guidance] * 2 + (c_guidances if use_cond else []),
656
+ return_dict=False,
657
+ adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
658
+ cache_mode=mode if kv_cache else None,
659
+ cache_storage=kv_cond if kv_cache else None,
660
+ to_cache=[False, False, *[True] * len(c_latents)],
661
+ group_mask=group_mask,
662
+ **transformer_kwargs,
663
+ )[0]
664
+
665
+ if image_guidance_scale != 1.0:
666
+ unc_pred = transformer_forward(
667
+ self.transformer,
668
+ image_features=[latents] + (uc_latents if use_cond else []),
669
+ text_features=[prompt_embeds],
670
+ img_ids=[latent_image_ids] + (c_ids if use_cond else []),
671
+ txt_ids=[text_ids],
672
+ timesteps=[timestep, timestep] + (c_timesteps if use_cond else []),
673
+ pooled_projections=[pooled_prompt_embeds] * 2
674
+ + (c_projections if use_cond else []),
675
+ guidances=[guidance] * 2 + (c_guidances if use_cond else []),
676
+ return_dict=False,
677
+ adapters=[main_adapter] * 2 + (c_adapters if use_cond else []),
678
+ cache_mode=mode if kv_cache else None,
679
+ cache_storage=kv_uncond if kv_cache else None,
680
+ to_cache=[False, False, *[True] * len(c_latents)],
681
+ **transformer_kwargs,
682
+ )[0]
683
+
684
+ noise_pred = unc_pred + image_guidance_scale * (noise_pred - unc_pred)
685
+
686
+ # compute the previous noisy sample x_t -> x_t-1
687
+ latents_dtype = latents.dtype
688
+ latents = self.scheduler.step(noise_pred, t, latents)[0]
689
+
690
+ if latents.dtype != latents_dtype:
691
+ if torch.backends.mps.is_available():
692
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
693
+ latents = latents.to(latents_dtype)
694
+
695
+ if callback_on_step_end is not None:
696
+ callback_kwargs = {}
697
+ for k in callback_on_step_end_tensor_inputs:
698
+ callback_kwargs[k] = locals()[k]
699
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
700
+
701
+ latents = callback_outputs.pop("latents", latents)
702
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
703
+
704
+ # call the callback, if provided
705
+ if i == len(timesteps) - 1 or (
706
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
707
+ ):
708
+ progress_bar.update()
709
+
710
+ if latent_mask is not None:
711
+ # Combine the generated latents and the complement condition
712
+ assert complement_cond is not None
713
+ comp_latent, comp_ids = complement_cond
714
+ all_ids = torch.cat([latent_image_ids, comp_ids], dim=0) # (Ta+Tc,3)
715
+ shape = (all_ids.max(dim=0).values + 1).to(torch.long) # (3,)
716
+ H, W = shape[1].item(), shape[2].item()
717
+ B, _, C = latents.shape
718
+ # Create a empty canvas
719
+ canvas = latents.new_zeros(B, H * W, C) # (B,H*W,C)
720
+
721
+ # Stash the latents and the complement condition
722
+ def _stash(canvas, tokens, ids, H, W) -> None:
723
+ B, T, C = tokens.shape
724
+ ids = ids.to(torch.long)
725
+ flat_idx = (ids[:, 1] * W + ids[:, 2]).to(torch.long)
726
+ canvas.view(B, -1, C).index_copy_(1, flat_idx, tokens)
727
+
728
+ _stash(canvas, latents, latent_image_ids, H, W)
729
+ _stash(canvas, comp_latent, comp_ids, H, W)
730
+ latents = canvas.view(B, H * W, C)
731
+
732
+ if output_type == "latent":
733
+ image = latents
734
+ else:
735
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
736
+ latents = (
737
+ latents / self.vae.config.scaling_factor
738
+ ) + self.vae.config.shift_factor
739
+ image = self.vae.decode(latents, return_dict=False)[0]
740
+ image = self.image_processor.postprocess(image, output_type=output_type)
741
+
742
+ # Offload all models
743
+ self.maybe_free_model_hooks()
744
+
745
+ if not return_dict:
746
+ return (image,)
747
+
748
+ return FluxPipelineOutput(images=image)
omini/rotation/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .rotation_config import RotationConfig
2
+ from .layer import RotationLayer
3
+ from .model import RotationTuner
omini/rotation/layer.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, Set
4
+
5
+ from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
6
+
7
+ def inverse_2x2(matrices):
8
+
9
+ # Extract matrix elements
10
+ # matrices[..., 0, 0] corresponds to 'a' in [[a, b], [c, d]]
11
+ a = matrices[..., 0, 0]
12
+ b = matrices[..., 0, 1]
13
+ c = matrices[..., 1, 0]
14
+ d = matrices[..., 1, 1]
15
+
16
+ # Compute determinant
17
+ det = a * d - b * c
18
+
19
+ # Compute inverse using the formula:
20
+ # inv = (1/det) * [[d, -b], [-c, a]]
21
+ inv_det = 1.0 / det
22
+
23
+ # Create output tensor
24
+ inv_matrices = torch.empty_like(matrices)
25
+ inv_matrices[..., 0, 0] = d * inv_det
26
+ inv_matrices[..., 0, 1] = -b * inv_det
27
+ inv_matrices[..., 1, 0] = -c * inv_det
28
+ inv_matrices[..., 1, 1] = a * inv_det
29
+
30
+ return inv_matrices
31
+
32
+ class Rotation(nn.Module):
33
+ """
34
+ Rotation layer based on Cayley transformation for parameter-efficient fine-tuning.
35
+
36
+ This layer implements orthogonal fine-tuning through Cayley transformation:
37
+ h(x) = (I - A)^{-1} (I + A) x
38
+
39
+ where A = XY^T with X = [U; -V] and Y = [V; U]
40
+ """
41
+
42
+ def __init__(self, r, dim, T=1.0, num_rotations=4):
43
+ super().__init__()
44
+ self.r = r
45
+ self.T = T
46
+ self.U = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.002, requires_grad=True)
47
+ self.V = nn.Parameter(torch.randn(num_rotations, r, dim) * 0.0, requires_grad=True)
48
+ self.num_rotations = num_rotations
49
+
50
+
51
+ def forward(self, x):
52
+ """
53
+ Apply Cayley transformation to input x.
54
+
55
+ A = XY^T where X = [U; -V], Y = [V; U]
56
+ Cayley transformation: h(x) = (I - A)^{-1} (I + A) x
57
+
58
+ Uses Woodbury identity for efficient computation:
59
+ (I - XY^T)^{-1} = I + X (I - Y^T X)^{-1} Y^T
60
+
61
+ Args:
62
+ x: Input tensor of shape (..., dim)
63
+
64
+ Returns:
65
+ Transformed tensor of shape (..., dim)
66
+ """
67
+ x_dtype = x.dtype
68
+ X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
69
+ Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
70
+
71
+ Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
72
+ I_2r = torch.eye(2 * self.r, device=x.device, dtype=x.dtype).repeat(self.num_rotations, 1, 1)
73
+ I_minus_YX = I_2r - Y_T_X
74
+
75
+ if self.r == 1:
76
+ I_minus_YX_inv = inverse_2x2(I_minus_YX)
77
+ else:
78
+ # make it float32
79
+ I_minus_YX = I_minus_YX.to(torch.float32)
80
+ I_minus_YX_inv = torch.linalg.inv(I_minus_YX) # Shape: (num_rotations, 2r, 2r)
81
+ I_minus_YX_inv = I_minus_YX_inv.to(x_dtype)
82
+
83
+ Yx = torch.einsum("...d,nrd->...nr", x, Y) # Shape: (batch*seq_len, num_rotations, 2r)
84
+ I_minus_YX_inv_Yx = torch.einsum("nrr,...nr->...nr", I_minus_YX_inv, Yx)
85
+
86
+ second_term = torch.einsum("...nr,nrd->...nd", I_minus_YX_inv_Yx, X) # Shape: (batch*seq_len, num_rotations, dim)
87
+ second_term = second_term.sum(dim=-2) # Sum over rotations
88
+
89
+ output = x + 2 * second_term # Shape: (batch*seq_len, dim)
90
+
91
+ return output
92
+
93
+ def get_delta_weight(self):
94
+ """
95
+ Compute the delta weight matrix induced by the rotation layer.
96
+
97
+ Returns:
98
+ Delta weight matrix of shape (dim, dim)
99
+ """
100
+ X = torch.cat([self.U, -self.V], dim=1) # Shape: (num_rotations, 2r, dim)
101
+ Y = torch.cat([self.V, self.U], dim=1) * self.T # Shape: (num_rotations, 2r, dim)
102
+
103
+ Y_T_X = torch.matmul(Y, X.transpose(1, 2)) # Shape: (num_rotations, 2r, 2r)
104
+ I_2r = torch.eye(2 * self.r, device=X.device, dtype=X.dtype).repeat(self.num_rotations, 1, 1)
105
+ I_minus_YX = I_2r - Y_T_X
106
+
107
+ if self.r == 1:
108
+ I_minus_YX_inv = inverse_2x2(I_minus_YX)
109
+ I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
110
+ else:
111
+ I_minus_YX_inv_Y = torch.linalg.solve(I_minus_YX.to(torch.float32), Y.to(torch.float32)) # Shape: (num_rotations, 2r, dim)
112
+ I_minus_YX_inv_Y = I_minus_YX_inv_Y.to(X.dtype)
113
+
114
+ # I_minus_YX_float = I_minus_YX.float()
115
+ # I_minus_YX_inv = torch.linalg.inv(I_minus_YX_float) # Shape: (num_rotations, 2r, 2r)
116
+ # I_minus_YX_inv = I_minus_YX_inv.to(X.dtype)
117
+
118
+
119
+ # I_minus_YX_inv_Y = torch.einsum("nRr,nrd->nRd", I_minus_YX_inv, Y) # Shape: (num_rotations, 2r, dim)
120
+ second_term = torch.einsum("nrd,nrD->ndD", X, I_minus_YX_inv_Y) # Shape: (num_rotations, dim, dim)
121
+ second_term = second_term.sum(dim=0)
122
+ total_delta_weight = 2 * second_term
123
+ return total_delta_weight
124
+
125
+
126
+ class RotationLayer(BaseTunerLayer):
127
+ """
128
+ Adapter-like wrapper that attaches Rotation modules to a base linear layer.
129
+ """
130
+
131
+ adapter_layer_names: tuple[str, ...] = ("rotation",)
132
+ other_param_names: tuple[str, ...] = ("r", "T", "num_rotations", "scaling")
133
+
134
+ def __init__(self, base_layer: nn.Module, **kwargs):
135
+ # Let BaseTunerLayer do its init (it usually subclasses nn.Module)
136
+ super().__init__()
137
+ # store base layer and adapter containers
138
+ self.base_layer = base_layer
139
+ self.rotation = nn.ModuleDict() # mapping adapter_name -> Rotation module
140
+ self.scaling={} # default scaling per adapter
141
+ self._adapter_config = {} # store r, T, num_rotations per adapter
142
+
143
+ # flags (exposed in a simple way)
144
+ self._disable_adapters = False
145
+ self.merged_adapters: list[str] = []
146
+ self._cast_input_dtype_enabled = True
147
+ self.kwargs = kwargs
148
+
149
+ if isinstance(base_layer, nn.Linear):
150
+ self.in_features = base_layer.in_features
151
+ self.out_features = base_layer.out_features
152
+ else:
153
+ raise NotImplementedError("RotationLayer only supports nn.Linear base layers for now.")
154
+
155
+ @property
156
+ def _available_adapters(self) -> set[str]:
157
+ return set(self.rotation.keys())
158
+
159
+ @property
160
+ def disable_adapters(self) -> bool:
161
+ return self._disable_adapters
162
+
163
+ @property
164
+ def merged(self) -> bool:
165
+ return bool(self.merged_adapters)
166
+
167
+ @property
168
+ def active_adapters(self) -> list[str]:
169
+ # If some external mechanism sets active adapters, prefer it; else use all added adapters.
170
+ return getattr(self, "_active_adapters", list(self.rotation.keys()))
171
+
172
+ def get_base_layer(self) -> nn.Module:
173
+ return self.base_layer
174
+
175
+ def _cast_input_dtype(self, x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
176
+ if not self._cast_input_dtype_enabled:
177
+ return x
178
+ return x.to(dtype)
179
+
180
+ def update_layer(
181
+ self,
182
+ adapter_name: str,
183
+ r: int,
184
+ T: float,
185
+ num_rotations: int,
186
+ **kwargs,
187
+ ):
188
+ """
189
+ Add / update a rotation adapter for this layer.
190
+ """
191
+
192
+ if r <= 0:
193
+ raise ValueError(f"r must be positive, got {r}")
194
+ if num_rotations <= 0:
195
+ raise ValueError(f"num_rotations must be positive, got {num_rotations}")
196
+
197
+ rot = Rotation(r=r, dim=self.in_features, T=T, num_rotations=num_rotations)
198
+ self.rotation[adapter_name] = rot
199
+ self.scaling[adapter_name] = 1.0
200
+ self._adapter_config[adapter_name] = {"r": r, "T": T, "num_rotations": num_rotations}
201
+
202
+ # (optional) helper to set currently active adapters externally
203
+ def set_active_adapters(self, adapters: Optional[list[str]]):
204
+ if adapters is None:
205
+ if hasattr(self, "_active_adapters"):
206
+ delattr(self, "_active_adapters")
207
+ else:
208
+ self._active_adapters = adapters
209
+
210
+
211
+ class Linear(nn.Module, RotationLayer):
212
+ """
213
+ A linear layer with an integrated rotation layer for parameter-efficient fine-tuning.
214
+ """
215
+
216
+ def __init__(self,
217
+ base_layer: nn.Linear,
218
+ adapter_name: str,
219
+ r: int,
220
+ T: float,
221
+ num_rotations: int,
222
+ **kwargs):
223
+
224
+ super().__init__()
225
+ RotationLayer.__init__(self, base_layer=base_layer, **kwargs)
226
+
227
+ self._active_adapter = adapter_name
228
+
229
+ self.update_layer(
230
+ adapter_name=adapter_name,
231
+ r=r,
232
+ T=T,
233
+ num_rotations=num_rotations,
234
+ **kwargs,
235
+ )
236
+
237
+ def merge(self, safe_merge: bool = False, adapter_names: Optional[str] = None):
238
+ """
239
+ Merge the adapter effect into the base layer weights:
240
+ W_merged = W @ R
241
+ where R = I + delta (delta returned by get_delta_weight()).
242
+ """
243
+ adapter_names = check_adapters_to_merge(self, adapter_names)
244
+
245
+ if not adapter_names:
246
+ return
247
+
248
+ base_layer = self.get_base_layer()
249
+ orig_dtype = base_layer.weight.dtype
250
+ # base_layer.weight shape: (out_features, in_features)
251
+ W = base_layer.weight.data # (out, in)
252
+
253
+ for active_adapter in adapter_names:
254
+ if active_adapter not in self._available_adapters:
255
+ continue
256
+ delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
257
+ R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R # (in, in)
258
+ # merged W = W @ R
259
+ merged_W = W.to(R.dtype) @ R
260
+ if safe_merge and not torch.isfinite(merged_W).all():
261
+ raise ValueError("Merging resulted in non-finite weights. Aborting merge.")
262
+
263
+ base_layer.weight.data = merged_W.contiguous().to(orig_dtype)
264
+ # mark merged (so unmerge can restore by inverse)
265
+ self.merged_adapters.append(active_adapter)
266
+
267
+
268
+ def unmerge(self):
269
+ """
270
+ Reverse merges in LIFO order (pop merged adapters and invert R).
271
+ """
272
+ base_layer = self.get_base_layer()
273
+ orig_dtype = base_layer.weight.dtype
274
+
275
+ while self.merged_adapters:
276
+ active_adapter = self.merged_adapters.pop()
277
+ if active_adapter not in self._available_adapters:
278
+ continue
279
+ delta_R = self.rotation[active_adapter].get_delta_weight() # (in, in)
280
+ R = torch.eye(delta_R.size(0), device=delta_R.device, dtype=delta_R.dtype) + delta_R
281
+ R_inv = torch.linalg.inv(R)
282
+ merged_W = base_layer.weight.data.to(R.dtype)
283
+ unmerged_W = merged_W @ R_inv
284
+ base_layer.weight.data = unmerged_W.contiguous().to(orig_dtype)
285
+
286
+
287
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
288
+ x_dtype = x.dtype
289
+ base_layer = self.get_base_layer()
290
+
291
+ if self.disable_adapters:
292
+ # if merged, unmerge to ensure base_layer produces original behavior
293
+ if self.merged:
294
+ self.unmerge()
295
+ return base_layer(x, *args, **kwargs).to(x_dtype)
296
+
297
+ if self.merged:
298
+ # if merged into base layer, just forward
299
+ return base_layer(x, *args, **kwargs).to(x_dtype)
300
+
301
+ # otherwise apply active adapters (transform inputs) then call base layer
302
+ for active_adapter in self.active_adapters:
303
+ if active_adapter not in self.rotation:
304
+ continue
305
+ rotation = self.rotation[active_adapter]
306
+ x = self._cast_input_dtype(x, rotation.U.dtype)
307
+ x = rotation(x)
308
+
309
+ return base_layer(x, *args, **kwargs).to(x_dtype)
310
+
311
+ def __repr__(self):
312
+ return f"rotation.{super().__repr__()}"
313
+
omini/rotation/layer_test.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from omini.rotation.layer import Linear, Rotation
4
+
5
+ def test_rotation_merge():
6
+ """
7
+ Test that merging rotation adapter produces the same output as the unmerged version.
8
+ """
9
+ print("="*60)
10
+ print("Testing Rotation Layer Merge")
11
+ print("="*60)
12
+
13
+ # Set random seed for reproducibility
14
+ torch.manual_seed(42)
15
+
16
+ # Configuration
17
+ in_features = 512
18
+ out_features = 1024
19
+ r = 4
20
+ num_rotations = 4
21
+ T = 1.0
22
+ batch_size = 8
23
+ seq_len = 16
24
+
25
+ # Create base linear layer
26
+ base_layer = nn.Linear(in_features, out_features, bias=True)
27
+
28
+ # Create rotation layer
29
+ rotation_layer = Linear(
30
+ base_layer=base_layer,
31
+ adapter_name="default",
32
+ r=r,
33
+ T=T,
34
+ num_rotations=num_rotations
35
+ )
36
+
37
+ # Create random input
38
+ x = torch.randn(batch_size, seq_len, in_features)
39
+
40
+ # Test 1: Forward pass before merge
41
+ print("\n" + "-"*60)
42
+ print("Test 1: Computing output BEFORE merge")
43
+ print("-"*60)
44
+ rotation_layer.eval()
45
+ with torch.no_grad():
46
+ output_before = rotation_layer(x)
47
+
48
+ print(f"Output shape: {output_before.shape}")
49
+ print(f"Output mean: {output_before.mean().item():.6f}")
50
+ print(f"Output std: {output_before.std().item():.6f}")
51
+ print(f"Output min: {output_before.min().item():.6f}")
52
+ print(f"Output max: {output_before.max().item():.6f}")
53
+
54
+ # Save original weight for verification
55
+ original_weight = base_layer.weight.data.clone()
56
+
57
+ # Test 2: Merge adapter
58
+ print("\n" + "-"*60)
59
+ print("Test 2: Merging adapter")
60
+ print("-"*60)
61
+ rotation_layer.merge(safe_merge=True, adapter_names=["default"])
62
+ print(f"✓ Adapter merged successfully")
63
+ print(f"✓ Merged adapters: {rotation_layer.merged_adapters}")
64
+
65
+ # Check that weights have changed
66
+ weight_diff = (base_layer.weight.data - original_weight).abs().max().item()
67
+ print(f"Max weight change: {weight_diff:.6e}")
68
+
69
+ # Test 3: Forward pass after merge
70
+ print("\n" + "-"*60)
71
+ print("Test 3: Computing output AFTER merge")
72
+ print("-"*60)
73
+ with torch.no_grad():
74
+ output_after = rotation_layer(x)
75
+
76
+ print(f"Output shape: {output_after.shape}")
77
+ print(f"Output mean: {output_after.mean().item():.6f}")
78
+ print(f"Output std: {output_after.std().item():.6f}")
79
+ print(f"Output min: {output_after.min().item():.6f}")
80
+ print(f"Output max: {output_after.max().item():.6f}")
81
+
82
+ # Test 4: Compare outputs
83
+ print("\n" + "-"*60)
84
+ print("Test 4: Comparing outputs")
85
+ print("-"*60)
86
+
87
+ # Compute differences
88
+ abs_diff = (output_after - output_before).abs()
89
+ rel_diff = abs_diff / (output_before.abs() + 1e-8)
90
+
91
+ max_abs_diff = abs_diff.max().item()
92
+ mean_abs_diff = abs_diff.mean().item()
93
+ max_rel_diff = rel_diff.max().item()
94
+ mean_rel_diff = rel_diff.mean().item()
95
+
96
+ print(f"Max absolute difference: {max_abs_diff:.6e}")
97
+ print(f"Mean absolute difference: {mean_abs_diff:.6e}")
98
+ print(f"Max relative difference: {max_rel_diff:.6e}")
99
+ print(f"Mean relative difference: {mean_rel_diff:.6e}")
100
+
101
+ # Check if outputs are close
102
+ atol = 1e-4 # Absolute tolerance
103
+ rtol = 1e-3 # Relative tolerance
104
+
105
+ are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
106
+
107
+ if are_close:
108
+ print(f"\n✅ PASS: Outputs are identical (within atol={atol}, rtol={rtol})")
109
+ else:
110
+ print(f"\n❌ FAIL: Outputs differ significantly")
111
+ print(f" Expected: atol < {atol}, rtol < {rtol}")
112
+ print(f" Got: max_abs_diff = {max_abs_diff:.6e}, max_rel_diff = {max_rel_diff:.6e}")
113
+
114
+ # Test 5: Unmerge and verify
115
+ print("\n" + "-"*60)
116
+ print("Test 5: Testing unmerge")
117
+ print("-"*60)
118
+ rotation_layer.unmerge()
119
+ print(f"✓ Adapter unmerged")
120
+ print(f"✓ Merged adapters: {rotation_layer.merged_adapters}")
121
+
122
+ with torch.no_grad():
123
+ output_unmerged = rotation_layer(x)
124
+
125
+ unmerge_diff = (output_unmerged - output_before).abs().max().item()
126
+ print(f"Max difference after unmerge: {unmerge_diff:.6e}")
127
+
128
+ unmerge_close = torch.allclose(output_before, output_unmerged, atol=atol, rtol=rtol)
129
+ if unmerge_close:
130
+ print(f"✅ PASS: Unmerge restored original behavior")
131
+ else:
132
+ print(f"❌ FAIL: Unmerge did not restore original behavior")
133
+
134
+ # Test 6: Verify weight restoration
135
+ weight_restored_diff = (base_layer.weight.data - original_weight).abs().max().item()
136
+ print(f"Max weight difference after unmerge: {weight_restored_diff:.6e}")
137
+
138
+ weight_restored = torch.allclose(base_layer.weight.data, original_weight, atol=1e-5)
139
+ if weight_restored:
140
+ print(f"✅ PASS: Original weights restored")
141
+ else:
142
+ print(f"❌ FAIL: Original weights not fully restored")
143
+
144
+ print("\n" + "="*60)
145
+ print("Test Summary")
146
+ print("="*60)
147
+ return are_close and unmerge_close and weight_restored
148
+
149
+
150
+ def test_multiple_merges():
151
+ """
152
+ Test merging and unmerging multiple times.
153
+ """
154
+ print("\n" + "="*60)
155
+ print("Testing Multiple Merge/Unmerge Cycles")
156
+ print("="*60)
157
+
158
+ torch.manual_seed(42)
159
+
160
+ in_features = 256
161
+ out_features = 512
162
+ r = 4
163
+ num_rotations = 4
164
+
165
+ base_layer = nn.Linear(in_features, out_features, bias=True)
166
+ rotation_layer = Linear(
167
+ base_layer=base_layer,
168
+ adapter_name="default",
169
+ r=r,
170
+ T=1.0,
171
+ num_rotations=num_rotations
172
+ )
173
+
174
+ x = torch.randn(4, 8, in_features)
175
+ rotation_layer.eval()
176
+
177
+ # Get original output
178
+ with torch.no_grad():
179
+ original_output = rotation_layer(x)
180
+
181
+ # Test multiple cycles
182
+ all_passed = True
183
+ for cycle in range(3):
184
+ print(f"\nCycle {cycle + 1}:")
185
+
186
+ # Merge
187
+ rotation_layer.merge(safe_merge=True)
188
+ with torch.no_grad():
189
+ merged_output = rotation_layer(x)
190
+
191
+ merge_close = torch.allclose(original_output, merged_output, atol=1e-4, rtol=1e-3)
192
+ print(f" Merge: {'✅ PASS' if merge_close else '❌ FAIL'}")
193
+
194
+ # Unmerge
195
+ rotation_layer.unmerge()
196
+ with torch.no_grad():
197
+ unmerged_output = rotation_layer(x)
198
+
199
+ unmerge_close = torch.allclose(original_output, unmerged_output, atol=1e-4, rtol=1e-3)
200
+ print(f" Unmerge: {'✅ PASS' if unmerge_close else '❌ FAIL'}")
201
+
202
+ all_passed = all_passed and merge_close and unmerge_close
203
+
204
+ return all_passed
205
+
206
+
207
+ def test_with_different_dtypes():
208
+ """
209
+ Test merging with different data types.
210
+ """
211
+ print("\n" + "="*60)
212
+ print("Testing Different Data Types")
213
+ print("="*60)
214
+
215
+ torch.manual_seed(42)
216
+
217
+ dtypes = [torch.float32, torch.float16, torch.bfloat16]
218
+ all_passed = True
219
+
220
+ for dtype in dtypes:
221
+ print(f"\nTesting with dtype: {dtype}")
222
+
223
+ in_features = 256
224
+ out_features = 512
225
+ r = 4
226
+ num_rotations = 4
227
+
228
+ base_layer = nn.Linear(in_features, out_features, bias=True)
229
+ base_layer = base_layer.to(dtype)
230
+
231
+ rotation_layer = Linear(
232
+ base_layer=base_layer,
233
+ adapter_name="default",
234
+ r=r,
235
+ T=1.0,
236
+ num_rotations=num_rotations
237
+ )
238
+ rotation_layer = rotation_layer.to(dtype)
239
+
240
+ x = torch.randn(4, 8, in_features, dtype=dtype)
241
+ rotation_layer.eval()
242
+
243
+ with torch.no_grad():
244
+ output_before = rotation_layer(x)
245
+ rotation_layer.merge(safe_merge=True)
246
+ output_after = rotation_layer(x)
247
+
248
+ # Adjust tolerances based on dtype
249
+ if dtype == torch.float32:
250
+ atol, rtol = 1e-5, 1e-4
251
+ elif dtype == torch.float16:
252
+ atol, rtol = 1e-2, 1e-2
253
+ else: # bfloat16
254
+ atol, rtol = 1e-2, 1e-2
255
+
256
+ are_close = torch.allclose(output_before, output_after, atol=atol, rtol=rtol)
257
+
258
+ if are_close:
259
+ print(f" ✅ PASS")
260
+ else:
261
+ max_diff = (output_after - output_before).abs().max().item()
262
+ print(f" ❌ FAIL (max diff: {max_diff:.6e})")
263
+
264
+ all_passed = all_passed and are_close
265
+
266
+ return all_passed
267
+
268
+
269
+ if __name__ == "__main__":
270
+ print("\n" + "="*60)
271
+ print("ROTATION LAYER MERGE TEST SUITE")
272
+ print("="*60)
273
+
274
+ results = {}
275
+
276
+ # Run all tests
277
+ results["basic_merge"] = test_rotation_merge()
278
+ results["multiple_cycles"] = test_multiple_merges()
279
+ results["different_dtypes"] = test_with_different_dtypes()
280
+
281
+ # Print summary
282
+ print("\n" + "="*60)
283
+ print("FINAL SUMMARY")
284
+ print("="*60)
285
+
286
+ for test_name, passed in results.items():
287
+ status = "✅ PASS" if passed else "❌ FAIL"
288
+ print(f"{test_name}: {status}")
289
+
290
+ all_passed = all(results.values())
291
+ print("\n" + "="*60)
292
+ if all_passed:
293
+ print("🎉 ALL TESTS PASSED!")
294
+ else:
295
+ print("⚠️ SOME TESTS FAILED")
296
+ print("="*60)
omini/rotation/model.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from enum import Enum
6
+ from dataclasses import asdict
7
+ from tqdm import tqdm
8
+
9
+
10
+ from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists, onload_layer
11
+
12
+ from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, _get_submodules
13
+
14
+ from .layer import RotationLayer, Linear
15
+
16
+ TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
17
+
18
+ class RotationTuner(BaseTuner):
19
+
20
+ prefix: str = "rotation_"
21
+ tuner_layer_class = RotationLayer
22
+ target_module_mapping = TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING
23
+
24
+
25
+ @staticmethod
26
+ def _check_target_module_exists(rotation_config, key: str) -> bool:
27
+ return check_target_module_exists(rotation_config, key)
28
+
29
+ def _create_and_replace(
30
+ self,
31
+ rotation_config,
32
+ adapter_name: str,
33
+ target: nn.Module,
34
+ target_name: str,
35
+ parent: nn.Module,
36
+ current_key: str,
37
+ **optional_kwargs,
38
+ ) -> None:
39
+ """
40
+ Create and replace a target module with a rotation-augmented version.
41
+
42
+ This method is called when an existing module is already a RotationLayer
43
+ and needs to have a new adapter added to it.
44
+
45
+ Args:
46
+ rotation_config: Configuration for the rotation adapter
47
+ adapter_name: Name of the adapter to add
48
+ target: The target module to augment
49
+ target_name: Name of the target module
50
+ parent: Parent module containing the target
51
+ current_key: Full key path to the current module
52
+ **optional_kwargs: Additional optional arguments
53
+
54
+ Raises:
55
+ ValueError: If current_key is not provided
56
+ """
57
+
58
+ if current_key is None:
59
+ raise ValueError("current_key must be provided to create Rotation layer")
60
+
61
+ # Check if target is already a RotationLayer
62
+ if isinstance(target, RotationLayer):
63
+ target.update_layer(
64
+ adapter_name=adapter_name,
65
+ r=rotation_config.r,
66
+ T=rotation_config.T,
67
+ num_rotations=rotation_config.num_rotations,
68
+ )
69
+ else:
70
+ # Create new rotation layer
71
+ new_module = self._create_new_module(
72
+ rotation_config=rotation_config,
73
+ adapter_name=adapter_name,
74
+ target=target,
75
+ **optional_kwargs,
76
+ )
77
+ if new_module is not None:
78
+ self._replace_module(parent, target_name, new_module, target)
79
+
80
+ def _replace_module(self, parent, child_name, new_module, child):
81
+
82
+ setattr(parent, child_name, new_module)
83
+
84
+ # child layer wraps the original module, unpack it
85
+ if hasattr(child, "base_layer"):
86
+ child = child.base_layer
87
+
88
+ meta = torch.device("meta")
89
+ # dispatch to correct device
90
+ for name, module in new_module.named_modules():
91
+ if (self.prefix in name) or ("ranknum" in name):
92
+ if hasattr(child, "qweight"):
93
+ weight = child.qweight
94
+ elif hasattr(child, "W_q"):
95
+ weight = child.W_q
96
+ elif hasattr(child, "weight"):
97
+ weight = child.weight
98
+ elif getattr(child, "in_proj_weight", None) is not None: # MHA
99
+ weight = child.in_proj_weight
100
+ else:
101
+ weight = next(child.parameters())
102
+ if not any(p.device == meta for p in module.parameters()):
103
+ module.to(weight.device)
104
+
105
+ def _mark_only_adapters_as_trainable(self, model):
106
+
107
+ # First, freeze all parameters
108
+ for n, p in model.named_parameters():
109
+ if self.prefix not in n:
110
+ p.requires_grad = False
111
+ else:
112
+ p.requires_grad = True
113
+
114
+ # Handle bias parameters based on config
115
+ for active_adapter in self.active_adapters:
116
+ bias_config = self.peft_config[active_adapter].bias
117
+
118
+ if bias_config == "none":
119
+ continue
120
+ elif bias_config == "all":
121
+ # Enable all bias parameters
122
+ for n, p in model.named_parameters():
123
+ if "bias" in n:
124
+ p.requires_grad = True
125
+ elif bias_config == "rotation_only":
126
+ # Enable only bias in rotation layers
127
+ for name, m in model.named_modules():
128
+ if isinstance(m, RotationLayer):
129
+ if hasattr(m, "bias") and m.bias is not None:
130
+ m.bias.requires_grad = True
131
+ else:
132
+ raise NotImplementedError(
133
+ f"Requested bias configuration '{bias_config}' is not implemented. "
134
+ f"Supported values: 'none', 'all', 'rotation_only'"
135
+ )
136
+
137
+ @staticmethod
138
+ def _create_new_module(
139
+ rotation_config,
140
+ adapter_name: str,
141
+ target: nn.Module,
142
+ **kwargs,
143
+ ) -> Optional[nn.Module]:
144
+ """
145
+ Create a new rotation-augmented module.
146
+
147
+ Args:
148
+ rotation_config: Configuration for the rotation adapter
149
+ adapter_name: Name of the adapter
150
+ target: Base module to augment
151
+ **kwargs: Additional arguments
152
+
153
+ Returns:
154
+ New RotationLayer module wrapping the target, or None if unsupported
155
+ """
156
+ if isinstance(target, nn.Linear):
157
+ return Linear(
158
+ base_layer=target,
159
+ adapter_name=adapter_name,
160
+ r=rotation_config.r,
161
+ T=rotation_config.T,
162
+ num_rotations=rotation_config.num_rotations,
163
+ **kwargs,
164
+ )
165
+ else:
166
+ # Unsupported layer type
167
+ print(
168
+ f"Rotation layer does not support {type(target).__name__} yet. "
169
+ f"Skipping this module."
170
+ )
171
+ return None
172
+
173
+
174
+ def __getattr__(self, name: str):
175
+ """Forward missing attributes to the wrapped module."""
176
+ try:
177
+ return super().__getattr__(name) # defer to nn.Module's logic
178
+ except AttributeError:
179
+ if name == "model": # see #1892: prevent infinite recursion if class is not initialized
180
+ raise
181
+ return getattr(self.model, name)
182
+
183
+ def get_peft_config_as_dict(self, inference: bool = False):
184
+ config_dict = {}
185
+ for key, value in self.peft_config.items():
186
+ config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
187
+ if inference:
188
+ config["inference_mode"] = True
189
+ config_dict[key] = config
190
+ return config
191
+
192
+
193
+ def _set_adapter_layers(self, enabled=True):
194
+ for module in self.model.modules():
195
+ if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
196
+ module.enable_adapters(enabled)
197
+
198
+ def enable_adapter_layers(self) -> None:
199
+ """Enable all adapters.
200
+
201
+ Call this if you have previously disabled all adapters and want to re-enable them.
202
+ """
203
+ self._set_adapter_layers(enabled=True)
204
+
205
+ def disable_adapter_layers(self):
206
+ for active_adapter in self.active_adapters:
207
+ val = self.peft_config[active_adapter].bias
208
+ if val != "none":
209
+ msg = (
210
+ f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
211
+ "output as the base model would without adaption."
212
+ )
213
+ print(msg)
214
+ self._set_adapter_layers(enabled=False)
215
+
216
+ def set_adapter(self, adapter_name):
217
+ """Set the active adapter(s).
218
+
219
+ Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
220
+ not desired, use the following code.
221
+
222
+ ```py
223
+ >>> for name, param in model_peft.named_parameters():
224
+ ... if ...: # some check on name (ex. if 'lora' in name)
225
+ ... param.requires_grad = False
226
+ ```
227
+
228
+ Args:
229
+ adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
230
+ """
231
+ for module in self.model.modules():
232
+ if isinstance(module, RotationLayer):
233
+ if module.merged:
234
+ print("Adapter cannot be set when the model is merged. Unmerging the model first.")
235
+ module.unmerge()
236
+ module.set_adapter(adapter_name)
237
+ self.active_adapter = adapter_name
238
+
239
+ def merge_adapter(self, adapter_names: Optional[list[str]] = None) -> None:
240
+ """
241
+ Merge adapter weights into the base model weights.
242
+
243
+ This can speed up inference by eliminating the need for runtime
244
+ rotation computations.
245
+
246
+ Args:
247
+ adapter_names: List of adapter names to merge. If None, merges all
248
+ active adapters.
249
+ """
250
+ for module in self.model.modules():
251
+ if isinstance(module, RotationLayer):
252
+ module.merge(safe_merge=False, adapter_names=adapter_names)
253
+
254
+
255
+ def unmerge_adapter(self) -> None:
256
+ """
257
+ Unmerge adapter weights from the base model weights.
258
+
259
+ This reverses the merge operation, restoring dynamic adapter behavior.
260
+ """
261
+ for module in self.model.modules():
262
+ if isinstance(module, RotationLayer):
263
+ module.unmerge()
264
+
265
+ @staticmethod
266
+ def _prepare_adapter_config(peft_config, model_config):
267
+
268
+ if peft_config.target_modules is None:
269
+ if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING:
270
+ raise ValueError("Please specify `target_modules` in `peft_config`")
271
+ peft_config.target_modules = set(
272
+ TRANSFORMERS_MODELS_TO_ROTATION_TARGET_MODULES_MAPPING[model_config["model_type"]]
273
+ )
274
+
275
+ return peft_config
276
+
277
+
278
+ def _check_new_adapter_config(self, config) -> None:
279
+ """
280
+ Check the validity of a new adapter configuration.
281
+
282
+ Args:
283
+ config: Configuration to validate
284
+
285
+ Raises:
286
+ ValueError: If configuration is invalid
287
+ """
288
+ # Validate rank
289
+ if config.r <= 0:
290
+ raise ValueError(f"r must be positive, got {config.r}")
291
+
292
+ # Validate num_rotations
293
+ if config.num_rotations <= 0:
294
+ raise ValueError(
295
+ f"num_rotations must be positive, got {config.num_rotations}"
296
+ )
297
+
298
+
299
+ # Validate bias configuration
300
+ valid_bias_configs = ["none", "all", "rotation_only"]
301
+ if hasattr(config, "bias") and config.bias not in valid_bias_configs:
302
+ raise ValueError(
303
+ f"Invalid bias configuration '{config.bias}'. "
304
+ f"Must be one of {valid_bias_configs}"
305
+ )
306
+
307
+
308
+ def _unload_and_optionally_merge(
309
+ self,
310
+ merge=True,
311
+ progressbar: bool = False,
312
+ safe_merge: bool = False,
313
+ adapter_names: Optional[list[str]] = None,
314
+ ):
315
+ if merge:
316
+ self._check_merge_allowed()
317
+
318
+ key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
319
+ desc = "Unloading " + ("and merging " if merge else "") + "model"
320
+ for key in tqdm(key_list, disable=not progressbar, desc=desc):
321
+ try:
322
+ parent, target, target_name = _get_submodules(self.model, key)
323
+ except AttributeError:
324
+ continue
325
+ with onload_layer(target):
326
+ if hasattr(target, "unload_and_optionally_merge_module"):
327
+ # if layers have special unloading method, like MultiheadAttention, use that
328
+ unloaded_module = target.unload_and_optionally_merge_module(
329
+ merge=merge, safe_merge=safe_merge, adapter_names=adapter_names
330
+ )
331
+ self._replace_module(parent, target_name, unloaded_module, target)
332
+ elif hasattr(target, "base_layer"):
333
+ if merge:
334
+ target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
335
+ self._replace_module(parent, target_name, target.get_base_layer(), target)
336
+
337
+ return self.model
338
+
339
+ def delete_adapter(self, adapter_name: str) -> None:
340
+ """
341
+ Deletes an existing adapter.
342
+
343
+ Args:
344
+ adapter_name (str): Name of the adapter to be deleted.
345
+ """
346
+ if adapter_name not in list(self.peft_config.keys()):
347
+ raise ValueError(f"Adapter {adapter_name} does not exist")
348
+ del self.peft_config[adapter_name]
349
+
350
+ key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
351
+ new_adapter = None
352
+ for key in key_list:
353
+ _, target, _ = _get_submodules(self.model, key)
354
+ if isinstance(target, RotationLayer):
355
+ target.delete_adapter(adapter_name)
356
+ if new_adapter is None:
357
+ new_adapter = target.active_adapters[:]
358
+
359
+ self.active_adapter = new_adapter or []
360
+ self._delete_auxiliary_adapter(adapter_name, new_active_adapters=new_adapter)
361
+
362
+ def merge_and_unload(
363
+ self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
364
+ ) -> torch.nn.Module:
365
+ r"""
366
+ This method merges the OFT layers into the base model. This is needed if someone wants to use the base model as
367
+ a standalone model.
368
+
369
+ Args:
370
+ progressbar (`bool`):
371
+ whether to show a progressbar indicating the unload and merge process
372
+ safe_merge (`bool`):
373
+ whether to activate the safe merging check to check if there is any potential Nan in the adapter
374
+ weights
375
+ adapter_names (`List[str]`, *optional*):
376
+ The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
377
+ to `None`.
378
+
379
+ """
380
+ return self._unload_and_optionally_merge(
381
+ progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
382
+ )
383
+
384
+ def unload(self) -> torch.nn.Module:
385
+ """
386
+ Gets back the base model by removing all the oft modules without merging. This gives back the original base
387
+ model.
388
+ """
389
+ return self._unload_and_optionally_merge(merge=False)
390
+
omini/rotation/rotation_config.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional
3
+ from peft.config import PeftConfig
4
+
5
+
6
+ @dataclass
7
+ class RotationConfig(PeftConfig):
8
+ """
9
+ Configuration class for Rotation-based Parameter-Efficient Fine-Tuning.
10
+
11
+ This configuration stores all parameters needed to apply the Rotation method
12
+ (based on Cayley transformation) to a model's linear layers.
13
+
14
+ Args:
15
+ r (`int`):
16
+ The rank parameter for the low-rank approximation in rotation matrices.
17
+ T (`float`, *optional*, defaults to 1.0):
18
+ Temperature parameter for the transformation.
19
+ num_rotations (`int`, *optional*, defaults to 4):
20
+ Number of rotation matrices to use in parallel.
21
+ target_modules (`Union[List[str], str]`):
22
+ Module names to apply rotation to (e.g., ["q_proj", "v_proj"]).
23
+ target_modules_to_skip (`Union[List[str], str]`, *optional*):
24
+ Module names to skip when applying rotation.
25
+ modules_to_save (`Union[List[str], str]`, *optional*):
26
+ Modules to save in addition to rotation parameters.
27
+ layers_to_transform (`Union[List[int], int]`, *optional*):
28
+ Layers to transform. If None, all layers matching target_modules are transformed.
29
+ apply_before (`bool`, *optional*, defaults to False):
30
+ If True, apply rotation before the base linear layer. If False, apply after.
31
+ """
32
+
33
+ peft_type: str = field(default="ROTATION", init=False)
34
+ target_modules: Optional[List[str]] = field(
35
+ default=None,
36
+ metadata={
37
+ "help": "List of module names to apply rotation to (e.g., ['q_proj', 'v_proj', 'linear'])"
38
+ },
39
+ )
40
+ target_modules_to_skip: Optional[List[str]] = field(
41
+ default=None,
42
+ metadata={"help": "List of module names to skip when applying rotation"},
43
+ )
44
+ modules_to_save: Optional[List[str]] = field(
45
+ default=None,
46
+ metadata={"help": "List of modules to save in addition to rotation parameters"},
47
+ )
48
+ r: int = field(
49
+ default=8,
50
+ metadata={"help": "Rank parameter for low-rank approximation"},
51
+ )
52
+ T: float = field(
53
+ default=1.0,
54
+ metadata={"help": "Temperature parameter for Cayley transformation"},
55
+ )
56
+ num_rotations: int = field(
57
+ default=4,
58
+ metadata={"help": "Number of rotation matrices to use in parallel"},
59
+ )
60
+
61
+ bias: str = field(
62
+ default="none",
63
+ metadata={
64
+ "help": "Bias training configuration. Options: 'none', 'all', 'rotation_only'"
65
+ }
66
+ )
67
+ layers_to_transform: Optional[List[int]] = field(
68
+ default=None,
69
+ metadata={"help": "Layers to transform. If None, all matching layers are transformed"},
70
+ )
71
+
72
+ def __post_init__(self):
73
+ self.peft_type = "ROTATION"
74
+ self.target_modules = (
75
+ set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
76
+ )
77
+ self.target_modules_to_skip = (
78
+ set(self.target_modules_to_skip)
79
+ if isinstance(self.target_modules_to_skip, list)
80
+ else self.target_modules_to_skip
81
+ )
omini/train_flux/train_custom.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import random
4
+ from torch.utils.data import DataLoader, Dataset
5
+
6
+ from PIL import Image
7
+
8
+ from datasets import load_dataset
9
+
10
+ from .trainer import OminiModel, get_config, train
11
+ from ..pipeline.flux_omini import Condition, generate
12
+
13
+
14
+ class CustomDataset(Dataset):
15
+ def __getitem__(self, idx):
16
+ # TODO: Implement the logic to load your custom dataset
17
+ raise NotImplementedError("Custom dataset loading not implemented")
18
+
19
+
20
+ @torch.no_grad()
21
+ def test_function(model, save_path, file_name):
22
+ # TODO: Implement the logic to generate a sample using the model
23
+ raise NotImplementedError("Sample generation not implemented")
24
+
25
+
26
+ def main():
27
+ # Initialize
28
+ config = get_config()
29
+ training_config = config["train"]
30
+ torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
31
+
32
+ # Initialize custom dataset
33
+ dataset = CustomDataset()
34
+
35
+ # Initialize model
36
+ trainable_model = OminiModel(
37
+ flux_pipe_id=config["flux_path"],
38
+ lora_config=training_config["lora_config"],
39
+ device=f"cuda",
40
+ dtype=getattr(torch, config["dtype"]),
41
+ optimizer_config=training_config["optimizer"],
42
+ model_config=config.get("model", {}),
43
+ gradient_checkpointing=training_config.get("gradient_checkpointing", False),
44
+ )
45
+
46
+ train(dataset, trainable_model, config, test_function)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
omini/train_flux/train_multi_condition.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import random
4
+
5
+ from PIL import Image, ImageDraw
6
+
7
+ from datasets import load_dataset
8
+
9
+ from .trainer import OminiModel, get_config, train
10
+ from ..pipeline.flux_omini import Condition, convert_to_condition, generate
11
+ from .train_spatial_alignment import ImageConditionDataset
12
+
13
+
14
+ class ImageMultiConditionDataset(ImageConditionDataset):
15
+ def __getitem__(self, idx):
16
+ image = self.base_dataset[idx]["jpg"]
17
+ image = image.resize(self.target_size).convert("RGB")
18
+ description = self.base_dataset[idx]["json"]["prompt"]
19
+
20
+ condition_size = self.condition_size
21
+ position_scale = self.position_scale
22
+
23
+ condition_imgs, position_deltas = [], []
24
+ for c_type in self.condition_type:
25
+ condition_img, position_delta = self.__get_condition__(image, c_type)
26
+ condition_imgs.append(condition_img.convert("RGB"))
27
+ position_deltas.append(position_delta)
28
+
29
+ # Randomly drop text or image (for training)
30
+ drop_text = random.random() < self.drop_text_prob
31
+ drop_image = random.random() < self.drop_image_prob
32
+
33
+ if drop_text:
34
+ description = ""
35
+ if drop_image:
36
+ condition_imgs = [
37
+ Image.new("RGB", condition_size)
38
+ for _ in range(len(self.condition_type))
39
+ ]
40
+
41
+ return_dict = {
42
+ "image": self.to_tensor(image),
43
+ "description": description,
44
+ **({"pil_image": [image, condition_img]} if self.return_pil_image else {}),
45
+ }
46
+
47
+ for i, c_type in enumerate(self.condition_type):
48
+ return_dict[f"condition_{i}"] = self.to_tensor(condition_imgs[i])
49
+ return_dict[f"condition_type_{i}"] = self.condition_type[i]
50
+ return_dict[f"position_delta_{i}"] = position_deltas[i]
51
+ return_dict[f"position_scale_{i}"] = position_scale
52
+
53
+ return return_dict
54
+
55
+
56
+ @torch.no_grad()
57
+ def test_function(model, save_path, file_name):
58
+ condition_size = model.training_config["dataset"]["condition_size"]
59
+ target_size = model.training_config["dataset"]["target_size"]
60
+
61
+ position_delta = model.training_config["dataset"].get("position_delta", [0, 0])
62
+ position_scale = model.training_config["dataset"].get("position_scale", 1.0)
63
+
64
+ condition_type = model.training_config["condition_type"]
65
+ test_list = []
66
+
67
+ condition_list = []
68
+ for i, c_type in enumerate(condition_type):
69
+ if c_type in ["canny", "coloring", "deblurring", "depth"]:
70
+ image = Image.open("assets/vase_hq.jpg")
71
+ image = image.resize(condition_size)
72
+ condition_img = convert_to_condition(c_type, image, 5)
73
+ elif c_type == "fill":
74
+ condition_img = image.resize(condition_size).convert("RGB")
75
+ w, h = image.size
76
+ x1, x2 = sorted([random.randint(0, w), random.randint(0, w)])
77
+ y1, y2 = sorted([random.randint(0, h), random.randint(0, h)])
78
+ mask = Image.new("L", image.size, 0)
79
+ draw = ImageDraw.Draw(mask)
80
+ draw.rectangle([x1, y1, x2, y2], fill=255)
81
+ if random.random() > 0.5:
82
+ mask = Image.eval(mask, lambda a: 255 - a)
83
+ condition_img = Image.composite(
84
+ image, Image.new("RGB", image.size, (0, 0, 0)), mask
85
+ )
86
+ else:
87
+ raise NotImplementedError
88
+ condition = Condition(
89
+ condition_img,
90
+ model.adapter_names[i + 2],
91
+ position_delta,
92
+ position_scale,
93
+ )
94
+ condition_list.append(condition)
95
+ test_list.append((condition_list, "A beautiful vase on a table."))
96
+ os.makedirs(save_path, exist_ok=True)
97
+ for i, (condition, prompt) in enumerate(test_list):
98
+ generator = torch.Generator(device=model.device)
99
+ generator.manual_seed(42)
100
+
101
+ res = generate(
102
+ model.flux_pipe,
103
+ prompt=prompt,
104
+ conditions=condition_list,
105
+ height=target_size[0],
106
+ width=target_size[1],
107
+ generator=generator,
108
+ model_config=model.model_config,
109
+ kv_cache=model.model_config.get("independent_condition", False),
110
+ )
111
+ file_path = os.path.join(
112
+ save_path, f"{file_name}_{'|'.join(condition_type)}_{i}.jpg"
113
+ )
114
+ res.images[0].save(file_path)
115
+
116
+
117
+ def main():
118
+ # Initialize
119
+ config = get_config()
120
+ training_config = config["train"]
121
+ torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", 0)))
122
+
123
+ # Initialize dataset
124
+ dataset = load_dataset(
125
+ "webdataset",
126
+ data_files={"train": training_config["dataset"]["urls"]},
127
+ split="train",
128
+ cache_dir="cache/t2i2m",
129
+ num_proc=32,
130
+ )
131
+ dataset = ImageMultiConditionDataset(
132
+ dataset,
133
+ condition_size=training_config["dataset"]["condition_size"],
134
+ target_size=training_config["dataset"]["target_size"],
135
+ condition_type=training_config["condition_type"],
136
+ drop_text_prob=training_config["dataset"]["drop_text_prob"],
137
+ drop_image_prob=training_config["dataset"]["drop_image_prob"],
138
+ position_scale=training_config["dataset"].get("position_scale", 1.0),
139
+ )
140
+
141
+ cond_n = len(training_config["condition_type"])
142
+
143
+ # Initialize model
144
+ trainable_model = OminiModel(
145
+ flux_pipe_id=config["flux_path"],
146
+ lora_config=training_config["lora_config"],
147
+ device=f"cuda",
148
+ dtype=getattr(torch, config["dtype"]),
149
+ optimizer_config=training_config["optimizer"],
150
+ model_config=config.get("model", {}),
151
+ gradient_checkpointing=training_config.get("gradient_checkpointing", False),
152
+ adapter_names=[None, None, *["default"] * cond_n],
153
+ # In this setting, all the conditions are using the same LoRA adapter
154
+ )
155
+
156
+ train(dataset, trainable_model, config, test_function)
157
+
158
+
159
+ if __name__ == "__main__":
160
+ main()