lv12 commited on
Commit
0c56232
1 Parent(s): 66ad261

full set multi loss ESCI triplets

Browse files
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 768,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": true,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": false,
9
+ "include_prompt": true
10
+ }
README.md ADDED
@@ -0,0 +1,950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: []
3
+ library_name: sentence-transformers
4
+ tags:
5
+ - sentence-transformers
6
+ - sentence-similarity
7
+ - feature-extraction
8
+ - generated_from_trainer
9
+ - dataset_size:1182198
10
+ - loss:CachedMultipleNegativesRankingLoss
11
+ - loss:AnglELoss
12
+ base_model: nomic-ai/nomic-embed-text-v1.5
13
+ datasets: []
14
+ metrics:
15
+ - cosine_accuracy
16
+ - dot_accuracy
17
+ - manhattan_accuracy
18
+ - euclidean_accuracy
19
+ - max_accuracy
20
+ - pearson_cosine
21
+ - spearman_cosine
22
+ - pearson_manhattan
23
+ - spearman_manhattan
24
+ - pearson_euclidean
25
+ - spearman_euclidean
26
+ - pearson_dot
27
+ - spearman_dot
28
+ - pearson_max
29
+ - spearman_max
30
+ widget:
31
+ - source_sentence: dog instrument toy
32
+ sentences:
33
+ - VATOS 25-in-1 Mars Rover Building Kit Outer Space Explorer Educational Construction
34
+ Toy for Kids 556 Pieces Solar Powered STEM Science Building Blocks Set, VATOS,
35
+ White
36
+ - Prefer Green 7 PCS Portion Control Containers Kit (with COMPLETE GUIDE & 21 DAY
37
+ DAILY TRACKER & 21 DAY MEAL PLANNER & RECIPES PDFs),Label-Coded,Multi-Color-Coded
38
+ System,Perfect Size for Lose Weight, Prefer Green, 7 PCS
39
+ - Coolibar UPF 50+ Men's Women's Gannett UV Gloves - Sun Protective (Medium- Light
40
+ Blue), Coolibar, Light Blue
41
+ - source_sentence: flame decal stickers
42
+ sentences:
43
+ - Tribal Flames Splash Pair - Vinyl Decal Sticker - 12" x 5" - Blue Flames, Sticker
44
+ Pimp, Blue Flames
45
+ - PC Gaming Headset Headphone Hook Holder Hanger Mount, Headphones Stand with Adjustable
46
+ & Rotating Arm Clamp , Under Desk Design , Universal Fit , Built in Cable Clip
47
+ Organizer EURPMASK, EURPMASK Choose the color of europe, Black
48
+ - Quick Charge 3.0 Wall Charger, 4-Pack 18W QC 3.0 USB Charger Adapter Fast Charging
49
+ Block Compatible Wireless Charger Compatible with Samsung Galaxy S10 S9 S8 Plus
50
+ S7 S6 Edge Note 9, LG, Kindle, Tablet, HONOT, Black
51
+ - source_sentence: 'search_query: softies women''s ultra soft marshmallow hooded lounger'
52
+ sentences:
53
+ - 'search_document: Red-A Placemats for Dining Table Set of 6 Heat-Resistant Wipeable
54
+ Table Mats for Kitchen Table Decoration Waterproof Vinyl Placemats Easy to Clean,Black
55
+ w/Brown, Red-A, Black'
56
+ - 'search_document: Softies Women''s Ultra Soft Marshmallow Hooded Lounger, Platinum,
57
+ L/XL, Softies, Platinum'
58
+ - 'search_document: Ekouaer Women''s Sleepwear Robe with Pockets Plus Size Maxi
59
+ Lounger Zipper Short Sleeve Bathrobe Housecoat (Black,L), Ekouaer, Black'
60
+ - source_sentence: 'search_query: wine glasses without stem'
61
+ sentences:
62
+ - 'search_document: STAUBER Best Bulb Changer with PowerLatch Extension Pole (Large
63
+ Suction, 4 Feet), STAUBER, Large Suction'
64
+ - 'search_document: Hand Blown Italian Style Crystal Burgundy Wine Glasses - Lead-Free
65
+ Premium Crystal Clear Glass - Set of 2 - 21 Ounce - Gift-Box for any Occasion,
66
+ JBHO, Burgundy'
67
+ - 'search_document: MyGift Modern Copper Stemless Wine Glasses, Set of 4, MyGift,
68
+ Copper'
69
+ - source_sentence: 'search_query: weighted blanket without glass beads'
70
+ sentences:
71
+ - 'search_document: Eigso Women Men Spike Punk Rock Black Leather Cuff Rivet Bracelet
72
+ Bangle Adjustable Snap Button, Eigso, Black'
73
+ - 'search_document: Quility Weighted Blanket with Soft Cover - 20 lbs Full/Queen
74
+ Size Heavy Blanket for Adults - Heating & Cooling, Machine Washable - (60" X 80")
75
+ (Navy), Quility, Navy Cover + Grey Cotton Blanket'
76
+ - 'search_document: Bedsure Queen Weighted Blanket 15 Pounds - Adult Weighted Blanket
77
+ 60x80 - Soft Heavy Blanket with Breathable TPE Insert No Glass Beads, Bedsure,
78
+ Navy'
79
+ pipeline_tag: sentence-similarity
80
+ model-index:
81
+ - name: SentenceTransformer based on nomic-ai/nomic-embed-text-v1.5
82
+ results:
83
+ - task:
84
+ type: triplet
85
+ name: Triplet
86
+ dataset:
87
+ name: Unknown
88
+ type: unknown
89
+ metrics:
90
+ - type: cosine_accuracy
91
+ value: 0.7236
92
+ name: Cosine Accuracy
93
+ - type: dot_accuracy
94
+ value: 0.282
95
+ name: Dot Accuracy
96
+ - type: manhattan_accuracy
97
+ value: 0.7231
98
+ name: Manhattan Accuracy
99
+ - type: euclidean_accuracy
100
+ value: 0.7227
101
+ name: Euclidean Accuracy
102
+ - type: max_accuracy
103
+ value: 0.7236
104
+ name: Max Accuracy
105
+ - task:
106
+ type: semantic-similarity
107
+ name: Semantic Similarity
108
+ dataset:
109
+ name: Unknown
110
+ type: unknown
111
+ metrics:
112
+ - type: pearson_cosine
113
+ value: 0.4912162846043421
114
+ name: Pearson Cosine
115
+ - type: spearman_cosine
116
+ value: 0.4658522123059972
117
+ name: Spearman Cosine
118
+ - type: pearson_manhattan
119
+ value: 0.4599741171303018
120
+ name: Pearson Manhattan
121
+ - type: spearman_manhattan
122
+ value: 0.4428141949345816
123
+ name: Spearman Manhattan
124
+ - type: pearson_euclidean
125
+ value: 0.46194545823984606
126
+ name: Pearson Euclidean
127
+ - type: spearman_euclidean
128
+ value: 0.44478471500226807
129
+ name: Spearman Euclidean
130
+ - type: pearson_dot
131
+ value: 0.45451995456560107
132
+ name: Pearson Dot
133
+ - type: spearman_dot
134
+ value: 0.43844636325741904
135
+ name: Spearman Dot
136
+ - type: pearson_max
137
+ value: 0.4912162846043421
138
+ name: Pearson Max
139
+ - type: spearman_max
140
+ value: 0.4658522123059972
141
+ name: Spearman Max
142
+ ---
143
+
144
+ # SentenceTransformer based on nomic-ai/nomic-embed-text-v1.5
145
+
146
+ This is a [sentence-transformers](https://www.SBERT.net) model finetuned from [nomic-ai/nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) on the triplets and pairs datasets. It maps sentences & paragraphs to a 768-dimensional dense vector space and can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
147
+
148
+ ## Model Details
149
+
150
+ ### Model Description
151
+ - **Model Type:** Sentence Transformer
152
+ - **Base model:** [nomic-ai/nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) <!-- at revision b0753ae76394dd36bcfb912a46018088bca48be0 -->
153
+ - **Maximum Sequence Length:** 8192 tokens
154
+ - **Output Dimensionality:** 768 tokens
155
+ - **Similarity Function:** Cosine Similarity
156
+ - **Training Datasets:**
157
+ - triplets
158
+ - pairs
159
+ <!-- - **Language:** Unknown -->
160
+ <!-- - **License:** Unknown -->
161
+
162
+ ### Model Sources
163
+
164
+ - **Documentation:** [Sentence Transformers Documentation](https://sbert.net)
165
+ - **Repository:** [Sentence Transformers on GitHub](https://github.com/UKPLab/sentence-transformers)
166
+ - **Hugging Face:** [Sentence Transformers on Hugging Face](https://huggingface.co/models?library=sentence-transformers)
167
+
168
+ ### Full Model Architecture
169
+
170
+ ```
171
+ SentenceTransformer(
172
+ (0): Transformer({'max_seq_length': 8192, 'do_lower_case': False}) with Transformer model: NomicBertModel
173
+ (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
174
+ )
175
+ ```
176
+
177
+ ## Usage
178
+
179
+ ### Direct Usage (Sentence Transformers)
180
+
181
+ First install the Sentence Transformers library:
182
+
183
+ ```bash
184
+ pip install -U sentence-transformers
185
+ ```
186
+
187
+ Then you can load this model and run inference.
188
+ ```python
189
+ from sentence_transformers import SentenceTransformer
190
+
191
+ # Download from the 🤗 Hub
192
+ model = SentenceTransformer("lv12/esci-nomic-embed-text-v1_5_4")
193
+ # Run inference
194
+ sentences = [
195
+ 'search_query: weighted blanket without glass beads',
196
+ 'search_document: Bedsure Queen Weighted Blanket 15 Pounds - Adult Weighted Blanket 60x80 - Soft Heavy Blanket with Breathable TPE Insert No Glass Beads, Bedsure, Navy',
197
+ 'search_document: Quility Weighted Blanket with Soft Cover - 20 lbs Full/Queen Size Heavy Blanket for Adults - Heating & Cooling, Machine Washable - (60" X 80") (Navy), Quility, Navy Cover + Grey Cotton Blanket',
198
+ ]
199
+ embeddings = model.encode(sentences)
200
+ print(embeddings.shape)
201
+ # [3, 768]
202
+
203
+ # Get the similarity scores for the embeddings
204
+ similarities = model.similarity(embeddings, embeddings)
205
+ print(similarities.shape)
206
+ # [3, 3]
207
+ ```
208
+
209
+ <!--
210
+ ### Direct Usage (Transformers)
211
+
212
+ <details><summary>Click to see the direct usage in Transformers</summary>
213
+
214
+ </details>
215
+ -->
216
+
217
+ <!--
218
+ ### Downstream Usage (Sentence Transformers)
219
+
220
+ You can finetune this model on your own dataset.
221
+
222
+ <details><summary>Click to expand</summary>
223
+
224
+ </details>
225
+ -->
226
+
227
+ <!--
228
+ ### Out-of-Scope Use
229
+
230
+ *List how the model may foreseeably be misused and address what users ought not to do with the model.*
231
+ -->
232
+
233
+ ## Evaluation
234
+
235
+ ### Metrics
236
+
237
+ #### Triplet
238
+
239
+ * Evaluated with [<code>TripletEvaluator</code>](https://sbert.net/docs/package_reference/sentence_transformer/evaluation.html#sentence_transformers.evaluation.TripletEvaluator)
240
+
241
+ | Metric | Value |
242
+ |:--------------------|:-----------|
243
+ | **cosine_accuracy** | **0.7236** |
244
+ | dot_accuracy | 0.282 |
245
+ | manhattan_accuracy | 0.7231 |
246
+ | euclidean_accuracy | 0.7227 |
247
+ | max_accuracy | 0.7236 |
248
+
249
+ #### Semantic Similarity
250
+
251
+ * Evaluated with [<code>EmbeddingSimilarityEvaluator</code>](https://sbert.net/docs/package_reference/sentence_transformer/evaluation.html#sentence_transformers.evaluation.EmbeddingSimilarityEvaluator)
252
+
253
+ | Metric | Value |
254
+ |:--------------------|:-----------|
255
+ | pearson_cosine | 0.4912 |
256
+ | **spearman_cosine** | **0.4659** |
257
+ | pearson_manhattan | 0.46 |
258
+ | spearman_manhattan | 0.4428 |
259
+ | pearson_euclidean | 0.4619 |
260
+ | spearman_euclidean | 0.4448 |
261
+ | pearson_dot | 0.4545 |
262
+ | spearman_dot | 0.4384 |
263
+ | pearson_max | 0.4912 |
264
+ | spearman_max | 0.4659 |
265
+
266
+ <!--
267
+ ## Bias, Risks and Limitations
268
+
269
+ *What are the known or foreseeable issues stemming from this model? You could also flag here known failure cases or weaknesses of the model.*
270
+ -->
271
+
272
+ <!--
273
+ ### Recommendations
274
+
275
+ *What are recommendations with respect to the foreseeable issues? For example, filtering explicit content.*
276
+ -->
277
+
278
+ ## Training Details
279
+
280
+ ### Training Datasets
281
+
282
+ #### triplets
283
+
284
+ * Dataset: triplets
285
+ * Size: 684,084 training samples
286
+ * Columns: <code>anchor</code>, <code>positive</code>, and <code>negative</code>
287
+ * Approximate statistics based on the first 1000 samples:
288
+ | | anchor | positive | negative |
289
+ |:--------|:---------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|
290
+ | type | string | string | string |
291
+ | details | <ul><li>min: 7 tokens</li><li>mean: 11.1 tokens</li><li>max: 22 tokens</li></ul> | <ul><li>min: 17 tokens</li><li>mean: 42.75 tokens</li><li>max: 95 tokens</li></ul> | <ul><li>min: 15 tokens</li><li>mean: 43.8 tokens</li><li>max: 127 tokens</li></ul> |
292
+ * Samples:
293
+ | anchor | positive | negative |
294
+ |:----------------------------------------------------------------|:--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
295
+ | <code>search_query: tarps heavy duty waterproof 8x10</code> | <code>search_document: 8' x 10' Super Heavy Duty 16 Mil Brown Poly Tarp Cover - Thick Waterproof, UV Resistant, Rip and Tear Proof Tarpaulin with Grommets and Reinforced Edges - by Xpose Safety, Xpose Safety, Brown</code> | <code>search_document: Grillkid 6'X8' 4.5 Mil Thick General Purpose Waterproof Poly Tarp, Grillkid, All Purpose</code> |
296
+ | <code>search_query: wireless keyboard without number pad</code> | <code>search_document: Macally 2.4G Small Wireless Keyboard - Ergonomic & Comfortable Computer Keyboard - Compact Keyboard for Laptop or Windows PC Desktop, Tablet, Smart TV - Plug & Play Mini Keyboard with 12 Hot Keys, Macally, Black</code> | <code>search_document: Wireless Keyboard - iClever GKA22S Rechargeable Keyboard with Number Pad, Full-Size Stainless Steel Ultra Slim Keyboard, 2.4G Stable Connection Wireless Keyboard for iMac, Mackbook, PC, Laptop, iClever, Silver</code> |
297
+ | <code>search_query: geometry earrings</code> | <code>search_document: Simple Stud Earrings for Women, Geometric Minimalist Stud Earring Set Tiny Circle Triangle Square Bar Stud Earrings Mini Cartilage Tragus Earrings, choice of all, B:Circle Sliver</code> | <code>search_document: BONALUNA Bohemian Wood And Marble Effect Oblong Shaped Drop Statement Earrings (VIVID TURQUOISE), BONALUNA, VIVID TURQUOISE</code> |
298
+ * Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
299
+ ```json
300
+ {
301
+ "scale": 20.0,
302
+ "similarity_fct": "cos_sim"
303
+ }
304
+ ```
305
+
306
+ #### pairs
307
+
308
+ * Dataset: pairs
309
+ * Size: 498,114 training samples
310
+ * Columns: <code>sentence1</code>, <code>sentence2</code>, and <code>score</code>
311
+ * Approximate statistics based on the first 1000 samples:
312
+ | | sentence1 | sentence2 | score |
313
+ |:--------|:---------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|:---------------------------------------------------------------|
314
+ | type | string | string | float |
315
+ | details | <ul><li>min: 3 tokens</li><li>mean: 6.73 tokens</li><li>max: 33 tokens</li></ul> | <ul><li>min: 10 tokens</li><li>mean: 40.14 tokens</li><li>max: 98 tokens</li></ul> | <ul><li>min: 0.0</li><li>mean: 0.81</li><li>max: 1.0</li></ul> |
316
+ * Samples:
317
+ | sentence1 | sentence2 | score |
318
+ |:-------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------|
319
+ | <code>I would choose a medium weight waterproof fabric, hip length jacket or longer, long sleeves, zip front, with a hood and deep pockets with zips</code> | <code>ZSHOW Men's Winter Hooded Packable Down Jacket(Blue, XX-Large), ZSHOW, Blue</code> | <code>1.0</code> |
320
+ | <code>sequin dance costume girls</code> | <code>Yeahdor Big Girls' Lyrical Latin Ballet Dance Costumes Dresses Halter Sequins Irregular Tutu Skirted Leotard Dancewear Pink 12-14, Yeahdor, Pink</code> | <code>1.0</code> |
321
+ | <code>paint easel bulk</code> | <code>Artecho Artist Easel Display Easel Stand, 2 Pack Metal Tripod Stand Easel for Painting, Hold Canvas from 21" to 66", Floor and Tabletop Displaying, Painting with Portable Bag, Artecho, Black</code> | <code>1.0</code> |
322
+ * Loss: [<code>AnglELoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#angleloss) with these parameters:
323
+ ```json
324
+ {
325
+ "scale": 20.0,
326
+ "similarity_fct": "pairwise_angle_sim"
327
+ }
328
+ ```
329
+
330
+ ### Evaluation Datasets
331
+
332
+ #### triplets
333
+
334
+ * Dataset: triplets
335
+ * Size: 10,000 evaluation samples
336
+ * Columns: <code>anchor</code>, <code>positive</code>, and <code>negative</code>
337
+ * Approximate statistics based on the first 1000 samples:
338
+ | | anchor | positive | negative |
339
+ |:--------|:----------------------------------------------------------------------------------|:------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------|
340
+ | type | string | string | string |
341
+ | details | <ul><li>min: 7 tokens</li><li>mean: 11.13 tokens</li><li>max: 23 tokens</li></ul> | <ul><li>min: 15 tokens</li><li>mean: 43.11 tokens</li><li>max: 107 tokens</li></ul> | <ul><li>min: 15 tokens</li><li>mean: 43.56 tokens</li><li>max: 99 tokens</li></ul> |
342
+ * Samples:
343
+ | anchor | positive | negative |
344
+ |:-------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
345
+ | <code>search_query: hitch fifth wheel</code> | <code>search_document: ENIXWILL 5th Wheel Trailer Hitch Lifting Device Bracket Pin Fit for Hitch Companion and Patriot Series Hitch, ENIXWILL, Black</code> | <code>search_document: ECOTRIC Fifth 5th Wheel Trailer Hitch Mount Rails and Installation Kits for Full-Size Trucks, ECOTRIC, black</code> |
346
+ | <code>search_query: dek pro</code> | <code>search_document: Cubiker Computer Desk 47 inch Home Office Writing Study Desk, Modern Simple Style Laptop Table with Storage Bag, Brown, Cubiker, Brown</code> | <code>search_document: FEZIBO Dual Motor L Shaped Electric Standing Desk, 48 Inches Stand Up Corner Desk, Home Office Sit Stand Desk with Rustic Brown Top and Black Frame, FEZIBO, Rustic Brown</code> |
347
+ | <code>search_query: 1 year baby mouth without teeth cleaner</code> | <code>search_document: Baby Toothbrush,Infant Toothbrush,Baby Tongue Cleaner,Infant Toothbrush,Baby Tongue Cleaner Newborn,Toothbrush Tongue Cleaner Dental Care for 0-36 Month Baby,36 Pcs + Free 4 Pcs, Babycolor, Blue</code> | <code>search_document: Slotic Baby Toothbrush for 0-2 Years, Safe and Sturdy, Toddler Oral Care Teether Brush, Extra Soft Bristle for Baby Teeth and Infant Gums, Dentist Recommended (4-Pack), Slotic, 4 Pack</code> |
348
+ * Loss: [<code>CachedMultipleNegativesRankingLoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#cachedmultiplenegativesrankingloss) with these parameters:
349
+ ```json
350
+ {
351
+ "scale": 20.0,
352
+ "similarity_fct": "cos_sim"
353
+ }
354
+ ```
355
+
356
+ #### pairs
357
+
358
+ * Dataset: pairs
359
+ * Size: 10,000 evaluation samples
360
+ * Columns: <code>sentence1</code>, <code>sentence2</code>, and <code>score</code>
361
+ * Approximate statistics based on the first 1000 samples:
362
+ | | sentence1 | sentence2 | score |
363
+ |:--------|:--------------------------------------------------------------------------------|:----------------------------------------------------------------------------------|:---------------------------------------------------------------|
364
+ | type | string | string | float |
365
+ | details | <ul><li>min: 3 tokens</li><li>mean: 6.8 tokens</li><li>max: 34 tokens</li></ul> | <ul><li>min: 9 tokens</li><li>mean: 39.7 tokens</li><li>max: 101 tokens</li></ul> | <ul><li>min: 0.0</li><li>mean: 0.77</li><li>max: 1.0</li></ul> |
366
+ * Samples:
367
+ | sentence1 | sentence2 | score |
368
+ |:------------------------------------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------|
369
+ | <code>outdoor ceiling fans without light</code> | <code>44" Plaza Industrial Indoor Outdoor Ceiling Fan with Remote Control Oil Rubbed Bronze Damp Rated for Patio Porch - Casa Vieja, Casa Vieja, No Light Kit - Bronze</code> | <code>1.0</code> |
370
+ | <code>bathroom cabinet</code> | <code>Homfa Bathroom Floor Cabinet Free Standing with Single Door Multifunctional Bathroom Storage Organizer Toiletries(Ivory White), Homfa, White</code> | <code>1.0</code> |
371
+ | <code>fitbit charge 3</code> | <code>TreasureMax Compatible with Fitbit Charge 2 Bands for Women/Men,Silicone Fadeless Pattern Printed Replacement Floral Bands for Fitbit Charge 2 HR Wristbands, TreasureMax, Paw 2</code> | <code>0.4</code> |
372
+ * Loss: [<code>AnglELoss</code>](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#angleloss) with these parameters:
373
+ ```json
374
+ {
375
+ "scale": 20.0,
376
+ "similarity_fct": "pairwise_angle_sim"
377
+ }
378
+ ```
379
+
380
+ ### Training Hyperparameters
381
+ #### Non-Default Hyperparameters
382
+
383
+ - `per_device_train_batch_size`: 16
384
+ - `per_device_eval_batch_size`: 4
385
+ - `gradient_accumulation_steps`: 2
386
+ - `learning_rate`: 1e-06
387
+ - `lr_scheduler_type`: cosine_with_restarts
388
+ - `lr_scheduler_kwargs`: {'num_cycles': 1}
389
+ - `warmup_ratio`: 0.01
390
+ - `dataloader_drop_last`: True
391
+ - `dataloader_num_workers`: 4
392
+ - `dataloader_prefetch_factor`: 4
393
+ - `load_best_model_at_end`: True
394
+ - `gradient_checkpointing`: True
395
+ - `batch_sampler`: no_duplicates
396
+
397
+ #### All Hyperparameters
398
+ <details><summary>Click to expand</summary>
399
+
400
+ - `overwrite_output_dir`: False
401
+ - `do_predict`: False
402
+ - `prediction_loss_only`: True
403
+ - `per_device_train_batch_size`: 16
404
+ - `per_device_eval_batch_size`: 4
405
+ - `per_gpu_train_batch_size`: None
406
+ - `per_gpu_eval_batch_size`: None
407
+ - `gradient_accumulation_steps`: 2
408
+ - `eval_accumulation_steps`: None
409
+ - `learning_rate`: 1e-06
410
+ - `weight_decay`: 0.0
411
+ - `adam_beta1`: 0.9
412
+ - `adam_beta2`: 0.999
413
+ - `adam_epsilon`: 1e-08
414
+ - `max_grad_norm`: 1.0
415
+ - `num_train_epochs`: 3
416
+ - `max_steps`: -1
417
+ - `lr_scheduler_type`: cosine_with_restarts
418
+ - `lr_scheduler_kwargs`: {'num_cycles': 1}
419
+ - `warmup_ratio`: 0.01
420
+ - `warmup_steps`: 0
421
+ - `log_level`: passive
422
+ - `log_level_replica`: warning
423
+ - `log_on_each_node`: True
424
+ - `logging_nan_inf_filter`: True
425
+ - `save_safetensors`: True
426
+ - `save_on_each_node`: False
427
+ - `save_only_model`: False
428
+ - `no_cuda`: False
429
+ - `use_cpu`: False
430
+ - `use_mps_device`: False
431
+ - `seed`: 42
432
+ - `data_seed`: None
433
+ - `jit_mode_eval`: False
434
+ - `use_ipex`: False
435
+ - `bf16`: False
436
+ - `fp16`: False
437
+ - `fp16_opt_level`: O1
438
+ - `half_precision_backend`: auto
439
+ - `bf16_full_eval`: False
440
+ - `fp16_full_eval`: False
441
+ - `tf32`: None
442
+ - `local_rank`: 0
443
+ - `ddp_backend`: None
444
+ - `tpu_num_cores`: None
445
+ - `tpu_metrics_debug`: False
446
+ - `debug`: []
447
+ - `dataloader_drop_last`: True
448
+ - `dataloader_num_workers`: 4
449
+ - `dataloader_prefetch_factor`: 4
450
+ - `past_index`: -1
451
+ - `disable_tqdm`: False
452
+ - `remove_unused_columns`: True
453
+ - `label_names`: None
454
+ - `load_best_model_at_end`: True
455
+ - `ignore_data_skip`: False
456
+ - `fsdp`: []
457
+ - `fsdp_min_num_params`: 0
458
+ - `fsdp_config`: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
459
+ - `fsdp_transformer_layer_cls_to_wrap`: None
460
+ - `accelerator_config`: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True}
461
+ - `deepspeed`: None
462
+ - `label_smoothing_factor`: 0.0
463
+ - `optim`: adamw_torch
464
+ - `optim_args`: None
465
+ - `adafactor`: False
466
+ - `group_by_length`: False
467
+ - `length_column_name`: length
468
+ - `ddp_find_unused_parameters`: None
469
+ - `ddp_bucket_cap_mb`: None
470
+ - `ddp_broadcast_buffers`: False
471
+ - `dataloader_pin_memory`: True
472
+ - `dataloader_persistent_workers`: False
473
+ - `skip_memory_metrics`: True
474
+ - `use_legacy_prediction_loop`: False
475
+ - `push_to_hub`: False
476
+ - `resume_from_checkpoint`: None
477
+ - `hub_model_id`: None
478
+ - `hub_strategy`: every_save
479
+ - `hub_private_repo`: False
480
+ - `hub_always_push`: False
481
+ - `gradient_checkpointing`: True
482
+ - `gradient_checkpointing_kwargs`: None
483
+ - `include_inputs_for_metrics`: False
484
+ - `fp16_backend`: auto
485
+ - `push_to_hub_model_id`: None
486
+ - `push_to_hub_organization`: None
487
+ - `mp_parameters`:
488
+ - `auto_find_batch_size`: False
489
+ - `full_determinism`: False
490
+ - `torchdynamo`: None
491
+ - `ray_scope`: last
492
+ - `ddp_timeout`: 1800
493
+ - `torch_compile`: False
494
+ - `torch_compile_backend`: None
495
+ - `torch_compile_mode`: None
496
+ - `dispatch_batches`: None
497
+ - `split_batches`: None
498
+ - `include_tokens_per_second`: False
499
+ - `include_num_input_tokens_seen`: False
500
+ - `neftune_noise_alpha`: None
501
+ - `batch_sampler`: no_duplicates
502
+ - `multi_dataset_batch_sampler`: proportional
503
+
504
+ </details>
505
+
506
+ ### Training Logs
507
+ <details><summary>Click to expand</summary>
508
+
509
+ | Epoch | Step | Training Loss | pairs loss | triplets loss | cosine_accuracy | spearman_cosine |
510
+ |:------:|:-----:|:-------------:|:----------:|:-------------:|:---------------:|:---------------:|
511
+ | 0.0027 | 100 | 2.4909 | - | - | - | - |
512
+ | 0.0054 | 200 | 2.6666 | - | - | - | - |
513
+ | 0.0081 | 300 | 2.76 | - | - | - | - |
514
+ | 0.0108 | 400 | 2.6945 | - | - | - | - |
515
+ | 0.0135 | 500 | 2.9113 | - | - | - | - |
516
+ | 0.0162 | 600 | 2.3476 | - | - | - | - |
517
+ | 0.0189 | 700 | 2.2818 | - | - | - | - |
518
+ | 0.0217 | 800 | 2.4241 | - | - | - | - |
519
+ | 0.0244 | 900 | 2.5126 | - | - | - | - |
520
+ | 0.0271 | 1000 | 2.4106 | 4.7376 | 0.8087 | 0.6993 | 0.3844 |
521
+ | 0.0298 | 1100 | 2.2369 | - | - | - | - |
522
+ | 0.0325 | 1200 | 2.0614 | - | - | - | - |
523
+ | 0.0352 | 1300 | 2.2178 | - | - | - | - |
524
+ | 0.0379 | 1400 | 1.974 | - | - | - | - |
525
+ | 0.0406 | 1500 | 1.9364 | - | - | - | - |
526
+ | 0.0433 | 1600 | 2.0906 | - | - | - | - |
527
+ | 0.0460 | 1700 | 1.8783 | - | - | - | - |
528
+ | 0.0487 | 1800 | 2.1149 | - | - | - | - |
529
+ | 0.0514 | 1900 | 1.7162 | - | - | - | - |
530
+ | 0.0541 | 2000 | 1.6761 | 3.8862 | 0.7490 | 0.7097 | 0.4082 |
531
+ | 0.0568 | 2100 | 2.1701 | - | - | - | - |
532
+ | 0.0596 | 2200 | 2.1306 | - | - | - | - |
533
+ | 0.0623 | 2300 | 1.6543 | - | - | - | - |
534
+ | 0.0650 | 2400 | 1.8157 | - | - | - | - |
535
+ | 0.0677 | 2500 | 1.7779 | - | - | - | - |
536
+ | 0.0704 | 2600 | 1.9434 | - | - | - | - |
537
+ | 0.0731 | 2700 | 1.7776 | - | - | - | - |
538
+ | 0.0758 | 2800 | 1.8197 | - | - | - | - |
539
+ | 0.0785 | 2900 | 1.9886 | - | - | - | - |
540
+ | 0.0812 | 3000 | 2.0699 | 3.8031 | 0.7298 | 0.7147 | 0.4282 |
541
+ | 0.0839 | 3100 | 1.9496 | - | - | - | - |
542
+ | 0.0866 | 3200 | 1.8349 | - | - | - | - |
543
+ | 0.0893 | 3300 | 2.111 | - | - | - | - |
544
+ | 0.0920 | 3400 | 1.9956 | - | - | - | - |
545
+ | 0.0947 | 3500 | 2.0379 | - | - | - | - |
546
+ | 0.0974 | 3600 | 1.8975 | - | - | - | - |
547
+ | 0.1002 | 3700 | 1.8552 | - | - | - | - |
548
+ | 0.1029 | 3800 | 1.9566 | - | - | - | - |
549
+ | 0.1056 | 3900 | 2.011 | - | - | - | - |
550
+ | 0.1083 | 4000 | 2.1263 | 3.7799 | 0.7221 | 0.7176 | 0.4393 |
551
+ | 0.1110 | 4100 | 1.8217 | - | - | - | - |
552
+ | 0.1137 | 4200 | 1.8638 | - | - | - | - |
553
+ | 0.1164 | 4300 | 1.7699 | - | - | - | - |
554
+ | 0.1191 | 4400 | 1.8248 | - | - | - | - |
555
+ | 0.1218 | 4500 | 1.835 | - | - | - | - |
556
+ | 0.1245 | 4600 | 1.9294 | - | - | - | - |
557
+ | 0.1272 | 4700 | 1.9817 | - | - | - | - |
558
+ | 0.1299 | 4800 | 1.877 | - | - | - | - |
559
+ | 0.1326 | 4900 | 1.5824 | - | - | - | - |
560
+ | 0.1353 | 5000 | 1.7429 | 3.7728 | 0.7163 | 0.7196 | 0.4496 |
561
+ | 0.1380 | 5100 | 1.8552 | - | - | - | - |
562
+ | 0.1408 | 5200 | 1.6888 | - | - | - | - |
563
+ | 0.1435 | 5300 | 1.9409 | - | - | - | - |
564
+ | 0.1462 | 5400 | 1.9389 | - | - | - | - |
565
+ | 0.1489 | 5500 | 1.82 | - | - | - | - |
566
+ | 0.1516 | 5600 | 1.9763 | - | - | - | - |
567
+ | 0.1543 | 5700 | 1.8122 | - | - | - | - |
568
+ | 0.1570 | 5800 | 1.7204 | - | - | - | - |
569
+ | 0.1597 | 5900 | 1.6901 | - | - | - | - |
570
+ | 0.1624 | 6000 | 1.7785 | 3.7514 | 0.7124 | 0.7195 | 0.4516 |
571
+ | 0.1651 | 6100 | 1.8559 | - | - | - | - |
572
+ | 0.1678 | 6200 | 1.7646 | - | - | - | - |
573
+ | 0.1705 | 6300 | 1.9068 | - | - | - | - |
574
+ | 0.1732 | 6400 | 1.8848 | - | - | - | - |
575
+ | 0.1759 | 6500 | 1.9384 | - | - | - | - |
576
+ | 0.1787 | 6600 | 1.7692 | - | - | - | - |
577
+ | 0.1814 | 6700 | 1.7093 | - | - | - | - |
578
+ | 0.1841 | 6800 | 1.8759 | - | - | - | - |
579
+ | 0.1868 | 6900 | 1.7319 | - | - | - | - |
580
+ | 0.1895 | 7000 | 1.9428 | 3.7487 | 0.7076 | 0.7256 | 0.4496 |
581
+ | 0.1922 | 7100 | 1.5733 | - | - | - | - |
582
+ | 0.1949 | 7200 | 1.8487 | - | - | - | - |
583
+ | 0.1976 | 7300 | 1.8361 | - | - | - | - |
584
+ | 0.2003 | 7400 | 1.9911 | - | - | - | - |
585
+ | 0.2030 | 7500 | 1.784 | - | - | - | - |
586
+ | 0.2057 | 7600 | 1.8518 | - | - | - | - |
587
+ | 0.2084 | 7700 | 1.6232 | - | - | - | - |
588
+ | 0.2111 | 7800 | 1.6239 | - | - | - | - |
589
+ | 0.2138 | 7900 | 1.7589 | - | - | - | - |
590
+ | 0.2165 | 8000 | 1.8644 | 3.7387 | 0.7040 | 0.7241 | 0.4552 |
591
+ | 0.2193 | 8100 | 1.7903 | - | - | - | - |
592
+ | 0.2220 | 8200 | 1.7197 | - | - | - | - |
593
+ | 0.2247 | 8300 | 1.9099 | - | - | - | - |
594
+ | 0.2274 | 8400 | 1.6778 | - | - | - | - |
595
+ | 0.2301 | 8500 | 1.9249 | - | - | - | - |
596
+ | 0.2328 | 8600 | 1.8483 | - | - | - | - |
597
+ | 0.2355 | 8700 | 1.6849 | - | - | - | - |
598
+ | 0.2382 | 8800 | 1.8647 | - | - | - | - |
599
+ | 0.2409 | 8900 | 1.8826 | - | - | - | - |
600
+ | 0.2436 | 9000 | 1.7632 | 3.7403 | 0.7033 | 0.7225 | 0.4545 |
601
+ | 0.2463 | 9100 | 1.8142 | - | - | - | - |
602
+ | 0.2490 | 9200 | 1.7374 | - | - | - | - |
603
+ | 0.2517 | 9300 | 1.8646 | - | - | - | - |
604
+ | 0.2544 | 9400 | 1.7623 | - | - | - | - |
605
+ | 0.2571 | 9500 | 1.7802 | - | - | - | - |
606
+ | 0.2599 | 9600 | 1.843 | - | - | - | - |
607
+ | 0.2626 | 9700 | 1.9797 | - | - | - | - |
608
+ | 0.2653 | 9800 | 1.7748 | - | - | - | - |
609
+ | 0.2680 | 9900 | 1.7031 | - | - | - | - |
610
+ | 0.2707 | 10000 | 1.5536 | 3.7613 | 0.7016 | 0.7259 | 0.4548 |
611
+ | 0.2734 | 10100 | 1.7663 | - | - | - | - |
612
+ | 0.2761 | 10200 | 1.8218 | - | - | - | - |
613
+ | 0.2788 | 10300 | 1.6327 | - | - | - | - |
614
+ | 0.2815 | 10400 | 1.8802 | - | - | - | - |
615
+ | 0.2842 | 10500 | 1.6294 | - | - | - | - |
616
+ | 0.2869 | 10600 | 1.9001 | - | - | - | - |
617
+ | 0.2896 | 10700 | 1.7873 | - | - | - | - |
618
+ | 0.2923 | 10800 | 1.8121 | - | - | - | - |
619
+ | 0.2950 | 10900 | 2.0197 | - | - | - | - |
620
+ | 0.2978 | 11000 | 1.7006 | 3.7559 | 0.7004 | 0.727 | 0.4613 |
621
+ | 0.3005 | 11100 | 1.6404 | - | - | - | - |
622
+ | 0.3032 | 11200 | 1.9422 | - | - | - | - |
623
+ | 0.3059 | 11300 | 1.5917 | - | - | - | - |
624
+ | 0.3086 | 11400 | 1.7236 | - | - | - | - |
625
+ | 0.3113 | 11500 | 1.8977 | - | - | - | - |
626
+ | 0.3140 | 11600 | 1.7686 | - | - | - | - |
627
+ | 0.3167 | 11700 | 1.4493 | - | - | - | - |
628
+ | 0.3194 | 11800 | 1.7447 | - | - | - | - |
629
+ | 0.3221 | 11900 | 1.9412 | - | - | - | - |
630
+ | 0.3248 | 12000 | 1.8 | 3.7308 | 0.6997 | 0.7241 | 0.4618 |
631
+ | 0.3275 | 12100 | 1.8855 | - | - | - | - |
632
+ | 0.3302 | 12200 | 1.5133 | - | - | - | - |
633
+ | 0.3329 | 12300 | 1.7893 | - | - | - | - |
634
+ | 0.3356 | 12400 | 1.7861 | - | - | - | - |
635
+ | 0.3384 | 12500 | 1.7733 | - | - | - | - |
636
+ | 0.3411 | 12600 | 1.5877 | - | - | - | - |
637
+ | 0.3438 | 12700 | 2.03 | - | - | - | - |
638
+ | 0.3465 | 12800 | 1.7071 | - | - | - | - |
639
+ | 0.3492 | 12900 | 1.7848 | - | - | - | - |
640
+ | 0.3519 | 13000 | 1.7508 | 3.7326 | 0.7006 | 0.7247 | 0.4583 |
641
+ | 0.3546 | 13100 | 1.7667 | - | - | - | - |
642
+ | 0.3573 | 13200 | 1.6415 | - | - | - | - |
643
+ | 0.3600 | 13300 | 1.7501 | - | - | - | - |
644
+ | 0.3627 | 13400 | 1.8451 | - | - | - | - |
645
+ | 0.3654 | 13500 | 1.7146 | - | - | - | - |
646
+ | 0.3681 | 13600 | 1.6837 | - | - | - | - |
647
+ | 0.3708 | 13700 | 1.92 | - | - | - | - |
648
+ | 0.3735 | 13800 | 1.6925 | - | - | - | - |
649
+ | 0.3763 | 13900 | 1.7799 | - | - | - | - |
650
+ | 0.3790 | 14000 | 1.527 | 3.7260 | 0.6989 | 0.727 | 0.4510 |
651
+ | 0.3817 | 14100 | 1.7222 | - | - | - | - |
652
+ | 0.3844 | 14200 | 1.8278 | - | - | - | - |
653
+ | 0.3871 | 14300 | 1.7669 | - | - | - | - |
654
+ | 0.3898 | 14400 | 1.5856 | - | - | - | - |
655
+ | 0.3925 | 14500 | 1.8234 | - | - | - | - |
656
+ | 0.3952 | 14600 | 1.7151 | - | - | - | - |
657
+ | 0.3979 | 14700 | 1.6432 | - | - | - | - |
658
+ | 0.4006 | 14800 | 1.9005 | - | - | - | - |
659
+ | 0.4033 | 14900 | 1.6946 | - | - | - | - |
660
+ | 0.4060 | 15000 | 1.5543 | 3.7222 | 0.6969 | 0.7275 | 0.4634 |
661
+ | 0.4087 | 15100 | 1.6736 | - | - | - | - |
662
+ | 0.4114 | 15200 | 1.8898 | - | - | - | - |
663
+ | 0.4141 | 15300 | 1.7224 | - | - | - | - |
664
+ | 0.4169 | 15400 | 1.7909 | - | - | - | - |
665
+ | 0.4196 | 15500 | 1.6555 | - | - | - | - |
666
+ | 0.4223 | 15600 | 1.523 | - | - | - | - |
667
+ | 0.4250 | 15700 | 1.7539 | - | - | - | - |
668
+ | 0.4277 | 15800 | 1.5763 | - | - | - | - |
669
+ | 0.4304 | 15900 | 1.7247 | - | - | - | - |
670
+ | 0.4331 | 16000 | 1.876 | 3.7105 | 0.6977 | 0.7263 | 0.4636 |
671
+ | 0.4358 | 16100 | 1.772 | - | - | - | - |
672
+ | 0.4385 | 16200 | 1.6774 | - | - | - | - |
673
+ | 0.4412 | 16300 | 1.7602 | - | - | - | - |
674
+ | 0.4439 | 16400 | 1.705 | - | - | - | - |
675
+ | 0.4466 | 16500 | 1.7893 | - | - | - | - |
676
+ | 0.4493 | 16600 | 1.653 | - | - | - | - |
677
+ | 0.4520 | 16700 | 1.8326 | - | - | - | - |
678
+ | 0.4547 | 16800 | 1.5326 | - | - | - | - |
679
+ | 0.4575 | 16900 | 1.8251 | - | - | - | - |
680
+ | 0.4602 | 17000 | 1.766 | 3.7193 | 0.6973 | 0.7257 | 0.4655 |
681
+ | 0.4629 | 17100 | 1.7162 | - | - | - | - |
682
+ | 0.4656 | 17200 | 1.6969 | - | - | - | - |
683
+ | 0.4683 | 17300 | 1.5172 | - | - | - | - |
684
+ | 0.4710 | 17400 | 1.7102 | - | - | - | - |
685
+ | 0.4737 | 17500 | 1.8369 | - | - | - | - |
686
+ | 0.4764 | 17600 | 1.8069 | - | - | - | - |
687
+ | 0.4791 | 17700 | 1.6299 | - | - | - | - |
688
+ | 0.4818 | 17800 | 1.8474 | - | - | - | - |
689
+ | 0.4845 | 17900 | 1.5864 | - | - | - | - |
690
+ | 0.4872 | 18000 | 1.7455 | 3.7087 | 0.6986 | 0.7249 | 0.4626 |
691
+ | 0.4899 | 18100 | 1.8263 | - | - | - | - |
692
+ | 0.4926 | 18200 | 1.8548 | - | - | - | - |
693
+ | 0.4954 | 18300 | 1.6442 | - | - | - | - |
694
+ | 0.4981 | 18400 | 1.7467 | - | - | - | - |
695
+ | 0.5008 | 18500 | 1.6174 | - | - | - | - |
696
+ | 0.5035 | 18600 | 1.4465 | - | - | - | - |
697
+ | 0.5062 | 18700 | 1.8866 | - | - | - | - |
698
+ | 0.5089 | 18800 | 1.72 | - | - | - | - |
699
+ | 0.5116 | 18900 | 1.7466 | - | - | - | - |
700
+ | 0.5143 | 19000 | 1.9124 | 3.7247 | 0.6979 | 0.725 | 0.4602 |
701
+ | 0.5170 | 19100 | 1.5687 | - | - | - | - |
702
+ | 0.5197 | 19200 | 1.6391 | - | - | - | - |
703
+ | 0.5224 | 19300 | 1.8248 | - | - | - | - |
704
+ | 0.5251 | 19400 | 1.6231 | - | - | - | - |
705
+ | 0.5278 | 19500 | 1.6152 | - | - | - | - |
706
+ | 0.5305 | 19600 | 1.639 | - | - | - | - |
707
+ | 0.5332 | 19700 | 1.6098 | - | - | - | - |
708
+ | 0.5360 | 19800 | 1.6619 | - | - | - | - |
709
+ | 0.5387 | 19900 | 1.6997 | - | - | - | - |
710
+ | 0.5414 | 20000 | 1.718 | 3.7259 | 0.6989 | 0.7264 | 0.4660 |
711
+ | 0.5441 | 20100 | 1.634 | - | - | - | - |
712
+ | 0.5468 | 20200 | 1.7865 | - | - | - | - |
713
+ | 0.5495 | 20300 | 1.8573 | - | - | - | - |
714
+ | 0.5522 | 20400 | 1.5575 | - | - | - | - |
715
+ | 0.5549 | 20500 | 1.6594 | - | - | - | - |
716
+ | 0.5576 | 20600 | 1.8793 | - | - | - | - |
717
+ | 0.5603 | 20700 | 1.7643 | - | - | - | - |
718
+ | 0.5630 | 20800 | 1.538 | - | - | - | - |
719
+ | 0.5657 | 20900 | 1.8634 | - | - | - | - |
720
+ | 0.5684 | 21000 | 1.916 | 3.7223 | 0.6982 | 0.7258 | 0.4650 |
721
+ | 0.5711 | 21100 | 1.5947 | - | - | - | - |
722
+ | 0.5738 | 21200 | 1.5321 | - | - | - | - |
723
+ | 0.5766 | 21300 | 1.7004 | - | - | - | - |
724
+ | 0.5793 | 21400 | 1.6947 | - | - | - | - |
725
+ | 0.5820 | 21500 | 1.5228 | - | - | - | - |
726
+ | 0.5847 | 21600 | 1.7152 | - | - | - | - |
727
+ | 0.5874 | 21700 | 1.6883 | - | - | - | - |
728
+ | 0.5901 | 21800 | 1.6779 | - | - | - | - |
729
+ | 0.5928 | 21900 | 1.7323 | - | - | - | - |
730
+ | 0.5955 | 22000 | 1.9633 | 3.7266 | 0.6996 | 0.7288 | 0.4635 |
731
+ | 0.5982 | 22100 | 1.7498 | - | - | - | - |
732
+ | 0.6009 | 22200 | 1.7513 | - | - | - | - |
733
+ | 0.6036 | 22300 | 1.7078 | - | - | - | - |
734
+ | 0.6063 | 22400 | 1.6438 | - | - | - | - |
735
+ | 0.6090 | 22500 | 1.6743 | - | - | - | - |
736
+ | 0.6117 | 22600 | 1.6701 | - | - | - | - |
737
+ | 0.6145 | 22700 | 1.7871 | - | - | - | - |
738
+ | 0.6172 | 22800 | 1.6247 | - | - | - | - |
739
+ | 0.6199 | 22900 | 1.7817 | - | - | - | - |
740
+ | 0.6226 | 23000 | 1.6606 | 3.7321 | 0.6993 | 0.7286 | 0.4614 |
741
+ | 0.6253 | 23100 | 1.8987 | - | - | - | - |
742
+ | 0.6280 | 23200 | 1.6494 | - | - | - | - |
743
+ | 0.6307 | 23300 | 1.6776 | - | - | - | - |
744
+ | 0.6334 | 23400 | 1.75 | - | - | - | - |
745
+ | 0.6361 | 23500 | 1.5131 | - | - | - | - |
746
+ | 0.6388 | 23600 | 1.7946 | - | - | - | - |
747
+ | 0.6415 | 23700 | 1.665 | - | - | - | - |
748
+ | 0.6442 | 23800 | 1.6681 | - | - | - | - |
749
+ | 0.6469 | 23900 | 1.8255 | - | - | - | - |
750
+ | 0.6496 | 24000 | 1.6759 | 3.7227 | 0.7017 | 0.7281 | 0.4625 |
751
+ | 0.6523 | 24100 | 1.554 | - | - | - | - |
752
+ | 0.6551 | 24200 | 1.6435 | - | - | - | - |
753
+ | 0.6578 | 24300 | 1.8224 | - | - | - | - |
754
+ | 0.6605 | 24400 | 1.6186 | - | - | - | - |
755
+ | 0.6632 | 24500 | 1.7156 | - | - | - | - |
756
+ | 0.6659 | 24600 | 1.5247 | - | - | - | - |
757
+ | 0.6686 | 24700 | 1.6264 | - | - | - | - |
758
+ | 0.6713 | 24800 | 1.7673 | - | - | - | - |
759
+ | 0.6740 | 24900 | 1.8072 | - | - | - | - |
760
+ | 0.6767 | 25000 | 1.765 | 3.7407 | 0.7026 | 0.7283 | 0.4589 |
761
+ | 0.6794 | 25100 | 1.6422 | - | - | - | - |
762
+ | 0.6821 | 25200 | 1.7846 | - | - | - | - |
763
+ | 0.6848 | 25300 | 1.7366 | - | - | - | - |
764
+ | 0.6875 | 25400 | 1.7839 | - | - | - | - |
765
+ | 0.6902 | 25500 | 1.441 | - | - | - | - |
766
+ | 0.6930 | 25600 | 1.5533 | - | - | - | - |
767
+ | 0.6957 | 25700 | 1.6922 | - | - | - | - |
768
+ | 0.6984 | 25800 | 1.5544 | - | - | - | - |
769
+ | 0.7011 | 25900 | 1.456 | - | - | - | - |
770
+ | 0.7038 | 26000 | 1.6494 | 3.7274 | 0.7059 | 0.7268 | 0.4661 |
771
+ | 0.7065 | 26100 | 1.6963 | - | - | - | - |
772
+ | 0.7092 | 26200 | 1.7892 | - | - | - | - |
773
+ | 0.7119 | 26300 | 1.6669 | - | - | - | - |
774
+ | 0.7146 | 26400 | 1.6758 | - | - | - | - |
775
+ | 0.7173 | 26500 | 1.6322 | - | - | - | - |
776
+ | 0.7200 | 26600 | 1.5416 | - | - | - | - |
777
+ | 0.7227 | 26700 | 1.681 | - | - | - | - |
778
+ | 0.7254 | 26800 | 1.5159 | - | - | - | - |
779
+ | 0.7281 | 26900 | 1.715 | - | - | - | - |
780
+ | 0.7308 | 27000 | 1.6164 | 3.7456 | 0.7061 | 0.7257 | 0.4570 |
781
+ | 0.7336 | 27100 | 1.6784 | - | - | - | - |
782
+ | 0.7363 | 27200 | 1.5886 | - | - | - | - |
783
+ | 0.7390 | 27300 | 1.6736 | - | - | - | - |
784
+ | 0.7417 | 27400 | 1.5659 | - | - | - | - |
785
+ | 0.7444 | 27500 | 1.6552 | - | - | - | - |
786
+ | 0.7471 | 27600 | 1.5672 | - | - | - | - |
787
+ | 0.7498 | 27700 | 1.5873 | - | - | - | - |
788
+ | 0.7525 | 27800 | 1.6746 | - | - | - | - |
789
+ | 0.7552 | 27900 | 1.7503 | - | - | - | - |
790
+ | 0.7579 | 28000 | 1.7287 | 3.7390 | 0.7076 | 0.7244 | 0.4636 |
791
+ | 0.7606 | 28100 | 1.6216 | - | - | - | - |
792
+ | 0.7633 | 28200 | 1.6101 | - | - | - | - |
793
+ | 0.7660 | 28300 | 1.5651 | - | - | - | - |
794
+ | 0.7687 | 28400 | 1.5659 | - | - | - | - |
795
+ | 0.7714 | 28500 | 1.5248 | - | - | - | - |
796
+ | 0.7742 | 28600 | 1.3725 | - | - | - | - |
797
+ | 0.7769 | 28700 | 1.7881 | - | - | - | - |
798
+ | 0.7796 | 28800 | 1.739 | - | - | - | - |
799
+ | 0.7823 | 28900 | 1.6464 | - | - | - | - |
800
+ | 0.7850 | 29000 | 1.6841 | 3.7212 | 0.7073 | 0.7247 | 0.4626 |
801
+ | 0.7877 | 29100 | 1.6254 | - | - | - | - |
802
+ | 0.7904 | 29200 | 1.6728 | - | - | - | - |
803
+ | 0.7931 | 29300 | 1.5605 | - | - | - | - |
804
+ | 0.7958 | 29400 | 1.687 | - | - | - | - |
805
+ | 0.7985 | 29500 | 1.7799 | - | - | - | - |
806
+ | 0.8012 | 29600 | 1.6792 | - | - | - | - |
807
+ | 0.8039 | 29700 | 1.5241 | - | - | - | - |
808
+ | 0.8066 | 29800 | 1.6341 | - | - | - | - |
809
+ | 0.8093 | 29900 | 1.5571 | - | - | - | - |
810
+ | 0.8121 | 30000 | 1.5228 | 3.7397 | 0.7105 | 0.7234 | 0.4682 |
811
+ | 0.8148 | 30100 | 1.5988 | - | - | - | - |
812
+ | 0.8175 | 30200 | 1.4222 | - | - | - | - |
813
+ | 0.8202 | 30300 | 1.4629 | - | - | - | - |
814
+ | 0.8229 | 30400 | 1.6381 | - | - | - | - |
815
+ | 0.8256 | 30500 | 1.4585 | - | - | - | - |
816
+ | 0.8283 | 30600 | 1.6774 | - | - | - | - |
817
+ | 0.8310 | 30700 | 1.811 | - | - | - | - |
818
+ | 0.8337 | 30800 | 1.5872 | - | - | - | - |
819
+ | 0.8364 | 30900 | 1.4762 | - | - | - | - |
820
+ | 0.8391 | 31000 | 1.7079 | 3.7256 | 0.7128 | 0.7215 | 0.4645 |
821
+ | 0.8418 | 31100 | 1.4948 | - | - | - | - |
822
+ | 0.8445 | 31200 | 1.4556 | - | - | - | - |
823
+ | 0.8472 | 31300 | 1.5191 | - | - | - | - |
824
+ | 0.8499 | 31400 | 1.598 | - | - | - | - |
825
+ | 0.8527 | 31500 | 1.6586 | - | - | - | - |
826
+ | 0.8554 | 31600 | 1.6893 | - | - | - | - |
827
+ | 0.8581 | 31700 | 1.7764 | - | - | - | - |
828
+ | 0.8608 | 31800 | 1.3632 | - | - | - | - |
829
+ | 0.8635 | 31900 | 1.6681 | - | - | - | - |
830
+ | 0.8662 | 32000 | 1.6232 | 3.7358 | 0.7161 | 0.7232 | 0.4651 |
831
+ | 0.8689 | 32100 | 1.4556 | - | - | - | - |
832
+ | 0.8716 | 32200 | 1.8698 | - | - | - | - |
833
+ | 0.8743 | 32300 | 1.7566 | - | - | - | - |
834
+ | 0.8770 | 32400 | 1.6082 | - | - | - | - |
835
+ | 0.8797 | 32500 | 1.6465 | - | - | - | - |
836
+ | 0.8824 | 32600 | 1.5018 | - | - | - | - |
837
+ | 0.8851 | 32700 | 1.8482 | - | - | - | - |
838
+ | 0.8878 | 32800 | 1.5147 | - | - | - | - |
839
+ | 0.8905 | 32900 | 1.699 | - | - | - | - |
840
+ | 0.8933 | 33000 | 1.5738 | 3.7323 | 0.7176 | 0.7246 | 0.4657 |
841
+ | 0.8960 | 33100 | 1.635 | - | - | - | - |
842
+ | 0.8987 | 33200 | 1.7069 | - | - | - | - |
843
+ | 0.9014 | 33300 | 1.6272 | - | - | - | - |
844
+ | 0.9041 | 33400 | 1.7648 | - | - | - | - |
845
+ | 0.9068 | 33500 | 1.6683 | - | - | - | - |
846
+ | 0.9095 | 33600 | 1.4867 | - | - | - | - |
847
+ | 0.9122 | 33700 | 1.6677 | - | - | - | - |
848
+ | 0.9149 | 33800 | 1.5527 | - | - | - | - |
849
+ | 0.9176 | 33900 | 1.6804 | - | - | - | - |
850
+ | 0.9203 | 34000 | 1.425 | 3.7477 | 0.7172 | 0.7231 | 0.4596 |
851
+ | 0.9230 | 34100 | 1.771 | - | - | - | - |
852
+ | 0.9257 | 34200 | 1.5767 | - | - | - | - |
853
+ | 0.9284 | 34300 | 1.5424 | - | - | - | - |
854
+ | 0.9312 | 34400 | 1.5985 | - | - | - | - |
855
+ | 0.9339 | 34500 | 1.6763 | - | - | - | - |
856
+ | 0.9366 | 34600 | 1.6608 | - | - | - | - |
857
+ | 0.9393 | 34700 | 1.7736 | - | - | - | - |
858
+ | 0.9420 | 34800 | 1.8955 | - | - | - | - |
859
+ | 0.9447 | 34900 | 1.5688 | - | - | - | - |
860
+ | 0.9474 | 35000 | 1.6123 | 3.7410 | 0.7196 | 0.7226 | 0.4671 |
861
+ | 0.9501 | 35100 | 1.7264 | - | - | - | - |
862
+ | 0.9528 | 35200 | 1.5511 | - | - | - | - |
863
+ | 0.9555 | 35300 | 1.6409 | - | - | - | - |
864
+ | 0.9582 | 35400 | 1.47 | - | - | - | - |
865
+ | 0.9609 | 35500 | 1.8675 | - | - | - | - |
866
+ | 0.9636 | 35600 | 1.6868 | - | - | - | - |
867
+ | 0.9663 | 35700 | 1.744 | - | - | - | - |
868
+ | 0.9690 | 35800 | 1.6734 | - | - | - | - |
869
+ | 0.9718 | 35900 | 1.4154 | - | - | - | - |
870
+ | 0.9745 | 36000 | 1.4793 | 3.7393 | 0.7190 | 0.7223 | 0.4677 |
871
+ | 0.9772 | 36100 | 1.7126 | - | - | - | - |
872
+ | 0.9799 | 36200 | 1.7037 | - | - | - | - |
873
+ | 0.9826 | 36300 | 1.6306 | - | - | - | - |
874
+ | 0.9853 | 36400 | 1.7783 | - | - | - | - |
875
+ | 0.9880 | 36500 | 1.5751 | - | - | - | - |
876
+ | 0.9907 | 36600 | 1.6079 | - | - | - | - |
877
+ | 0.9934 | 36700 | 1.7162 | - | - | - | - |
878
+ | 0.9961 | 36800 | 1.447 | - | - | - | - |
879
+ | 0.9988 | 36900 | 1.6155 | - | - | - | - |
880
+ | 1.0015 | 37000 | 1.7294 | 3.7512 | 0.7177 | 0.7236 | 0.4659 |
881
+
882
+ </details>
883
+
884
+ ### Framework Versions
885
+ - Python: 3.10.12
886
+ - Sentence Transformers: 3.0.1
887
+ - Transformers: 4.38.2
888
+ - PyTorch: 2.1.2+cu121
889
+ - Accelerate: 0.27.2
890
+ - Datasets: 2.19.1
891
+ - Tokenizers: 0.15.2
892
+
893
+ ## Citation
894
+
895
+ ### BibTeX
896
+
897
+ #### Sentence Transformers
898
+ ```bibtex
899
+ @inproceedings{reimers-2019-sentence-bert,
900
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
901
+ author = "Reimers, Nils and Gurevych, Iryna",
902
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
903
+ month = "11",
904
+ year = "2019",
905
+ publisher = "Association for Computational Linguistics",
906
+ url = "https://arxiv.org/abs/1908.10084",
907
+ }
908
+ ```
909
+
910
+ #### CachedMultipleNegativesRankingLoss
911
+ ```bibtex
912
+ @misc{gao2021scaling,
913
+ title={Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup},
914
+ author={Luyu Gao and Yunyi Zhang and Jiawei Han and Jamie Callan},
915
+ year={2021},
916
+ eprint={2101.06983},
917
+ archivePrefix={arXiv},
918
+ primaryClass={cs.LG}
919
+ }
920
+ ```
921
+
922
+ #### AnglELoss
923
+ ```bibtex
924
+ @misc{li2023angleoptimized,
925
+ title={AnglE-optimized Text Embeddings},
926
+ author={Xianming Li and Jing Li},
927
+ year={2023},
928
+ eprint={2309.12871},
929
+ archivePrefix={arXiv},
930
+ primaryClass={cs.CL}
931
+ }
932
+ ```
933
+
934
+ <!--
935
+ ## Glossary
936
+
937
+ *Clearly define terms in order to be accessible across audiences.*
938
+ -->
939
+
940
+ <!--
941
+ ## Model Card Authors
942
+
943
+ *Lists the people who create the model card, providing recognition and accountability for the detailed work that goes into its construction.*
944
+ -->
945
+
946
+ <!--
947
+ ## Model Card Contact
948
+
949
+ *Provides a way for people who have updates to the Model Card, suggestions, or questions, to contact the Model Card authors.*
950
+ -->
config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "models/nomic-embed-text-esci/checkpoint-37000",
3
+ "activation_function": "swiglu",
4
+ "architectures": [
5
+ "NomicBertModel"
6
+ ],
7
+ "attn_pdrop": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_hf_nomic_bert.NomicBertConfig",
10
+ "AutoModel": "modeling_hf_nomic_bert.NomicBertModel",
11
+ "AutoModelForMaskedLM": "nomic-ai/nomic-bert-2048--modeling_hf_nomic_bert.NomicBertForPreTraining"
12
+ },
13
+ "bos_token_id": null,
14
+ "causal": false,
15
+ "dense_seq_output": true,
16
+ "embd_pdrop": 0.0,
17
+ "eos_token_id": null,
18
+ "fused_bias_fc": true,
19
+ "fused_dropout_add_ln": true,
20
+ "initializer_range": 0.02,
21
+ "layer_norm_epsilon": 1e-12,
22
+ "max_trained_positions": 2048,
23
+ "mlp_fc1_bias": false,
24
+ "mlp_fc2_bias": false,
25
+ "model_type": "nomic_bert",
26
+ "n_embd": 768,
27
+ "n_head": 12,
28
+ "n_inner": 3072,
29
+ "n_layer": 12,
30
+ "n_positions": 8192,
31
+ "pad_vocab_size_multiple": 64,
32
+ "parallel_block": false,
33
+ "parallel_block_tied_norm": false,
34
+ "prenorm": false,
35
+ "qkv_proj_bias": false,
36
+ "reorder_and_upcast_attn": false,
37
+ "resid_pdrop": 0.0,
38
+ "rotary_emb_base": 1000,
39
+ "rotary_emb_fraction": 1.0,
40
+ "rotary_emb_interleaved": false,
41
+ "rotary_emb_scale_base": null,
42
+ "rotary_scaling_factor": null,
43
+ "scale_attn_by_inverse_layer_idx": false,
44
+ "scale_attn_weights": true,
45
+ "summary_activation": null,
46
+ "summary_first_dropout": 0.0,
47
+ "summary_proj_to_labels": true,
48
+ "summary_type": "cls_index",
49
+ "summary_use_proj": true,
50
+ "torch_dtype": "float32",
51
+ "transformers_version": "4.38.2",
52
+ "type_vocab_size": 2,
53
+ "use_cache": true,
54
+ "use_flash_attn": true,
55
+ "use_rms_norm": false,
56
+ "use_xentropy": true,
57
+ "vocab_size": 30528
58
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "3.0.1",
4
+ "transformers": "4.38.2",
5
+ "pytorch": "2.1.2+cu121"
6
+ },
7
+ "prompts": {},
8
+ "default_prompt_name": null,
9
+ "similarity_fn_name": null
10
+ }
configuration_hf_nomic_bert.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import GPT2Config
2
+
3
+
4
+ class NomicBertConfig(GPT2Config):
5
+ model_type = "nomic_bert"
6
+
7
+ def __init__(
8
+ self,
9
+ prenorm=False,
10
+ parallel_block=False,
11
+ parallel_block_tied_norm=False,
12
+ rotary_emb_fraction=0.0,
13
+ fused_dropout_add_ln=False,
14
+ fused_bias_fc=False,
15
+ use_flash_attn=False,
16
+ use_xentropy=False,
17
+ qkv_proj_bias=True,
18
+ rotary_emb_base=10_000,
19
+ rotary_emb_scale_base=None,
20
+ rotary_emb_interleaved=False,
21
+ mlp_fc1_bias=True,
22
+ mlp_fc2_bias=True,
23
+ use_rms_norm=False,
24
+ causal=False,
25
+ type_vocab_size=2,
26
+ dense_seq_output=True,
27
+ pad_vocab_size_multiple=1,
28
+ tie_word_embeddings=True,
29
+ rotary_scaling_factor=None,
30
+ max_trained_positions=2048,
31
+ **kwargs,
32
+ ):
33
+ self.prenorm = prenorm
34
+ self.parallel_block = parallel_block
35
+ self.parallel_block_tied_norm = parallel_block_tied_norm
36
+ self.rotary_emb_fraction = rotary_emb_fraction
37
+ self.tie_word_embeddings = tie_word_embeddings
38
+ self.fused_dropout_add_ln = fused_dropout_add_ln
39
+ self.fused_bias_fc = fused_bias_fc
40
+ self.use_flash_attn = use_flash_attn
41
+ self.use_xentropy = use_xentropy
42
+ self.qkv_proj_bias = qkv_proj_bias
43
+ self.rotary_emb_base = rotary_emb_base
44
+ self.rotary_emb_scale_base = rotary_emb_scale_base
45
+ self.rotary_emb_interleaved = rotary_emb_interleaved
46
+ self.mlp_fc1_bias = mlp_fc1_bias
47
+ self.mlp_fc2_bias = mlp_fc2_bias
48
+ self.use_rms_norm = use_rms_norm
49
+ self.causal = causal
50
+ self.type_vocab_size = type_vocab_size
51
+ self.dense_seq_output = dense_seq_output
52
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
53
+ self.rotary_scaling_factor = rotary_scaling_factor
54
+ self.max_trained_positions = max_trained_positions
55
+
56
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a3ba911558d833fc7ede5bbf4cccc67f19a69e1f7b5495aab145106c839b8aa
3
+ size 546938168
modeling_hf_nomic_bert.py ADDED
@@ -0,0 +1,2071 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, Tri Dao.
2
+ # This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
3
+ # https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
4
+ # https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
5
+
6
+ import logging
7
+
8
+ # Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
9
+ import math
10
+ import numpy as np
11
+ import collections
12
+ import os
13
+ import re
14
+ from collections import OrderedDict
15
+ from functools import partial
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from einops import rearrange, repeat
22
+ from safetensors.torch import load_file as safe_load_file
23
+ from transformers import GPT2Config, PreTrainedModel, ViTModel, ViTConfig
24
+ from transformers.models.bert.modeling_bert import (
25
+ BaseModelOutputWithPoolingAndCrossAttentions,
26
+ MaskedLMOutput,
27
+ SequenceClassifierOutput,
28
+ )
29
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
30
+ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
31
+ from transformers.modeling_outputs import BaseModelOutputWithPast
32
+ from torch.nn.modules.utils import _pair
33
+
34
+ from .configuration_hf_nomic_bert import NomicBertConfig
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # adapted from flash attention, added safe serialization option for hf models
40
+ def state_dict_from_pretrained(model_name, safe_serialization=False, device=None, dtype=None):
41
+ # If not fp32, then we don't want to load directly to the GPU
42
+ mapped_device = "cpu" if dtype not in [torch.float32, None] else device
43
+ is_sharded = False
44
+ load_safe = False
45
+ resolved_archive_file = None
46
+
47
+ weights_path = os.path.join(model_name, WEIGHTS_NAME)
48
+ weights_index_path = os.path.join(model_name, WEIGHTS_INDEX_NAME)
49
+ safe_weights_path = os.path.join(model_name, SAFE_WEIGHTS_NAME)
50
+ safe_weights_index_path = os.path.join(model_name, SAFE_WEIGHTS_INDEX_NAME)
51
+
52
+ if os.path.isfile(weights_path):
53
+ resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
54
+ elif os.path.isfile(weights_index_path):
55
+ resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False)
56
+ is_sharded = True
57
+ elif os.path.isfile(safe_weights_path):
58
+ resolved_archive_file = cached_file(model_name, SAFE_WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
59
+ load_safe = True
60
+ elif os.path.isfile(safe_weights_index_path):
61
+ resolved_archive_file = cached_file(
62
+ model_name, SAFE_WEIGHTS_INDEX_NAME, _raise_exceptions_for_missing_entries=False
63
+ )
64
+ is_sharded = True
65
+ load_safe = True
66
+ else: # Try loading from HF hub instead of from local files
67
+ resolved_archive_file = None
68
+ for weight_name in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
69
+ resolved_archive_file = cached_file(
70
+ model_name, weight_name, _raise_exceptions_for_missing_entries=False
71
+ )
72
+ if resolved_archive_file is not None:
73
+ if weight_name in [SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME]:
74
+ load_safe = True
75
+ if weight_name in [WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME]:
76
+ is_sharded = True
77
+ break
78
+
79
+ if resolved_archive_file is None:
80
+ raise EnvironmentError(f"Model name {model_name} was not found.")
81
+
82
+ if load_safe:
83
+ loader = partial(safe_load_file, device=mapped_device)
84
+ else:
85
+ loader = partial(torch.load, map_location=mapped_device)
86
+
87
+ if is_sharded:
88
+ # resolved_archive_file becomes a list of files that point to the different
89
+ # checkpoint shards in this case.
90
+ resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(model_name, resolved_archive_file)
91
+ state_dict = {}
92
+ for sharded_file in resolved_archive_file:
93
+ state_dict.update(loader(sharded_file))
94
+ else:
95
+ state_dict = loader(resolved_archive_file)
96
+ # Convert dtype before moving to GPU to save memory
97
+ if dtype is not None:
98
+ state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()}
99
+ state_dict = {k: v.to(device=device) for k, v in state_dict.items()}
100
+ return state_dict
101
+
102
+
103
+ def filter_shapes(state_dict, model):
104
+ """
105
+ Filters the state dict to match the current model shape.
106
+ """
107
+ filtered_state_dict = {}
108
+ for key, value in state_dict.items():
109
+ if key in model.state_dict():
110
+ if value.shape == model.state_dict()[key].shape:
111
+ filtered_state_dict[key] = value
112
+ return filtered_state_dict
113
+
114
+
115
+ def remap_bert_state_dict(
116
+ state_dict,
117
+ config,
118
+ remove_bert=False,
119
+ remove_cls_weights=False,
120
+ add_pooling_layer=False,
121
+ ):
122
+ """
123
+ Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
124
+ """
125
+
126
+ def add_bert_prefix(key):
127
+ # prepend bert. to the key
128
+ if key.startswith("bert.") or key.startswith("cls."):
129
+ return key
130
+ return f"bert.{key}"
131
+
132
+ state_dict = OrderedDict((add_bert_prefix(k), v) for k, v in state_dict.items())
133
+
134
+ # LayerNorm
135
+ def key_mapping_ln_gamma_beta(key):
136
+ key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
137
+ key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
138
+ return key
139
+
140
+ state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
141
+
142
+ # Layers
143
+ def key_mapping_layers(key):
144
+ return re.sub(r"^bert.encoder.layer\.", "bert.encoder.layers.", key)
145
+
146
+ state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
147
+
148
+ # LayerNorm
149
+ def key_mapping_ln(key):
150
+ key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
151
+ key = re.sub(
152
+ r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
153
+ r"bert.encoder.layers.\1.norm1.\2",
154
+ key,
155
+ )
156
+ key = re.sub(
157
+ r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
158
+ r"bert.encoder.layers.\1.norm2.\2",
159
+ key,
160
+ )
161
+ key = re.sub(
162
+ r"^cls.predictions.transform.LayerNorm.(weight|bias)",
163
+ r"cls.predictions.transform.layer_norm.\1",
164
+ key,
165
+ )
166
+ return key
167
+
168
+ state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
169
+
170
+ # MLP
171
+ def key_mapping_mlp(key):
172
+ key = re.sub(
173
+ r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
174
+ r"bert.encoder.layers.\1.mlp.fc1.\2",
175
+ key,
176
+ )
177
+ key = re.sub(
178
+ r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
179
+ r"bert.encoder.layers.\1.mlp.fc2.\2",
180
+ key,
181
+ )
182
+ return key
183
+
184
+ state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
185
+
186
+ # Attention
187
+ last_layer_subset = getattr(config, "last_layer_subset", False)
188
+ for d in range(config.num_hidden_layers):
189
+ if f"bert.encoder.layers.{d}.attention.self.query.weight" not in state_dict:
190
+ continue
191
+ Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
192
+ Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
193
+ Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
194
+ bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
195
+ bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
196
+ bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
197
+ if not (last_layer_subset and d == config.num_hidden_layers - 1):
198
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.weight"] = torch.cat([Wq, Wk, Wv], dim=0)
199
+ state_dict[f"bert.encoder.layers.{d}.attn.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
200
+ else:
201
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.weight"] = Wq
202
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
203
+ state_dict[f"bert.encoder.layers.{d}.attn.Wq.bias"] = bq
204
+ state_dict[f"bert.encoder.layers.{d}.attn.Wkv.bias"] = torch.cat([bk, bv], dim=0)
205
+
206
+ def key_mapping_attn(key):
207
+ return re.sub(
208
+ r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
209
+ r"bert.encoder.layers.\1.attn.out_proj.\2",
210
+ key,
211
+ )
212
+
213
+ state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
214
+
215
+ def key_mapping_decoder_bias(key):
216
+ return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
217
+
218
+ # remove nsp weights, we don't use
219
+ state_dict.pop("cls.seq_relationship.weight", None)
220
+ state_dict.pop("cls.seq_relationship.bias", None)
221
+ state_dict.pop("bert.embeddings.position_ids", None)
222
+
223
+ state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
224
+
225
+ if remove_cls_weights:
226
+ cls_weights = [
227
+ "cls.predictions.decoder.bias",
228
+ "cls.predictions.transform.dense.weight",
229
+ "cls.predictions.transform.dense.bias",
230
+ "cls.predictions.transform.layer_norm.weight",
231
+ "cls.predictions.transform.layer_norm.bias",
232
+ "cls.predictions.decoder.weight",
233
+ ]
234
+ for weight in cls_weights:
235
+ state_dict.pop(weight, None)
236
+
237
+ # Word embedding
238
+ pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
239
+ if pad_vocab_size_multiple > 1:
240
+ word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
241
+ state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
242
+ word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
243
+ )
244
+ if not remove_cls_weights:
245
+ decoder_weight = state_dict["cls.predictions.decoder.weight"]
246
+ state_dict["cls.predictions.decoder.weight"] = F.pad(
247
+ decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
248
+ )
249
+ # If the vocab was padded, we want to set the decoder bias for those padded indices to be
250
+ # strongly negative (i.e. the decoder shouldn't predict those indices).
251
+ # TD [2022-05-09]: I don't think it affects the MLPerf training.
252
+ if "cls.predictions.decoder.bias" in state_dict:
253
+ decoder_bias = state_dict["cls.predictions.decoder.bias"]
254
+ state_dict["cls.predictions.decoder.bias"] = F.pad(
255
+ decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
256
+ )
257
+
258
+ if add_pooling_layer is False:
259
+ pooler_weights = [
260
+ "bert.pooler.dense.weight",
261
+ "bert.pooler.dense.bias",
262
+ ]
263
+ for key in pooler_weights:
264
+ state_dict.pop(key, None)
265
+
266
+ if remove_bert:
267
+
268
+ def remove_bert_prefix(key):
269
+ key = re.sub(r"^bert.", "", key)
270
+ return key
271
+
272
+ state_dict = OrderedDict((remove_bert_prefix(k), v) for k, v in state_dict.items())
273
+
274
+ return state_dict
275
+
276
+
277
+ def _trunc_normal_(tensor, mean, std, a, b):
278
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
279
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
280
+ def norm_cdf(x):
281
+ # Computes standard normal cumulative distribution function
282
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
283
+
284
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
285
+ print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
286
+ "The distribution of values may be incorrect.",
287
+ stacklevel=2)
288
+
289
+ # Values are generated by using a truncated uniform distribution and
290
+ # then using the inverse CDF for the normal distribution.
291
+ # Get upper and lower cdf values
292
+ l = norm_cdf((a - mean) / std)
293
+ u = norm_cdf((b - mean) / std)
294
+
295
+ # Uniformly fill tensor with values from [l, u], then translate to
296
+ # [2l-1, 2u-1].
297
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
298
+
299
+ # Use inverse cdf transform for normal distribution to get truncated
300
+ # standard normal
301
+ tensor.erfinv_()
302
+
303
+ # Transform to proper mean, std
304
+ tensor.mul_(std * math.sqrt(2.))
305
+ tensor.add_(mean)
306
+
307
+ # Clamp to ensure it's in the proper range
308
+ tensor.clamp_(min=a, max=b)
309
+ return tensor
310
+
311
+ def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
312
+ r"""Fills the input Tensor with values drawn from a truncated
313
+ normal distribution. The values are effectively drawn from the
314
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
315
+ with values outside :math:`[a, b]` redrawn until they are within
316
+ the bounds. The method used for generating the random values works
317
+ best when :math:`a \leq \text{mean} \leq b`.
318
+
319
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
320
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
321
+ and the result is subsquently scaled and shifted by the mean and std args.
322
+
323
+ Args:
324
+ tensor: an n-dimensional `torch.Tensor`
325
+ mean: the mean of the normal distribution
326
+ std: the standard deviation of the normal distribution
327
+ a: the minimum cutoff value
328
+ b: the maximum cutoff value
329
+ Examples:
330
+ >>> w = torch.empty(3, 5)
331
+ >>> nn.init.trunc_normal_(w)
332
+ """
333
+ with torch.no_grad():
334
+ _trunc_normal_(tensor, 0, 1.0, a, b)
335
+ tensor.mul_(std).add_(mean)
336
+ return tensor
337
+
338
+
339
+ class NomicBertPreTrainedModel(PreTrainedModel):
340
+ """An abstract class to handle weights initialization and
341
+ a simple interface for dowloading and loading pretrained models.
342
+ """
343
+
344
+ config_class = NomicBertConfig
345
+ base_model_prefix = "model"
346
+ supports_gradient_checkpointing = True
347
+ _no_split_modules = ["Block"]
348
+ _skip_keys_device_placement = "past_key_values"
349
+
350
+ def __init__(self, config, *inputs, **kwargs):
351
+ super().__init__(config)
352
+ if not isinstance(config, GPT2Config):
353
+ raise ValueError(
354
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
355
+ "To create a model from a Google pretrained model use "
356
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
357
+ self.__class__.__name__, self.__class__.__name__
358
+ )
359
+ )
360
+ self.config = config
361
+
362
+ @classmethod
363
+ def from_pretrained(cls, model_name, config=None, *inputs, **kwargs):
364
+ """
365
+ Instantiate a NomicBertPreTrainedModel from a pre-trained model file or a pytorch state dict.
366
+ Download and cache the pre-trained model file if needed.
367
+
368
+ Params:
369
+ pretrained_model_name_or_path: either:
370
+ - a path or url to a pretrained model archive containing:
371
+ . `bert_config.json` a configuration file for the model
372
+ . `pytorch_model.bin` a PyTorch dump of a NomicBertForPretraining instance
373
+ - a path or url to a pretrained model archive containing:
374
+ . `bert_config.json` a configuration file for the model
375
+ . `model.chkpt` a TensorFlow checkpoint
376
+ *inputs, **kwargs: additional input for the specific NomicBert class
377
+ (ex: num_labels for NomicBertForSequenceClassification)
378
+ """
379
+ # Instantiate model.
380
+ if config is None:
381
+ config = cls.config_class.from_pretrained(model_name)
382
+ remove_cls = cls != NomicBertForPreTraining
383
+ remove_bert_prefix = cls != NomicBertForPreTraining and cls != NomicBertForSequenceClassification
384
+ ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
385
+ num_labels = kwargs.pop("num_labels", None)
386
+ rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
387
+ strict = kwargs.pop("strict", True)
388
+ if rotary_scaling_factor:
389
+ config.rotary_scaling_factor = rotary_scaling_factor
390
+
391
+ if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
392
+ config.n_positions = 2048
393
+ if num_labels:
394
+ config.num_labels = num_labels
395
+
396
+ if "add_pooling_layer" in kwargs:
397
+ model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
398
+ else:
399
+ if cls == NomicBertModel:
400
+ model = cls(config, *inputs, add_pooling_layer=False)
401
+ else:
402
+ model = cls(config, *inputs)
403
+ # TODO: fix this
404
+ # Assuming we know what we're doing when loading from disk
405
+ # Prob a bad assumption but i'm tired and want to train this asap
406
+ if os.path.exists(model_name):
407
+ model_path = f"{model_name}/pytorch_model.bin"
408
+ if os.path.exists(model_path):
409
+ state_dict = torch.load(f"{model_name}/pytorch_model.bin")
410
+ else:
411
+ model_path = f"{model_name}/model.safetensors"
412
+ if not os.path.exists(model_path):
413
+ raise ValueError(f"Model path {model_path} not found")
414
+ state_dict = safe_load_file(model_path)
415
+
416
+ if ignore_mismatched_shapes:
417
+ state_dict = filter_shapes(state_dict, model)
418
+ load_return = model.load_state_dict(state_dict, strict=False)
419
+ else:
420
+ # TODO: can probably check config class and see if we need to remap from a bert model
421
+ state_dict = state_dict_from_pretrained(model_name)
422
+ state_dict = remap_bert_state_dict(
423
+ state_dict,
424
+ config,
425
+ remove_bert=remove_bert_prefix,
426
+ remove_cls_weights=remove_cls,
427
+ add_pooling_layer=getattr(config, "add_pooling_layer", False),
428
+ )
429
+ if ignore_mismatched_shapes:
430
+ state_dict = filter_shapes(state_dict, model)
431
+
432
+ load_return = model.load_state_dict(state_dict, strict=strict)
433
+ logger.warning(load_return)
434
+ return model
435
+
436
+ def _set_gradient_checkpointing(self, module, value=False):
437
+ if isinstance(module, NomicBertEncoder):
438
+ module.gradient_checkpointing = value
439
+
440
+
441
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
442
+ def _init_weights(module, initializer_range=0.02):
443
+ if isinstance(module, nn.Linear):
444
+ nn.init.normal_(module.weight, std=initializer_range)
445
+ if module.bias is not None:
446
+ nn.init.zeros_(module.bias)
447
+ elif isinstance(module, nn.Embedding):
448
+ nn.init.normal_(module.weight, std=initializer_range)
449
+ if module.padding_idx is not None:
450
+ nn.init.zeros_(module.weight[module.padding_idx])
451
+
452
+ def _ntuple(n):
453
+ def parse(x):
454
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
455
+ return tuple(x)
456
+ return tuple(repeat(x, n))
457
+ return parse
458
+
459
+
460
+ to_1tuple = _ntuple(1)
461
+ to_2tuple = _ntuple(2)
462
+ to_3tuple = _ntuple(3)
463
+ to_4tuple = _ntuple(4)
464
+ to_ntuple = _ntuple
465
+
466
+
467
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
468
+ """
469
+ Create 2D sin/cos positional embeddings.
470
+
471
+ Args:
472
+ embed_dim (`int`):
473
+ Embedding dimension.
474
+ grid_size (`int`):
475
+ The grid height and width.
476
+ add_cls_token (`bool`, *optional*, defaults to `False`):
477
+ Whether or not to add a classification (CLS) token.
478
+
479
+ Returns:
480
+ (`torch.FloatTensor` of shape (grid_size*grid_size, embed_dim) or (1+grid_size*grid_size, embed_dim): the
481
+ position embeddings (with or without classification token)
482
+ """
483
+ grid_h = np.arange(grid_size, dtype=np.float32)
484
+
485
+ grid_w = np.arange(grid_size, dtype=np.float32)
486
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
487
+ grid = np.stack(grid, axis=0)
488
+
489
+ grid = grid.reshape([2, 1, grid_size, grid_size])
490
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
491
+ if add_cls_token:
492
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
493
+ return pos_embed
494
+
495
+
496
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
497
+ if embed_dim % 2 != 0:
498
+ raise ValueError("embed_dim must be even")
499
+
500
+ # use half of dimensions to encode grid_h
501
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
502
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
503
+
504
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
505
+ return emb
506
+
507
+
508
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
509
+ """
510
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
511
+ """
512
+ if embed_dim % 2 != 0:
513
+ raise ValueError("embed_dim must be even")
514
+
515
+ omega = np.arange(embed_dim // 2, dtype=float)
516
+ omega /= embed_dim / 2.0
517
+ omega = 1.0 / 10000**omega # (D/2,)
518
+
519
+ pos = pos.reshape(-1) # (M,)
520
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
521
+
522
+ emb_sin = np.sin(out) # (M, D/2)
523
+ emb_cos = np.cos(out) # (M, D/2)
524
+
525
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
526
+ return emb
527
+
528
+ def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
529
+ """generate N-D grid in dimension order.
530
+
531
+ The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
532
+
533
+ That is, the statement
534
+ [X1,X2,X3] = ndgrid(x1,x2,x3)
535
+
536
+ produces the same result as
537
+
538
+ [X2,X1,X3] = meshgrid(x2,x1,x3)
539
+
540
+ This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
541
+ torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
542
+
543
+ """
544
+ try:
545
+ return torch.meshgrid(*tensors, indexing='ij')
546
+ except TypeError:
547
+ # old PyTorch < 1.10 will follow this path as it does not have indexing arg,
548
+ # the old behaviour of meshgrid was 'ij'
549
+ return torch.meshgrid(*tensors)
550
+
551
+ def build_fourier_pos_embed(
552
+ feat_shape: List[int],
553
+ bands: Optional[torch.Tensor] = None,
554
+ num_bands: int = 64,
555
+ max_res: int = 224,
556
+ temperature: float = 10000.,
557
+ linear_bands: bool = False,
558
+ include_grid: bool = False,
559
+ in_pixels: bool = True,
560
+ ref_feat_shape: Optional[List[int]] = None,
561
+ dtype: torch.dtype = torch.float32,
562
+ device: Optional[torch.device] = None,
563
+ ) -> List[torch.Tensor]:
564
+ """
565
+
566
+ Args:
567
+ feat_shape: Feature shape for embedding.
568
+ bands: Pre-calculated frequency bands.
569
+ num_bands: Number of frequency bands (determines output dim).
570
+ max_res: Maximum resolution for pixel based freq.
571
+ temperature: Temperature for non-pixel freq.
572
+ linear_bands: Linear band spacing for pixel based freq.
573
+ include_grid: Include the spatial grid in output.
574
+ in_pixels: Output in pixel freq.
575
+ ref_feat_shape: Reference feature shape for resize / fine-tune.
576
+ dtype: Output dtype.
577
+ device: Output device.
578
+
579
+ Returns:
580
+
581
+ """
582
+ if bands is None:
583
+ if in_pixels:
584
+ bands = pixel_freq_bands(
585
+ num_bands,
586
+ float(max_res),
587
+ linear_bands=linear_bands,
588
+ device=device,
589
+ )
590
+ else:
591
+ bands = freq_bands(
592
+ num_bands,
593
+ temperature=temperature,
594
+ step=1,
595
+ device=device,
596
+ )
597
+ else:
598
+ if device is None:
599
+ device = bands.device
600
+ if dtype is None:
601
+ dtype = bands.dtype
602
+
603
+ if in_pixels:
604
+ t = [torch.linspace(-1., 1., steps=s, device=device, dtype=torch.float32) for s in feat_shape]
605
+ else:
606
+ t = [torch.arange(s, device=device, dtype=torch.int64).to(torch.float32) for s in feat_shape]
607
+
608
+ if ref_feat_shape is not None:
609
+ # eva's scheme for resizing rope embeddings (ref shape = pretrain)
610
+ t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
611
+
612
+ grid = torch.stack(ndgrid(t), dim=-1)
613
+ grid = grid.unsqueeze(-1)
614
+ pos = grid * bands
615
+
616
+ pos_sin, pos_cos = pos.sin().to(dtype=dtype), pos.cos().to(dtype)
617
+ out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
618
+ return out
619
+
620
+
621
+ def build_rotary_pos_embed(
622
+ feat_shape: List[int],
623
+ bands: Optional[torch.Tensor] = None,
624
+ dim: int = 64,
625
+ max_res: int = 224,
626
+ temperature: float = 10000.,
627
+ linear_bands: bool = False,
628
+ in_pixels: bool = True,
629
+ ref_feat_shape: Optional[List[int]] = None,
630
+ dtype: torch.dtype = torch.float32,
631
+ device: Optional[torch.device] = None,
632
+ ):
633
+ """
634
+
635
+ Args:
636
+ feat_shape: Spatial shape of the target tensor for embedding.
637
+ bands: Optional pre-generated frequency bands
638
+ dim: Output dimension of embedding tensor.
639
+ max_res: Maximum resolution for pixel mode.
640
+ temperature: Temperature (inv freq) for non-pixel mode
641
+ linear_bands: Linearly (instead of log) spaced bands for pixel mode
642
+ in_pixels: Pixel vs language (inv freq) mode.
643
+ dtype: Output dtype.
644
+ device: Output device.
645
+
646
+ Returns:
647
+
648
+ """
649
+ sin_emb, cos_emb = build_fourier_pos_embed(
650
+ feat_shape,
651
+ bands=bands,
652
+ num_bands=dim // 4,
653
+ max_res=max_res,
654
+ temperature=temperature,
655
+ linear_bands=linear_bands,
656
+ in_pixels=in_pixels,
657
+ ref_feat_shape=ref_feat_shape,
658
+ device=device,
659
+ dtype=dtype,
660
+ )
661
+ num_spatial_dim = 1
662
+ # this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
663
+ for x in feat_shape:
664
+ num_spatial_dim *= x
665
+ sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
666
+ cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
667
+ return sin_emb, cos_emb
668
+
669
+ def freq_bands(
670
+ num_bands: int,
671
+ temperature: float = 10000.,
672
+ step: int = 2,
673
+ device: Optional[torch.device] = None,
674
+ ) -> torch.Tensor:
675
+ exp = torch.arange(0, num_bands, step, dtype=torch.int64, device=device).to(torch.float32) / num_bands
676
+ bands = 1. / (temperature ** exp)
677
+ return bands
678
+
679
+
680
+ def pixel_freq_bands(
681
+ num_bands: int,
682
+ max_freq: float = 224.,
683
+ linear_bands: bool = True,
684
+ device: Optional[torch.device] = None,
685
+ ):
686
+ if linear_bands:
687
+ bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=torch.float32, device=device)
688
+ else:
689
+ bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=torch.float32, device=device)
690
+ return bands * torch.pi
691
+
692
+ def rot(x):
693
+ return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
694
+
695
+ def apply_rot_embed_cat(x: torch.Tensor, emb):
696
+ sin_emb, cos_emb = emb.tensor_split(2, -1)
697
+ if sin_emb.ndim == 3:
698
+ return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
699
+ return x * cos_emb + rot(x) * sin_emb
700
+
701
+ # taken from https://github.com/huggingface/pytorch-image-models/blob/cb0e4391beedcc5ac3ae4bce16561b95c326f32c/timm/layers/pos_embed_sincos.py#L363
702
+ class NomicVisionRotaryEmbeddingCat(nn.Module):
703
+ """ Rotary position embedding w/ concatenatd sin & cos
704
+
705
+ The following impl/resources were referenced for this impl:
706
+ * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
707
+ * https://blog.eleuther.ai/rotary-embeddings/
708
+ """
709
+
710
+ def __init__(
711
+ self,
712
+ dim,
713
+ max_res=224,
714
+ temperature=10000,
715
+ in_pixels=True,
716
+ linear_bands: bool = False,
717
+ feat_shape: Optional[List[int]] = None,
718
+ ref_feat_shape: Optional[List[int]] = None,
719
+ ):
720
+ super().__init__()
721
+ self.dim = dim
722
+ self.max_res = max_res
723
+ self.temperature = temperature
724
+ self.in_pixels = in_pixels
725
+ self.feat_shape = feat_shape
726
+ self.ref_feat_shape = ref_feat_shape
727
+
728
+ if feat_shape is None:
729
+ # only cache bands
730
+ if in_pixels:
731
+ bands = pixel_freq_bands(
732
+ dim // 4,
733
+ float(max_res),
734
+ linear_bands=linear_bands,
735
+ )
736
+ else:
737
+ bands = freq_bands(
738
+ dim // 4,
739
+ temperature=temperature,
740
+ step=1,
741
+ )
742
+ self.register_buffer(
743
+ 'bands',
744
+ bands,
745
+ persistent=False,
746
+ )
747
+ self.pos_embed = None
748
+ else:
749
+ # cache full sin/cos embeddings if shape provided up front
750
+ embeds = build_rotary_pos_embed(
751
+ feat_shape=feat_shape,
752
+ dim=dim,
753
+ max_res=max_res,
754
+ linear_bands=linear_bands,
755
+ in_pixels=in_pixels,
756
+ ref_feat_shape=self.ref_feat_shape,
757
+ )
758
+ self.bands = None
759
+ self.register_buffer(
760
+ 'pos_embed',
761
+ torch.cat(embeds, -1),
762
+ persistent=False,
763
+ )
764
+
765
+ def get_embed(self, shape: Optional[List[int]] = None):
766
+ if self.bands is not None and shape is not None:
767
+ # rebuild embeddings every call, use if target shape changes
768
+ embeds = build_rotary_pos_embed(
769
+ shape,
770
+ self.bands,
771
+ in_pixels=self.in_pixels,
772
+ ref_feat_shape=self.ref_feat_shape,
773
+ )
774
+ return torch.cat(embeds, -1)
775
+ elif self.pos_embed is not None:
776
+ return self.pos_embed
777
+ else:
778
+ assert False, "get_embed() requires pre-computed pos_embed or valid shape w/ pre-computed bands"
779
+
780
+ def forward(self, x):
781
+ # assuming channel-first tensor where spatial dim are >= 2
782
+ pos_embed = self.get_embed(x.shape[2:])
783
+ return apply_rot_embed_cat(x, pos_embed)
784
+
785
+ class NomicVisionPatchEmbeddings(nn.Module):
786
+ def __init__(
787
+ self,
788
+ config,
789
+ ):
790
+ super().__init__()
791
+ img_size = _pair(config.img_size)
792
+ patch_size = _pair(config.patch_size)
793
+ self.img_size = img_size
794
+ self.patch_size = patch_size
795
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
796
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
797
+
798
+ self.proj = nn.Linear(
799
+ config.num_channels * patch_size[0] * patch_size[1], config.n_embd, bias=config.patch_embed_bias
800
+ )
801
+
802
+ self.learned_pos_embedding = False
803
+ self.sinusoidal_pos_embedding = False
804
+ self.no_embed_class = getattr(config, "no_embed_class", False)
805
+
806
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.n_embd)) if not getattr(config, "no_cls_token", False) else None
807
+ if config.learned_pos_embedding:
808
+ # this is the default in DINO
809
+ self.learned_pos_embedding = True
810
+ # hack for timm dinov2 with registers
811
+ num_patches = self.num_patches if getattr(config, "register_tokens", 0) > 0 else self.num_patches + 1
812
+ self.pos_embed = nn.Parameter(torch.randn(1, num_patches, config.n_embd) * 0.02) if getattr(config, "use_pos_embed", True) else None
813
+ elif getattr(config, "sinusoidal_pos_embedding", False):
814
+ self.sinusoidal_pos_embedding = True
815
+ if getattr(config, "use_pos_embed", True):
816
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, config.n_embd), requires_grad=False)
817
+ pos_embed = get_2d_sincos_pos_embed(config.n_embd, self.grid_size[0], add_cls_token=True)
818
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).to(self.pos_embed))
819
+ else:
820
+ self.pos_embed = None
821
+ else:
822
+ self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches + 1, config.n_embd) * 0.02) if getattr(config, "use_pos_embed", True) else None
823
+
824
+ if getattr(config, "register_tokens", 0) > 0:
825
+ self.reg_token = nn.Parameter(torch.randn(1, config.register_tokens, config.n_embd) * 0.02)
826
+ else:
827
+ self.reg_token = None
828
+
829
+ if config.mask_token:
830
+ self.mask_token = nn.Parameter(torch.zeros(1, config.n_embd))
831
+
832
+ self.patch_dropout = nn.Identity()
833
+
834
+ if getattr(config, "use_rotary_pos_emb", False):
835
+ ref_feat_shape = getattr(config, "ref_feat_shape", None)
836
+ ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
837
+ self.rope = NomicVisionRotaryEmbeddingCat(
838
+ config.n_embd // config.n_head,
839
+ in_pixels=False,
840
+ feat_shape=self.grid_size,
841
+ ref_feat_shape=ref_feat_shape,
842
+ )
843
+ else:
844
+ self.rope = None
845
+
846
+
847
+ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
848
+ """
849
+ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
850
+ resolution images.
851
+
852
+ Source:
853
+ https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
854
+ """
855
+ num_patches = embeddings.shape[1] - 1
856
+ num_positions = self.pos_embed.shape[1] - 1
857
+ if num_patches == num_positions and height == width:
858
+ return self.pos_embed
859
+ class_pos_embed = self.pos_embed[:, 0]
860
+ patch_pos_embed = self.pos_embed[:, 1:]
861
+ dim = embeddings.shape[-1]
862
+ height = height // self.patch_size[0]
863
+ width = width // self.patch_size[1]
864
+ # we add a small number to avoid floating point error in the interpolation
865
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
866
+ height, width = height + 0.1, width + 0.1
867
+ patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
868
+ patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
869
+ patch_pos_embed = nn.functional.interpolate(
870
+ patch_pos_embed,
871
+ scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
872
+ mode="bicubic",
873
+ align_corners=False,
874
+ )
875
+ if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
876
+ raise ValueError("Width or height does not match with the interpolated position embeddings")
877
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
878
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
879
+
880
+ def forward(self, x):
881
+ # deepspeed case where the input is in fp32
882
+ if x.dtype != self.proj.weight.dtype:
883
+ x = x.to(dtype=self.proj.weight.dtype)
884
+
885
+ _, _, height, width = x.shape
886
+ x = self.proj(
887
+ rearrange(
888
+ x,
889
+ "b c (h p1) (w p2) -> b h w (c p1 p2)",
890
+ p1=self.patch_size[0],
891
+ p2=self.patch_size[1],
892
+ )
893
+ )
894
+ embeddings = rearrange(x, "b h w c -> b (h w) c")
895
+
896
+ to_cat = []
897
+ if self.cls_token is not None:
898
+ if self.sinusoidal_pos_embedding:
899
+ cls_token = self.cls_token + self.pos_embed[:, 0]
900
+ cls_token = cls_token.expand(embeddings.shape[0], -1, -1)
901
+ to_cat += [cls_token]
902
+ else:
903
+ cls_token = self.cls_token.expand(embeddings.shape[0], 1, -1)
904
+ to_cat += [cls_token]
905
+
906
+ if self.reg_token is not None:
907
+ to_cat += [self.reg_token.expand(embeddings.shape[0], -1, -1)]
908
+
909
+ rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
910
+
911
+ if self.no_embed_class:
912
+ if self.learned_pos_embedding:
913
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
914
+ else:
915
+ if self.pos_embed is not None:
916
+ embeddings = embeddings + self.pos_embed
917
+ if to_cat:
918
+ embeddings = torch.cat(to_cat + [embeddings], dim=1)
919
+ else:
920
+ if to_cat:
921
+ embeddings = torch.cat(to_cat + [embeddings], dim=1)
922
+ if self.learned_pos_embedding:
923
+ if self.pos_embed is not None:
924
+ embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
925
+ else:
926
+ if self.pos_embed is not None:
927
+ embeddings = embeddings + self.pos_embed
928
+
929
+ embeddings = self.patch_dropout(embeddings)
930
+
931
+ return embeddings, rot_pos_embed
932
+
933
+
934
+ class NomicBertEmbeddings(nn.Module):
935
+ def __init__(self, config):
936
+ """
937
+ If max_position_embeddings <= 0, there's no position embeddings
938
+ If type_vocab_size <= 0, there's no token type embeddings
939
+ """
940
+ super().__init__()
941
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
942
+ self.max_position_embeddings = config.max_position_embeddings if config.rotary_emb_fraction <= 0 else 0
943
+ self.type_vocab_size = config.type_vocab_size
944
+ if self.max_position_embeddings > 0 and config.rotary_emb_fraction <= 0:
945
+ self.position_embeddings = nn.Embedding(
946
+ config.max_position_embeddings,
947
+ config.hidden_size,
948
+ )
949
+ if self.type_vocab_size > 0:
950
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
951
+
952
+ def forward(self, input_ids, position_ids=None, token_type_ids=None):
953
+ """
954
+ input_ids: (batch, seqlen)
955
+ position_ids: (batch, seqlen)
956
+ token_type_ids: (batch, seqlen)
957
+ """
958
+ batch_size, seqlen = input_ids.shape
959
+ embeddings = self.word_embeddings(input_ids)
960
+
961
+ if self.type_vocab_size > 0:
962
+ if token_type_ids is None:
963
+ token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
964
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
965
+ embeddings = embeddings + token_type_embeddings
966
+
967
+ if self.max_position_embeddings > 0:
968
+ if position_ids is None:
969
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
970
+ position_embeddings = self.position_embeddings(position_ids)
971
+ embeddings = embeddings + position_embeddings
972
+ return embeddings
973
+
974
+
975
+ class NomicBertMLP(nn.Module):
976
+ def __init__(
977
+ self,
978
+ in_features,
979
+ hidden_features=None,
980
+ out_features=None,
981
+ activation=F.gelu,
982
+ bias1=True,
983
+ bias2=True,
984
+ return_residual=False,
985
+ fused_bias_fc=False,
986
+ ):
987
+ super().__init__()
988
+ out_features = out_features if out_features is not None else in_features
989
+ hidden_features = hidden_features if hidden_features is not None else in_features * 4
990
+ self.return_residual = return_residual
991
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1)
992
+ approximate = "tanh" if activation in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
993
+ self.activation = nn.GELU(approximate=approximate) if activation == "gelu" else activation
994
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
995
+
996
+ def forward(self, x):
997
+ y = self.fc1(x)
998
+ y = self.activation(y)
999
+ y = self.fc2(y)
1000
+ return y if not self.return_residual else (y, x)
1001
+
1002
+
1003
+ class NomciBertGatedMLP(nn.Module):
1004
+ def __init__(
1005
+ self,
1006
+ in_features,
1007
+ hidden_features=None,
1008
+ out_features=None,
1009
+ activation=F.sigmoid,
1010
+ bias1=True,
1011
+ bias2=True,
1012
+ multiple_of=256,
1013
+ return_residual=False,
1014
+ fused_bias_fc=True,
1015
+ device=None,
1016
+ dtype=None,
1017
+ norm_layer=False,
1018
+ ):
1019
+ super().__init__()
1020
+ out_features = out_features if out_features is not None else in_features
1021
+ hidden_features = hidden_features if hidden_features is not None else int(8 * in_features / 3)
1022
+ hidden_features = int((hidden_features + multiple_of - 1) // multiple_of * multiple_of)
1023
+ self.return_residual = return_residual
1024
+
1025
+ self.fc11 = nn.Linear(in_features, hidden_features, bias=bias1)
1026
+ self.fc12 = nn.Linear(in_features, hidden_features, bias=bias1)
1027
+ self.activation = activation
1028
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2)
1029
+ self.norm = nn.LayerNorm(hidden_features) if norm_layer else nn.Identity()
1030
+
1031
+ def forward(self, x):
1032
+ y = self.fc11(x)
1033
+ gate = self.fc12(x)
1034
+ if self.activation == F.sigmoid: # Special case for GLU
1035
+ y = F.glu(torch.cat([y, gate], dim=-1), dim=-1)
1036
+ else:
1037
+ y = y * self.activation(gate)
1038
+
1039
+ # eva uses layer norm after the activation
1040
+ y = self.norm(y)
1041
+
1042
+ y = self.fc2(y)
1043
+ return y if not self.return_residual else (y, x)
1044
+
1045
+
1046
+ def rotate_half(x, interleaved=False):
1047
+ if not interleaved:
1048
+ x1, x2 = x.chunk(2, dim=-1)
1049
+ return torch.cat((-x2, x1), dim=-1)
1050
+ else:
1051
+ x1, x2 = x[..., ::2], x[..., 1::2]
1052
+ return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
1053
+
1054
+
1055
+ def apply_rotary_emb(x, cos, sin, offset=0, interleaved=False):
1056
+ """
1057
+ x: (batch_size, seqlen, nheads, headdim)
1058
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
1059
+ """
1060
+ ro_dim = cos.shape[-1] * 2
1061
+ assert ro_dim <= x.shape[-1]
1062
+ cos, sin = (
1063
+ cos[offset : offset + x.shape[1]],
1064
+ sin[offset : offset + x.shape[1]],
1065
+ )
1066
+ cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
1067
+ sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
1068
+ return torch.cat(
1069
+ [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
1070
+ dim=-1,
1071
+ )
1072
+
1073
+
1074
+ class NomicBertRotaryEmbedding(nn.Module):
1075
+ def __init__(
1076
+ self,
1077
+ dim: int,
1078
+ base=10000.0,
1079
+ interleaved=False,
1080
+ scale_base=None,
1081
+ pos_idx_in_fp32=True,
1082
+ device=None,
1083
+ ):
1084
+ """
1085
+ interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
1086
+ of 1st half and 2nd half (GPT-NeoX style).
1087
+ pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
1088
+ otherwise they might be in lower precision.
1089
+ This option was added because previously (before 2023-07-02), when we construct
1090
+ the position indices, we use the dtype of self.inv_freq. In most cases this would
1091
+ be fp32, but if the model is trained in pure bf16 (not mixed precision), then
1092
+ self.inv_freq would be bf16, and the position indices are also in bf16.
1093
+ Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
1094
+ embeddings for some positions will coincide.
1095
+ To maintain compatibility with models previously trained in pure bf16,
1096
+ we add this option.
1097
+ """
1098
+ super().__init__()
1099
+ self.dim = dim
1100
+ self.base = float(base)
1101
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
1102
+ # Generate and save the inverse frequency buffer (non trainable)
1103
+ inv_freq = self._compute_inv_freq(device)
1104
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1105
+ self.interleaved = interleaved
1106
+ self.scale_base = scale_base
1107
+ scale = (
1108
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
1109
+ if scale_base is not None
1110
+ else None
1111
+ )
1112
+ self.register_buffer("scale", scale, persistent=False)
1113
+
1114
+ self._seq_len_cached = 0
1115
+ self._cos_cached = None
1116
+ self._sin_cached = None
1117
+ self._cos_k_cached = None
1118
+ self._sin_k_cached = None
1119
+
1120
+ def _compute_inv_freq(self, device=None):
1121
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
1122
+
1123
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
1124
+ # Reset the tables if the sequence length has changed,
1125
+ # if we're on a new device (possibly due to tracing for instance),
1126
+ # or if we're switching from inference mode to training
1127
+ if (
1128
+ seqlen > self._seq_len_cached
1129
+ or self._cos_cached is None
1130
+ or self._cos_cached.device != device
1131
+ or self._cos_cached.dtype != dtype
1132
+ or (self.training and self._cos_cached.is_inference())
1133
+ ):
1134
+ self._seq_len_cached = seqlen
1135
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
1136
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
1137
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
1138
+ if self.pos_idx_in_fp32:
1139
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
1140
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
1141
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
1142
+ # cos & sin output to change significantly.
1143
+ # We want to recompute self.inv_freq if it was not loaded in fp32
1144
+ if self.inv_freq.dtype != torch.float32:
1145
+ inv_freq = self._compute_inv_freq(device=device)
1146
+ else:
1147
+ inv_freq = self.inv_freq
1148
+ else:
1149
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
1150
+ inv_freq = self.inv_freq
1151
+ # Don't do einsum, it converts fp32 to fp16 under AMP
1152
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1153
+ freqs = torch.outer(t, inv_freq)
1154
+ self._cos_cached = torch.cos(freqs).to(dtype)
1155
+ self._sin_cached = torch.sin(freqs).to(dtype)
1156
+
1157
+ def forward(
1158
+ self,
1159
+ qkv: torch.Tensor,
1160
+ kv: Optional[torch.Tensor] = None,
1161
+ seqlen_offset: Union[int, torch.Tensor] = 0,
1162
+ max_seqlen: Optional[int] = None,
1163
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1164
+ """
1165
+ qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
1166
+ else it's just q of shape (batch, seqlen, nheads, headdim)
1167
+ kv: (batch, seqlen, 2, nheads, headdim)
1168
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
1169
+ Most commonly used in inference when we have KV cache.
1170
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
1171
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
1172
+ Apply rotary embedding *inplace* to qkv and / or kv.
1173
+ """
1174
+ seqlen = qkv.shape[1]
1175
+ if seqlen > self._seq_len_cached:
1176
+ self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
1177
+ elif max_seqlen is not None:
1178
+ self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
1179
+ elif isinstance(seqlen_offset, int):
1180
+ self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
1181
+
1182
+ q_rot = apply_rotary_emb(qkv[:, :, 0], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
1183
+ k_rot = apply_rotary_emb(qkv[:, :, 1], self._cos_cached, self._sin_cached, seqlen_offset, self.interleaved)
1184
+ return torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
1185
+
1186
+
1187
+ class NomicBertDynamicNTKRotaryEmbedding(NomicBertRotaryEmbedding):
1188
+ def __init__(self, rotary_scaling_factor, max_position_embeddings, **kwargs):
1189
+ super().__init__(**kwargs)
1190
+ self.rotary_scaling_factor = rotary_scaling_factor
1191
+ self.max_position_embeddings = max_position_embeddings
1192
+
1193
+ def _compute_inv_freq(self, base=None, device=None):
1194
+ if base is None:
1195
+ base = self.base
1196
+ return 1.0 / (base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
1197
+
1198
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
1199
+ # Reset the tables if the sequence length has changed,
1200
+ # if we're on a new device (possibly due to tracing for instance),
1201
+ # or if we're switching from inference mode to training
1202
+ if seqlen > self.max_position_embeddings:
1203
+ base = self.base * (
1204
+ (self.rotary_scaling_factor * seqlen / self.max_position_embeddings) - (self.rotary_scaling_factor - 1)
1205
+ ) ** (self.dim / (self.dim - 2))
1206
+ inv_freq = self._compute_inv_freq(base=base, device=device)
1207
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
1208
+
1209
+ if (
1210
+ seqlen > self._seq_len_cached
1211
+ or self._cos_cached is None
1212
+ or self._cos_cached.device != device
1213
+ or self._cos_cached.dtype != dtype
1214
+ or (self.training and self._cos_cached.is_inference())
1215
+ ):
1216
+ self._seq_len_cached = seqlen
1217
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
1218
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
1219
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
1220
+ if self.pos_idx_in_fp32:
1221
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
1222
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
1223
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
1224
+ # cos & sin output to change significantly.
1225
+ # We want to recompute self.inv_freq if it was not loaded in fp32
1226
+ if self.inv_freq.dtype != torch.float32:
1227
+ if seqlen > self.max_position_embeddings:
1228
+ base = self.base * (
1229
+ (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
1230
+ ) ** (self.dim / (self.dim - 2))
1231
+ else:
1232
+ base = self.base
1233
+ inv_freq = self._compute_inv_freq(device=device, base=base)
1234
+ else:
1235
+ inv_freq = self.inv_freq
1236
+ else:
1237
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
1238
+ inv_freq = self.inv_freq
1239
+ # Don't do einsum, it converts fp32 to fp16 under AMP
1240
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
1241
+ freqs = torch.outer(t, inv_freq)
1242
+ if self.scale is None:
1243
+ self._cos_cached = torch.cos(freqs).to(dtype)
1244
+ self._sin_cached = torch.sin(freqs).to(dtype)
1245
+ else:
1246
+ power = (
1247
+ torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
1248
+ ) / self.scale_base
1249
+ scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
1250
+ # We want the multiplication by scale to happen in fp32
1251
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
1252
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
1253
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
1254
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
1255
+
1256
+
1257
+ class NomicBertAttention(nn.Module):
1258
+ """Multi-head self-attention and cross-attention"""
1259
+
1260
+ def __init__(
1261
+ self,
1262
+ config,
1263
+ ) -> None:
1264
+ """
1265
+ num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
1266
+ return_residual: whether to return the input x along with the output. This is for
1267
+ performance reason: for post-norm architecture, returning the input allows us
1268
+ to fuse the backward of nn.Linear with the residual connection.
1269
+ """
1270
+ super().__init__()
1271
+ self.embed_dim = config.n_embd
1272
+ self.use_flash_attn = config.use_flash_attn
1273
+ self.fused_bias_fc = config.fused_bias_fc
1274
+
1275
+ self.num_heads = config.n_head
1276
+ self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
1277
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
1278
+ self.head_dim = self.embed_dim // self.num_heads
1279
+ # we don't really support mqa / gqa for now
1280
+ qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
1281
+
1282
+ self.register_buffer(
1283
+ "norm_factor",
1284
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
1285
+ persistent=False,
1286
+ )
1287
+
1288
+ self.rotary_emb_dim = self.head_dim * config.rotary_emb_fraction
1289
+ if self.rotary_emb_dim > 0:
1290
+ if getattr(config, "rotary_scaling_factor", None):
1291
+ self.rotary_emb = NomicBertDynamicNTKRotaryEmbedding(
1292
+ dim=self.rotary_emb_dim,
1293
+ base=config.rotary_emb_base,
1294
+ scale_base=config.rotary_emb_scale_base,
1295
+ interleaved=config.rotary_emb_interleaved,
1296
+ rotary_scaling_factor=config.rotary_scaling_factor,
1297
+ max_position_embeddings=config.max_trained_positions,
1298
+ )
1299
+ else:
1300
+ self.rotary_emb = NomicBertRotaryEmbedding(
1301
+ dim=self.rotary_emb_dim,
1302
+ base=config.rotary_emb_base,
1303
+ scale_base=config.rotary_emb_scale_base,
1304
+ interleaved=config.rotary_emb_interleaved,
1305
+ )
1306
+ # bug in xformers: https://github.com/facebookresearch/xformers/issues/841
1307
+ # uses the head dimension instead of the sequence dimension
1308
+ self.rotary_head_dim = getattr(config, "rotary_head_dim", False)
1309
+
1310
+ self.Wqkv = nn.Linear(self.embed_dim, qkv_dim, bias=config.qkv_proj_bias)
1311
+
1312
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1313
+ self.causal = config.causal
1314
+ self.drop = nn.Dropout(config.attn_pdrop)
1315
+ self.num_prefix_tokens = max(getattr(config, "register_tokens", 1), 1)
1316
+
1317
+ def forward(
1318
+ self,
1319
+ hidden_states: torch.Tensor,
1320
+ attention_mask: Optional[torch.Tensor] = None,
1321
+ position_ids: Optional[torch.LongTensor] = None,
1322
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1323
+ output_attentions: bool = False,
1324
+ use_cache: bool = False,
1325
+ is_padded_inputs: Optional[bool] = True,
1326
+ cu_seqlens: Optional[torch.Tensor] = None,
1327
+ max_seq_len: Optional[int] = None,
1328
+ rope: Optional[torch.Tensor] = None,
1329
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
1330
+
1331
+ has_layer_past = past_key_value is not None
1332
+
1333
+ if has_layer_past:
1334
+ past_key_value = past_key_value[0]
1335
+ past_len = past_key_value[1]
1336
+ else:
1337
+ past_len = 0
1338
+
1339
+ qkv = self.Wqkv(hidden_states)
1340
+ qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
1341
+
1342
+ past_key_value = (past_key_value, past_len + qkv.size(1)) if use_cache else None
1343
+
1344
+ if self.rotary_emb_dim > 0:
1345
+ if self.rotary_head_dim:
1346
+ qkv = rearrange(qkv, "b s three h d -> b h three s d")
1347
+ qkv = self.rotary_emb(qkv, seqlen_offset=past_len)
1348
+
1349
+ if self.rotary_head_dim:
1350
+ qkv = rearrange(qkv, "b h three s d -> b s three h d")
1351
+ elif rope is not None:
1352
+ q, k, v = qkv.permute(0, 3, 1, 2, 4).unbind(dim=-2)
1353
+ q = torch.cat([q[:, :, :self.num_prefix_tokens], apply_rot_embed_cat(q[:, :, self.num_prefix_tokens:], rope)], dim=2).type_as(q)
1354
+ k = torch.cat([k[:, :, :self.num_prefix_tokens], apply_rot_embed_cat(k[:, :, self.num_prefix_tokens:], rope)], dim=2).type_as(q)
1355
+
1356
+ qkv = torch.stack([q, k, v], dim=-2)
1357
+ qkv = rearrange(qkv, "b h s three d -> b s three h d")
1358
+
1359
+ query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
1360
+
1361
+ query = query.permute(0, 2, 1, 3)
1362
+ key = key.permute(0, 2, 1, 3)
1363
+ value = value.permute(0, 2, 1, 3)
1364
+
1365
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
1366
+ if attention_mask is not None:
1367
+ attention_scores = attention_scores + attention_mask
1368
+
1369
+ attentions_probs = F.softmax(attention_scores, dim=-1)
1370
+ attentions_probs = self.drop(attentions_probs)
1371
+
1372
+ attn_output = torch.matmul(attentions_probs, value)
1373
+ attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
1374
+
1375
+ attn_output = self.out_proj(attn_output)
1376
+
1377
+ return attn_output
1378
+
1379
+
1380
+ class NomicBertBlock(NomicBertPreTrainedModel):
1381
+ def __init__(
1382
+ self,
1383
+ config,
1384
+ ):
1385
+ super().__init__(config=config)
1386
+ self.prenorm = config.prenorm
1387
+ self.fused_dropout_add_ln = config.fused_dropout_add_ln
1388
+
1389
+ self.attn = NomicBertAttention(config)
1390
+ activation = (
1391
+ F.sigmoid
1392
+ if config.activation_function == "glu"
1393
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
1394
+ )
1395
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
1396
+ self.mlp = NomciBertGatedMLP(
1397
+ config.n_embd,
1398
+ hidden_features=config.n_inner,
1399
+ bias1=config.mlp_fc1_bias,
1400
+ bias2=config.mlp_fc2_bias,
1401
+ activation=activation,
1402
+ fused_bias_fc=config.fused_bias_fc,
1403
+ norm_layer=getattr(config, "norm_mlp", False),
1404
+ )
1405
+ else:
1406
+ self.mlp = NomicBertMLP(
1407
+ config.n_embd,
1408
+ hidden_features=config.n_inner,
1409
+ bias1=config.mlp_fc1_bias,
1410
+ bias2=config.mlp_fc2_bias,
1411
+ activation=activation,
1412
+ fused_bias_fc=config.fused_bias_fc,
1413
+ )
1414
+
1415
+ self.dropout1 = nn.Dropout(config.resid_pdrop)
1416
+ self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1417
+ self.norm2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1418
+ self.dropout2 = nn.Dropout(config.resid_pdrop)
1419
+
1420
+ def forward(
1421
+ self,
1422
+ hidden_states: torch.Tensor,
1423
+ hidden_states2: torch.Tensor,
1424
+ residual: Optional[torch.Tensor] = None,
1425
+ attention_mask: Optional[torch.Tensor] = None,
1426
+ position_ids: Optional[torch.LongTensor] = None,
1427
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
1428
+ is_padded_inputs: Optional[bool] = True,
1429
+ output_attentions: Optional[bool] = False,
1430
+ use_cache: Optional[bool] = False,
1431
+ cu_seqlens: Optional[torch.Tensor] = None,
1432
+ max_seq_len: Optional[int] = None,
1433
+ rope: Optional[torch.Tensor] = None,
1434
+ ):
1435
+ r"""Pass the input through the encoder layer.
1436
+
1437
+ Args:
1438
+ hidden_states: the sequence to the encoder layer (required).
1439
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
1440
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
1441
+ before applying the query projection. Useful for e.g., ViT where we only care
1442
+ about the CLS token in the last layer.
1443
+ """
1444
+ if self.prenorm:
1445
+ dropped = self.dropout1(hidden_states)
1446
+ residual = (dropped + residual) if residual is not None else dropped
1447
+ hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype))
1448
+ hidden_states = self.attn(
1449
+ hidden_states,
1450
+ attention_mask=attention_mask,
1451
+ is_padded_inputs=is_padded_inputs,
1452
+ cu_seqlens=cu_seqlens,
1453
+ max_seq_len=max_seq_len,
1454
+ rope=rope,
1455
+ )
1456
+
1457
+ dropped = self.dropout2(hidden_states)
1458
+ residual = (dropped + residual) if residual is not None else dropped
1459
+ hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
1460
+ hidden_states = self.mlp(hidden_states)
1461
+
1462
+ return hidden_states, None, residual
1463
+ else:
1464
+ assert residual is None
1465
+ attn_outputs = self.attn(
1466
+ hidden_states,
1467
+ attention_mask=attention_mask,
1468
+ is_padded_inputs=is_padded_inputs,
1469
+ cu_seqlens=cu_seqlens,
1470
+ max_seq_len=max_seq_len,
1471
+ rope=rope,
1472
+ )
1473
+ hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
1474
+ mlp_out = self.mlp(hidden_states)
1475
+
1476
+ hidden_states = self.norm2((self.dropout2(mlp_out) + hidden_states).to(dtype=self.norm2.weight.dtype))
1477
+ return hidden_states, None, None
1478
+
1479
+
1480
+ class NomicBertEncoder(nn.Module):
1481
+ def __init__(self, config: GPT2Config):
1482
+ super().__init__()
1483
+ self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
1484
+ self.gradient_checkpointing = False
1485
+ self.config = config
1486
+
1487
+ def forward(
1488
+ self,
1489
+ hidden_states: torch.LongTensor = None,
1490
+ attention_mask: Optional[torch.Tensor] = None,
1491
+ position_ids: Optional[torch.LongTensor] = None,
1492
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1493
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1494
+ use_cache: Optional[bool] = None,
1495
+ output_attentions: Optional[bool] = None,
1496
+ output_hidden_states: Optional[bool] = None,
1497
+ return_dict: Optional[bool] = None,
1498
+ is_padded_inputs: Optional[bool] = True,
1499
+ rope: Optional[torch.Tensor] = None,
1500
+ ):
1501
+ """If subset_mask is not None, we only want output for the subset of the sequence.
1502
+ This means that we only compute the last layer output for these tokens.
1503
+ subset_mask: (batch, seqlen), dtype=torch.bool
1504
+ """
1505
+ hidden_states2 = None
1506
+ residual = None
1507
+
1508
+ for _, layer in enumerate(self.layers):
1509
+ if self.gradient_checkpointing and self.training:
1510
+
1511
+ def create_custom_forward(module):
1512
+ def custom_forward(*inputs):
1513
+ # None for past_key_value
1514
+ return module(*inputs)
1515
+
1516
+ return custom_forward
1517
+
1518
+ hidden_states, hidden_states2, residual = torch.utils.checkpoint.checkpoint(
1519
+ create_custom_forward(layer),
1520
+ hidden_states,
1521
+ hidden_states2,
1522
+ residual,
1523
+ attention_mask,
1524
+ position_ids,
1525
+ past_key_values,
1526
+ is_padded_inputs,
1527
+ output_attentions,
1528
+ use_cache,
1529
+ None,
1530
+ None,
1531
+ rope,
1532
+ # if you freeze ANY layers, you need `use_reentrant=False`
1533
+ # https://github.com/huggingface/transformers/issues/21381
1534
+ # https://discuss.pytorch.org/t/checkpoint-with-no-grad-requiring-inputs-problem/19117/7
1535
+ use_reentrant=False,
1536
+ )
1537
+
1538
+ else:
1539
+ hidden_states, hidden_states2, residual = layer(
1540
+ hidden_states,
1541
+ hidden_states2,
1542
+ residual,
1543
+ attention_mask,
1544
+ position_ids,
1545
+ None,
1546
+ is_padded_inputs,
1547
+ output_attentions,
1548
+ use_cache,
1549
+ rope=rope,
1550
+ )
1551
+ return hidden_states
1552
+
1553
+
1554
+ class NomicBertPooler(nn.Module):
1555
+ def __init__(self, config):
1556
+ super().__init__()
1557
+ self.dense = nn.Linear(config.n_embd, config.n_embd)
1558
+ self.activation = nn.Tanh()
1559
+
1560
+ def forward(self, hidden_states, pool=True):
1561
+ # We "pool" the model by simply taking the hidden state corresponding
1562
+ # to the first token.
1563
+ first_token_tensor = hidden_states[:, 0] if pool else hidden_states
1564
+ pooled_output = self.dense(first_token_tensor)
1565
+ pooled_output = self.activation(pooled_output)
1566
+ return pooled_output
1567
+
1568
+
1569
+ class NomicBertPredictionHeadTransform(nn.Module):
1570
+ def __init__(self, config):
1571
+ super().__init__()
1572
+ self.dense = nn.Linear(config.n_embd, config.n_embd, bias=config.mlp_fc1_bias)
1573
+ approximate = "tanh" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_pytorch_tanh"] else "none"
1574
+ if config.activation_function == "swiglu":
1575
+ self.transform_act_fn = F.silu
1576
+ else:
1577
+ self.transform_act_fn = nn.GELU(approximate=approximate)
1578
+
1579
+ self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1580
+
1581
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1582
+ hidden_states = self.dense(hidden_states)
1583
+ hidden_states = self.transform_act_fn(hidden_states)
1584
+ hidden_states = self.layer_norm(hidden_states)
1585
+
1586
+ return hidden_states
1587
+
1588
+
1589
+ class NomicBertLMPredictionHead(nn.Module):
1590
+ def __init__(self, config):
1591
+ super().__init__()
1592
+
1593
+ self.transform = NomicBertPredictionHeadTransform(config)
1594
+
1595
+ self.decoder = nn.Linear(config.n_embd, config.vocab_size, bias=config.mlp_fc1_bias)
1596
+
1597
+ def forward(self, hidden_states):
1598
+ hidden_states = self.transform(hidden_states)
1599
+ hidden_states = self.decoder(hidden_states)
1600
+ return hidden_states
1601
+
1602
+
1603
+ class NomicBertPreTrainingHeads(nn.Module):
1604
+ def __init__(self, config):
1605
+ super().__init__()
1606
+ self.predictions = NomicBertLMPredictionHead(config)
1607
+
1608
+ def forward(self, sequence_output):
1609
+ prediction_scores = self.predictions(sequence_output)
1610
+ return prediction_scores
1611
+
1612
+
1613
+ class NomicBertModel(NomicBertPreTrainedModel):
1614
+ def __init__(self, config: GPT2Config, add_pooling_layer=True):
1615
+ super().__init__(config)
1616
+ self.pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
1617
+ if config.vocab_size % self.pad_vocab_size_multiple != 0:
1618
+ config.vocab_size += self.pad_vocab_size_multiple - (config.vocab_size % self.pad_vocab_size_multiple)
1619
+
1620
+ assert config.activation_function in [
1621
+ "gelu",
1622
+ "gelu_new",
1623
+ "gelu_fast",
1624
+ "gelu_pytorch_tanh",
1625
+ "swiglu",
1626
+ "geglu",
1627
+ "glu",
1628
+ ]
1629
+
1630
+ self.embeddings = NomicBertEmbeddings(config)
1631
+ self.emb_drop = nn.Dropout(config.resid_pdrop)
1632
+ self.emb_ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1633
+ self.encoder = NomicBertEncoder(config)
1634
+ self.pooler = NomicBertPooler(config) if add_pooling_layer else None
1635
+
1636
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1637
+
1638
+ def forward(
1639
+ self,
1640
+ input_ids,
1641
+ attention_mask=None,
1642
+ position_ids=None,
1643
+ token_type_ids=None,
1644
+ return_dict=None,
1645
+ matryoshka_dim=None,
1646
+ ):
1647
+ if token_type_ids is None:
1648
+ token_type_ids = torch.zeros_like(input_ids)
1649
+ hidden_states = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
1650
+ hidden_states = self.emb_ln(hidden_states)
1651
+ hidden_states = self.emb_drop(hidden_states)
1652
+
1653
+ attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.shape)
1654
+ sequence_output = self.encoder(hidden_states, attention_mask=attention_mask, return_dict=return_dict)
1655
+
1656
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1657
+
1658
+ if matryoshka_dim:
1659
+ sequence_output = sequence_output[:, :matryoshka_dim]
1660
+
1661
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1662
+ last_hidden_state=sequence_output,
1663
+ pooler_output=pooled_output,
1664
+ )
1665
+
1666
+
1667
+ class NomicBertForPreTraining(NomicBertPreTrainedModel):
1668
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
1669
+
1670
+ def __init__(self, config: GPT2Config):
1671
+ super().__init__(config)
1672
+
1673
+ self.bert = NomicBertModel(config, add_pooling_layer=getattr(config, "add_pooling_layer", False))
1674
+ self.cls = NomicBertPreTrainingHeads(config)
1675
+ self.mlm_loss = nn.CrossEntropyLoss()
1676
+
1677
+ # Initialize weights and apply final processing
1678
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
1679
+ self.tie_weights()
1680
+
1681
+ def tie_weights(self):
1682
+ self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight
1683
+
1684
+ def forward(
1685
+ self,
1686
+ input_ids,
1687
+ position_ids=None,
1688
+ token_type_ids=None,
1689
+ attention_mask=None,
1690
+ labels=None,
1691
+ ):
1692
+ """
1693
+ If labels are provided, they must be -100 for masked out tokens (as specified in the attention
1694
+ mask).
1695
+ Outputs:
1696
+ if `labels` and `next_sentence_label` are not `None`:
1697
+ Outputs the total_loss which is the sum of the masked language modeling loss and the next
1698
+ sentence classification loss.
1699
+ if `labels` or `next_sentence_label` is `None`:
1700
+ Outputs a tuple comprising
1701
+ - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
1702
+ - the next sentence classification logits of shape [batch_size, 2].
1703
+
1704
+ """
1705
+ outputs = self.bert(
1706
+ input_ids,
1707
+ position_ids=position_ids,
1708
+ token_type_ids=token_type_ids,
1709
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1710
+ )
1711
+ sequence_output, _ = outputs.last_hidden_state, outputs.pooler_output
1712
+
1713
+ prediction_scores = self.cls(sequence_output)
1714
+
1715
+ total_loss = None
1716
+ if labels is not None:
1717
+ masked_lm_loss = self.mlm_loss(
1718
+ rearrange(prediction_scores, "... v -> (...) v"),
1719
+ rearrange(labels, "... -> (...)"),
1720
+ )
1721
+ total_loss = masked_lm_loss.float()
1722
+
1723
+ return MaskedLMOutput(
1724
+ loss=total_loss,
1725
+ logits=prediction_scores,
1726
+ hidden_states=outputs.hidden_states,
1727
+ attentions=None,
1728
+ )
1729
+
1730
+
1731
+ class NomicBertForSequenceClassification(NomicBertPreTrainedModel):
1732
+ def __init__(self, config):
1733
+ super().__init__(config)
1734
+ self.num_labels = config.num_labels
1735
+ self.config = config
1736
+
1737
+ self.bert = NomicBertModel(config)
1738
+ classifier_dropout = getattr(config, "classifier_dropout", config.embd_pdrop)
1739
+ self.dropout = nn.Dropout(classifier_dropout)
1740
+ self.classifier = nn.Linear(config.n_embd, config.num_labels)
1741
+
1742
+ # Initialize weights and apply final processing
1743
+ self.post_init()
1744
+
1745
+ def forward(
1746
+ self,
1747
+ input_ids: Optional[torch.Tensor] = None,
1748
+ attention_mask: Optional[torch.Tensor] = None,
1749
+ token_type_ids: Optional[torch.Tensor] = None,
1750
+ position_ids: Optional[torch.Tensor] = None,
1751
+ head_mask: Optional[torch.Tensor] = None,
1752
+ inputs_embeds: Optional[torch.Tensor] = None,
1753
+ labels: Optional[torch.Tensor] = None,
1754
+ output_attentions: Optional[bool] = None,
1755
+ output_hidden_states: Optional[bool] = None,
1756
+ return_dict: Optional[bool] = None,
1757
+ ):
1758
+ r"""
1759
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1760
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1761
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1762
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1763
+ """
1764
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1765
+ outputs = self.bert(
1766
+ input_ids,
1767
+ position_ids=position_ids,
1768
+ token_type_ids=token_type_ids,
1769
+ attention_mask=attention_mask.bool() if attention_mask is not None else None,
1770
+ )
1771
+
1772
+ pooled_output = outputs[1]
1773
+
1774
+ pooled_output = self.dropout(pooled_output)
1775
+ logits = self.classifier(pooled_output)
1776
+
1777
+ loss = None
1778
+ if labels is not None:
1779
+ if self.config.problem_type is None:
1780
+ if self.num_labels == 1:
1781
+ self.config.problem_type = "regression"
1782
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1783
+ self.config.problem_type = "single_label_classification"
1784
+ else:
1785
+ self.config.problem_type = "multi_label_classification"
1786
+
1787
+ if self.config.problem_type == "regression":
1788
+ loss_fct = nn.MSELoss()
1789
+ if self.num_labels == 1:
1790
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1791
+ else:
1792
+ loss = loss_fct(logits, labels)
1793
+ elif self.config.problem_type == "single_label_classification":
1794
+ loss_fct = nn.CrossEntropyLoss()
1795
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1796
+ elif self.config.problem_type == "multi_label_classification":
1797
+ loss_fct = nn.BCEWithLogitsLoss()
1798
+ loss = loss_fct(logits, labels)
1799
+ if not return_dict:
1800
+ output = (logits,) + outputs[2:]
1801
+ return ((loss,) + output) if loss is not None else output
1802
+
1803
+ return SequenceClassifierOutput(
1804
+ loss=loss,
1805
+ logits=logits,
1806
+ hidden_states=outputs.hidden_states,
1807
+ attentions=outputs.attentions,
1808
+ )
1809
+
1810
+ def hf_vit_config_to_vit_config(vit_config: ViTConfig) -> GPT2Config:
1811
+ return GPT2Config(
1812
+ n_embd=vit_config.hidden_size,
1813
+ n_layer=vit_config.num_hidden_layers,
1814
+ n_head=vit_config.num_attention_heads,
1815
+ n_inner=vit_config.intermediate_size,
1816
+ activation_function=vit_config.hidden_act,
1817
+ vocab_size=0, # no vocab since using patches
1818
+ n_positions=0, # No absolute position embedding
1819
+ resid_pdrop=0.0, # No dropout
1820
+ embd_pdrop=getattr(vit_config, "dropout", 0.0),
1821
+ attn_pdrop=vit_config.attention_probs_dropout_prob,
1822
+ layer_norm_epsilon=vit_config.layer_norm_eps,
1823
+ initializer_range=vit_config.initializer_range,
1824
+ bos_token_id=None,
1825
+ eos_token_id=None,
1826
+ # These are new arguments not in the original GPT2Config
1827
+ drop_path_rate=0.0,
1828
+ # Why is there double layer norm??
1829
+ prepre_layernom=False,
1830
+ layer_scale=False,
1831
+ layer_scale_init=None,
1832
+ img_size=vit_config.image_size,
1833
+ patch_size=vit_config.patch_size,
1834
+ num_channels=vit_config.num_channels,
1835
+ prenorm=True,
1836
+ parallel_block=False,
1837
+ parallel_block_tied_norm=False,
1838
+ rotary_emb_fraction=0,
1839
+ tie_word_embeddings=False,
1840
+ fused_dropout_add_ln=True,
1841
+ fused_bias_fc=True,
1842
+ patch_embed_bias=True,
1843
+ use_flash_attn=True,
1844
+ qkv_proj_bias=True,
1845
+ mlp_fc1_bias=getattr(vit_config, "mlp_fc1_bias", True),
1846
+ mlp_fc2_bias=getattr(vit_config, "mlp_fc2_bias", True),
1847
+ use_rms_norm=False,
1848
+ causal=False,
1849
+ hidden_features_scaling_factor=1.0,
1850
+ mask_token=False,
1851
+ learned_pos_embedding=False,
1852
+ patch_dropout=0,
1853
+ sinusoidal_pos_embedding=vit_config.model_type == "vit_mae"
1854
+ )
1855
+
1856
+
1857
+ class NomicAttentionPooling(nn.Module):
1858
+ def __init__(
1859
+ self,
1860
+ config
1861
+ ):
1862
+ super().__init__()
1863
+ self.embed_dim = config.n_embd
1864
+ self.use_flash_attn = config.use_flash_attn
1865
+ self.fused_bias_fc = config.fused_bias_fc
1866
+
1867
+ self.num_heads = config.n_head
1868
+ self.num_heads_kv = config.num_heads_kv if getattr(config, "num_heads_kv", None) is not None else self.num_heads
1869
+ assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
1870
+ self.head_dim = self.embed_dim // self.num_heads
1871
+ # we don't really support mqa / gqa for now
1872
+ kv_dim = 2 * self.head_dim * self.num_heads_kv
1873
+
1874
+ self.register_buffer(
1875
+ "norm_factor",
1876
+ torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()),
1877
+ persistent=False,
1878
+ )
1879
+
1880
+ self.Wq = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1881
+ self.Wkv = nn.Linear(self.embed_dim, kv_dim, bias=config.qkv_proj_bias)
1882
+
1883
+ self.latent = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
1884
+
1885
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.qkv_proj_bias)
1886
+ self.causal = config.causal
1887
+ self.drop = nn.Dropout(config.attn_pdrop)
1888
+
1889
+ def init_weights(self):
1890
+ trunc_normal_tf_(self.latent, std=self.embed_dim ** -0.5)
1891
+
1892
+ def forward(
1893
+ self,
1894
+ kv,
1895
+ attention_mask=None,
1896
+ cu_seqlens_k=None,
1897
+ max_seqlen_k=None,
1898
+ is_padded_inputs: Optional[bool] = True,
1899
+ output_attentions: bool = False,
1900
+ ):
1901
+ """Implements the multihead softmax attention.
1902
+ Arguments
1903
+ ---------
1904
+ q: The tensor containing the query. (B, Sq, H, D)
1905
+ kv: The tensor containing the key and value. (B, Sk, 2, H_k, D)
1906
+ causal: if passed, will override self.causal
1907
+ cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1908
+ of the sequences in the batch, used to index into q.
1909
+ max_seqlen: int. Maximum sequence length in the batch of q.
1910
+ cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
1911
+ of the sequences in the batch, used to index into kv.
1912
+ max_seqlen_k: int. Maximum sequence length in the batch of k and v.
1913
+ """
1914
+ q_latent = self.latent.expand(kv.size(0), -1, -1)
1915
+ q = self.Wq(q_latent)
1916
+ bsz, q_len, h_size = q.shape
1917
+ kv = self.Wkv(kv)
1918
+ query = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
1919
+ kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
1920
+
1921
+ key, value = kv[:, :, 0], kv[:, :, 1]
1922
+
1923
+ query = query.permute(0, 2, 1, 3)
1924
+ key = key.permute(0, 2, 1, 3)
1925
+ value = value.permute(0, 2, 1, 3)
1926
+
1927
+ attention_scores = torch.matmul(query, key.transpose(-1, -2)) / self.norm_factor
1928
+ if attention_mask is not None:
1929
+ attention_scores = attention_scores + attention_mask
1930
+
1931
+ attentions_probs = F.softmax(attention_scores, dim=-1)
1932
+ attentions_probs = self.drop(attentions_probs)
1933
+
1934
+ attn_output = torch.matmul(attentions_probs, value)
1935
+ attn_output = rearrange(attn_output.permute(0, 2, 1, 3), "... h d -> ... (h d)")
1936
+
1937
+ attn_output = self.out_proj(attn_output)
1938
+
1939
+ return attn_output
1940
+
1941
+
1942
+ class NomicMultiHeadAttentionPooling(nn.Module):
1943
+ def __init__(
1944
+ self,
1945
+ config,
1946
+ ):
1947
+ super().__init__()
1948
+ self.prenorm = config.prenorm
1949
+ self.fused_dropout_add_ln = config.fused_dropout_add_ln
1950
+
1951
+ self.attn = NomicAttentionPooling(config)
1952
+ activation = (
1953
+ F.sigmoid
1954
+ if config.activation_function == "glu"
1955
+ else (F.silu if config.activation_function == "swiglu" else F.gelu)
1956
+ )
1957
+ if config.activation_function in ["glu", "swiglu", "geglu"]:
1958
+ self.mlp = NomciBertGatedMLP(
1959
+ config.n_embd,
1960
+ hidden_features=config.n_inner,
1961
+ bias1=config.mlp_fc1_bias,
1962
+ bias2=config.mlp_fc2_bias,
1963
+ activation=activation,
1964
+ fused_bias_fc=config.fused_bias_fc,
1965
+ )
1966
+ else:
1967
+ self.mlp = NomicBertMLP(
1968
+ config.n_embd,
1969
+ hidden_features=config.n_inner,
1970
+ bias1=config.mlp_fc1_bias,
1971
+ bias2=config.mlp_fc2_bias,
1972
+ activation=activation,
1973
+ fused_bias_fc=config.fused_bias_fc,
1974
+ )
1975
+
1976
+ self.dropout1 = nn.Dropout(config.resid_pdrop)
1977
+ self.norm1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
1978
+ self.dropout2 = nn.Dropout(config.resid_pdrop)
1979
+
1980
+ def forward(
1981
+ self,
1982
+ hidden_states: torch.Tensor,
1983
+ attention_mask: Optional[torch.Tensor] = None,
1984
+ ):
1985
+ r"""Pass the input through the encoder layer.
1986
+
1987
+ Args:
1988
+ hidden_states: the sequence to the encoder layer (required).
1989
+ residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
1990
+ mixer_subset: for cross-attention only. If not None, will take a subset of x
1991
+ before applying the query projection. Useful for e.g., ViT where we only care
1992
+ about the CLS token in the last layer.
1993
+ """
1994
+
1995
+ attn_outputs = self.attn(
1996
+ hidden_states,
1997
+ attention_mask=attention_mask,
1998
+ )
1999
+
2000
+ normed = self.norm1(attn_outputs)
2001
+ hidden_states = hidden_states + self.mlp(normed)
2002
+
2003
+ return hidden_states
2004
+
2005
+ class NomicVisionPreTrainedModel(PreTrainedModel):
2006
+ """An abstract class to handle weights initialization and
2007
+ a simple interface for dowloading and loading pretrained models.
2008
+ """
2009
+
2010
+ config_class = NomicBertConfig
2011
+ base_model_prefix = "model"
2012
+ supports_gradient_checkpointing = True
2013
+ _no_split_modules = ["Block"]
2014
+ _skip_keys_device_placement = "past_key_values"
2015
+
2016
+ def __init__(self, config, *inputs, **kwargs):
2017
+ super().__init__(config)
2018
+ if not isinstance(config, GPT2Config):
2019
+ raise ValueError(
2020
+ "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
2021
+ "To create a model from a Google pretrained model use "
2022
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
2023
+ self.__class__.__name__, self.__class__.__name__
2024
+ )
2025
+ )
2026
+ self.config = config
2027
+
2028
+ class NomicVisionModel(NomicVisionPreTrainedModel):
2029
+ def __init__(self, config):
2030
+ super().__init__(config)
2031
+
2032
+ self.embeddings = NomicVisionPatchEmbeddings(config)
2033
+ self.layers = nn.ModuleList([NomicBertBlock(config) for _ in range(config.n_layer)])
2034
+
2035
+ self.selector = NomicMultiHeadAttentionPooling(config)
2036
+
2037
+ self.global_pool = getattr(config, "global_pool", None)
2038
+ self.num_prefix_tokens = (1 if not getattr(config, "no_cls_token", False) else 0) + getattr(config, "register_tokens", 0)
2039
+
2040
+ self.apply(partial(_init_weights, initializer_range=config.initializer_range))
2041
+
2042
+ def forward(
2043
+ self,
2044
+ pixel_values,
2045
+ attention_mask=None,
2046
+ position_ids=None,
2047
+ token_type_ids=None,
2048
+ return_dict=None,
2049
+ matryoshka_dim=None,
2050
+ ):
2051
+ embeddings, rope = self.embeddings(pixel_values)
2052
+
2053
+ original_dtype = embeddings.dtype
2054
+
2055
+ hidden_states = embeddings
2056
+ # unused but easier to pass to gradient checkpointing as words
2057
+ residual = None
2058
+ for layer in self.layers:
2059
+ # need to pass none for backwards compatability
2060
+ hidden_states, _, residual = layer(hidden_states, None, residual=residual, is_padded_inputs=False, rope=rope)
2061
+
2062
+ hidden_states = hidden_states + residual
2063
+ if self.global_pool == "avg":
2064
+ hidden_states = hidden_states[:, self.num_prefix_tokens:].mean(dim=1)
2065
+
2066
+ pooled_output = self.selector(hidden_states)
2067
+
2068
+ return BaseModelOutputWithPast(
2069
+ last_hidden_state=pooled_output,
2070
+ hidden_states=hidden_states,
2071
+ )
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "sentence_transformers.models.Transformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ }
14
+ ]
sentence_bert_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "max_seq_length": 8192,
3
+ "do_lower_case": false
4
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "mask_token": "[MASK]",
48
+ "max_length": 8192,
49
+ "model_max_length": 8192,
50
+ "pad_to_multiple_of": null,
51
+ "pad_token": "[PAD]",
52
+ "pad_token_type_id": 0,
53
+ "padding_side": "right",
54
+ "sep_token": "[SEP]",
55
+ "stride": 0,
56
+ "strip_accents": null,
57
+ "tokenize_chinese_chars": true,
58
+ "tokenizer_class": "BertTokenizer",
59
+ "truncation_side": "right",
60
+ "truncation_strategy": "longest_first",
61
+ "unk_token": "[UNK]"
62
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff