Pringled commited on
Commit
df8f870
·
verified ·
1 Parent(s): 5cf2056

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +355 -46
README.md CHANGED
@@ -1,92 +1,401 @@
1
  ---
2
- library_name: model2vec
 
3
  license: mit
4
- model_name: static-coderankembed-potion-code-16m-contrastive
5
  tags:
 
6
  - embeddings
 
 
7
  - static-embeddings
8
- - sentence-transformers
9
  ---
10
 
11
- # static-coderankembed-potion-code-16m-contrastive Model Card
 
 
12
 
13
- This [Model2Vec](https://github.com/MinishLab/model2vec) model is a distilled version of a Sentence Transformer. It uses static embeddings, allowing text embeddings to be computed orders of magnitude faster on both GPU and CPU. It is designed for applications where computational resources are limited or where real-time performance is critical. Model2Vec models are the smallest, fastest, and most performant static embedders available. The distilled models are up to 50 times smaller and 500 times faster than traditional Sentence Transformers.
14
 
 
15
 
16
  ## Installation
17
 
18
- Install model2vec using pip:
19
- ```
20
  pip install model2vec
21
  ```
22
 
23
  ## Usage
24
 
25
- ### Using Model2Vec
26
-
27
- The [Model2Vec library](https://github.com/MinishLab/model2vec) is the fastest and most lightweight way to run Model2Vec models.
28
-
29
- Load this model using the `from_pretrained` method:
30
  ```python
31
  from model2vec import StaticModel
32
 
33
- # Load a pretrained Model2Vec model
34
- model = StaticModel.from_pretrained("static-coderankembed-potion-code-16m-contrastive")
 
 
35
 
36
- # Compute text embeddings
37
- embeddings = model.encode(["Example sentence"])
38
  ```
39
 
40
- ### Using Sentence Transformers
41
 
42
- You can also use the [Sentence Transformers library](https://github.com/UKPLab/sentence-transformers) to load and use the model:
43
 
44
- ```python
45
- from sentence_transformers import SentenceTransformer
 
 
 
46
 
47
- # Load a pretrained Sentence Transformer model
48
- model = SentenceTransformer("static-coderankembed-potion-code-16m-contrastive")
49
 
50
- # Compute text embeddings
51
- embeddings = model.encode(["Example sentence"])
52
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- ### Distilling a Model2Vec model
55
 
56
- You can distill a Model2Vec model from a Sentence Transformer model using the `distill` method. First, install the `distill` extra with `pip install model2vec[distill]`. Then, run the following code:
57
 
58
  ```python
59
- from model2vec.distill import distill
60
 
61
- # Distill a Sentence Transformer model, in this case the BAAI/bge-base-en-v1.5 model
62
- m2v_model = distill(model_name="BAAI/bge-base-en-v1.5", pca_dims=256)
63
 
64
- # Save the model
65
- m2v_model.save_pretrained("m2v_model")
66
- ```
67
 
68
- ## How it works
 
 
 
 
69
 
70
- Model2vec creates a small, fast, and powerful model that outperforms other static embedding models by a large margin on all tasks we could find, while being much faster to create than traditional static embedding models such as GloVe. Best of all, you don't need any data to distill a model using Model2Vec.
71
 
72
- It works by passing a vocabulary through a sentence transformer model, then reducing the dimensionality of the resulting embeddings using PCA, and finally weighting the embeddings using [SIF weighting](https://openreview.net/pdf?id=SyK00v5xx). During inference, we simply take the mean of all token embeddings occurring in a sentence.
 
73
 
74
- ## Additional Resources
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- - [Model2Vec Repo](https://github.com/MinishLab/model2vec)
77
- - [Model2Vec Base Models](https://huggingface.co/collections/minishlab/model2vec-base-models-66fd9dd9b7c3b3c0f25ca90e)
78
- - [Model2Vec Results](https://github.com/MinishLab/model2vec/tree/main/results)
79
- - [Model2Vec Docs](https://minish.ai/packages/model2vec/introduction)
80
 
 
 
 
81
 
82
- ## Library Authors
 
83
 
84
- Model2Vec was developed by the [Minish Lab](https://github.com/MinishLab) team consisting of [Stephan Tulkens](https://github.com/stephantul) and [Thomas van Dongen](https://github.com/Pringled).
 
 
 
85
 
86
- ## Citation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) if you use this model in your work.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  ```
 
 
 
 
90
  @software{minishlab2024model2vec,
91
  author = {Stephan Tulkens and {van Dongen}, Thomas},
92
  title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
@@ -96,4 +405,4 @@ Please cite the [Model2Vec repository](https://github.com/MinishLab/model2vec) i
96
  url = {https://github.com/MinishLab/model2vec},
97
  license = {MIT}
98
  }
99
- ```
 
1
  ---
2
+ language:
3
+ - code
4
  license: mit
5
+ library_name: model2vec
6
  tags:
7
+ - model2vec
8
  - embeddings
9
+ - code
10
+ - retrieval
11
  - static-embeddings
 
12
  ---
13
 
14
+ # potion-code-16M Model Card
15
+
16
+ ## Overview
17
 
18
+ **potion-code-16M** is a fast static code embedding model optimized for code retrieval tasks. It is distilled from [nomic-ai/CodeRankEmbed](https://huggingface.co/nomic-ai/CodeRankEmbed) and trained on the [CornStack](https://huggingface.co/datasets/nomic-ai/cornstack-python-v1) code corpus using [Tokenlearn](https://github.com/MinishLab/tokenlearn) and contrastive fine-tuning.
19
 
20
+ It uses static embeddings, allowing text and code embeddings to be computed orders of magnitude faster than transformer-based models on both GPU and CPU.
21
 
22
  ## Installation
23
 
24
+ ```bash
 
25
  pip install model2vec
26
  ```
27
 
28
  ## Usage
29
 
 
 
 
 
 
30
  ```python
31
  from model2vec import StaticModel
32
 
33
+ model = StaticModel.from_pretrained("Pringled/potion-code-16M")
34
+
35
+ # Embed natural language queries
36
+ query_embeddings = model.encode(["How to read a file in Python?"])
37
 
38
+ # Embed code documents
39
+ code_embeddings = model.encode(["def read_file(path):\n with open(path) as f:\n return f.read()"])
40
  ```
41
 
42
+ ## How it works
43
 
44
+ potion-code-16M is created using the following pipeline:
45
 
46
+ 1. **Vocabulary mining**: code-specific tokens are mined from CornStack and added to the base CodeRankEmbed tokenizer (42k extra tokens → ~62.5k total)
47
+ 2. **Distillation**: the extended vocabulary is distilled from CodeRankEmbed using Model2Vec (256-dimensional embeddings, PCA whitening)
48
+ 3. **Tokenlearn**: the distilled model is fine-tuned on 240k (query, document) pairs from CornStack using cosine similarity loss
49
+ 4. **Contrastive fine-tuning**: the model is further fine-tuned using MultipleNegativesRankingLoss on 120k CornStack query-document pairs
50
+ 5. **Post-SIF re-regularization**: token weights are re-regularized using SIF weighting after each training stage
51
 
52
+ ## Results
 
53
 
54
+ Results on the [CoIR benchmark](https://github.com/CoIR-team/coir) (NDCG@10, `mteb>=2.10`):
55
+
56
+ | Model | Params | AppsRetrieval | COIRCodeSearchNet | CodeFeedbackMT | CodeFeedbackST | CodeSearchNetCC | CodeTransContest | CodeTransDL | CosQA | StackOverflow | Text2SQL | **AVG** |
57
+ |---|---|---|---|---|---|---|---|---|---|---|---|---|
58
+ | CodeRankEmbed | 137M | - | - | - | - | - | - | - | - | - | - | - |
59
+ | BM25 | — | 4.76 | 32.45 | 59.69 | 67.85 | 33.00 | 47.29 | 32.97 | 15.53 | 69.54 | 28.07 | 39.11 |
60
+ | **potion-code-16M** | **16M** | **3.97** | **42.99** | **36.26** | **50.27** | **43.40** | **39.76** | **31.72** | **21.37** | **57.47** | **43.34** | **37.05** |
61
+
62
+ *Results for CodeRankEmbed coming soon.*
63
+
64
+ ## Model Details
65
+
66
+ | Property | Value |
67
+ |---|---|
68
+ | Parameters | ~16M |
69
+ | Embedding dimensions | 256 |
70
+ | Vocabulary size | ~62,500 |
71
+ | Teacher model | nomic-ai/CodeRankEmbed |
72
+ | Training corpus | CornStack (6 languages: Python, Java, JavaScript, Go, PHP, Ruby) |
73
+ | Max sequence length | 1,000,000 tokens (static, no limit in practice) |
74
+
75
+ ## Additional Resources
76
+
77
+ - [Model2Vec repository](https://github.com/MinishLab/model2vec)
78
+ - [Tokenlearn repository](https://github.com/MinishLab/tokenlearn)
79
+ - [CornStack dataset](https://huggingface.co/datasets/nomic-ai/cornstack-python-v1)
80
+ - [CoIR benchmark](https://github.com/CoIR-team/coir)
81
 
82
+ ## Reproducibility
83
 
84
+ The following script reproduces this model end-to-end. It requires the tokenlearn training data from `Pringled/cornstack-docs-tokenlearn` and `Pringled/cornstack-queries-tokenlearn` (20k samples per language used).
85
 
86
  ```python
87
+ """Reproduction script for potion-code-16M.
88
 
89
+ Runs the full pipeline: distill tokenlearn contrastive fine-tuning.
 
90
 
91
+ Requirements:
92
+ pip install model2vec tokenlearn sentence-transformers datasets skeletoken einops
 
93
 
94
+ The three model checkpoints are saved to:
95
+ ./models/potion-code-16M-distilled
96
+ ./models/potion-code-16M-tokenlearn
97
+ ./models/potion-code-16M-contrastive ← final model
98
+ """
99
 
100
+ from __future__ import annotations
101
 
102
+ import logging
103
+ import random
104
 
105
+ import numpy as np
106
+ import torch
107
+ from datasets import Dataset, concatenate_datasets, load_dataset
108
+ from huggingface_hub import snapshot_download
109
+ from model2vec import StaticModel
110
+ from model2vec.distill import distill_from_model
111
+ from model2vec.distill.inference import post_process_embeddings
112
+ from pathlib import Path
113
+ from sentence_transformers import (
114
+ SentenceTransformer,
115
+ SentenceTransformerTrainer,
116
+ SentenceTransformerTrainingArguments,
117
+ )
118
+ from sentence_transformers.losses import MultipleNegativesRankingLoss
119
+ from sentence_transformers.models import StaticEmbedding
120
+ from sentence_transformers.training_args import BatchSamplers
121
+ from skeletoken import TokenizerModel
122
+ from sklearn.decomposition import PCA
123
+ from tokenlearn.losses import Loss
124
+ from tokenlearn.model import StaticModelForFineTuning
125
+ from tokenlearn.utils import create_vocab
126
+ from transformers import AutoModel, AutoTokenizer
127
 
128
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
129
+ logger = logging.getLogger(__name__)
 
 
130
 
131
+ # ---------------------------------------------------------------------------
132
+ # Hyperparameters
133
+ # ---------------------------------------------------------------------------
134
 
135
+ TEACHER_MODEL = "nomic-ai/CodeRankEmbed"
136
+ OUTPUT_DIR = Path("models")
137
 
138
+ # Distill
139
+ VOCAB_SIZE = 42_000 # extra tokens mined from CornStack → ~62.5k total → ~16M params
140
+ PCA_DIMS = 256
141
+ SIF_COEFFICIENT = 1e-4
142
 
143
+ # Tokenlearn
144
+ TOKENLEARN_DOCS_DATASET = "Pringled/cornstack-docs-tokenlearn"
145
+ TOKENLEARN_QUERIES_DATASET = "Pringled/cornstack-queries-tokenlearn"
146
+ TOKENLEARN_LANGUAGES = ["go", "java", "javascript", "php", "python", "ruby"]
147
+ TOKENLEARN_MAX_PER_LANGUAGE = 20_000 # 20k docs + 20k queries × 6 langs = 240k total
148
+ TOKENLEARN_LR = 1e-3
149
+ TOKENLEARN_MAX_EPOCHS = 20 # early stopping (patience=5) typically kicks in earlier
150
+ TOKENLEARN_BATCH_SIZE = 128
151
+
152
+ # Contrastive
153
+ CORNSTACK_DATASETS = {
154
+ "python": "nomic-ai/cornstack-python-v1",
155
+ "java": "nomic-ai/cornstack-java-v1",
156
+ "php": "nomic-ai/cornstack-php-v1",
157
+ "go": "nomic-ai/cornstack-go-v1",
158
+ "javascript": "nomic-ai/cornstack-javascript-v1",
159
+ "ruby": "nomic-ai/cornstack-ruby-v1",
160
+ }
161
+ CONTRASTIVE_MAX_PER_LANGUAGE = 20_000 # 20k × 6 langs = 120k pairs total
162
+ CONTRASTIVE_LR = 5e-3
163
+ CONTRASTIVE_EPOCHS = 3
164
+ CONTRASTIVE_BATCH_SIZE = 512
165
+ CONTRASTIVE_SEED = 42
166
+
167
+
168
+ # ---------------------------------------------------------------------------
169
+ # Helpers
170
+ # ---------------------------------------------------------------------------
171
+
172
+ def apply_post_sif(model: StaticModel, pca_dims: int, sif_coefficient: float) -> StaticModel:
173
+ embeddings_np = model.embedding.astype(np.float32)
174
+ processed, weights = post_process_embeddings(
175
+ embeddings_np, pca_dims=pca_dims, sif_coefficient=sif_coefficient
176
+ )
177
+ logger.info("post_process_embeddings: %s → %s", embeddings_np.shape, processed.shape)
178
+ model.embedding = processed
179
+ model.weights = weights
180
+ return model
181
+
182
+
183
+ # ---------------------------------------------------------------------------
184
+ # Step 1: Distill
185
+ # ---------------------------------------------------------------------------
186
+
187
+ def run_distill(save_path: Path) -> None:
188
+ logger.info("Downloading %s ...", TEACHER_MODEL)
189
+ local_path = snapshot_download(TEACHER_MODEL)
190
+ model = AutoModel.from_pretrained(local_path, trust_remote_code=True)
191
+ tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True, use_fast=True)
192
+
193
+ # Load tokenlearn corpus texts for vocab mining (docs + queries, 20k/lang)
194
+ logger.info("Loading texts for vocabulary mining ...")
195
+ shards = []
196
+ for lang in TOKENLEARN_LANGUAGES:
197
+ docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
198
+ queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
199
+ shards.extend([docs, queries])
200
+ corpus = concatenate_datasets(shards)
201
+ texts: list[str] = list(corpus["text"])
202
+ logger.info("Loaded %d texts for vocab mining.", len(texts))
203
+
204
+ logger.info("Mining vocabulary (target size=%d) ...", VOCAB_SIZE)
205
+ vocab = create_vocab(texts=texts, vocab_size=VOCAB_SIZE)
206
+ logger.info("Mined %d tokens.", len(vocab))
207
+
208
+ # Filter: keep only new single-token entries not already in CodeRankEmbed vocabulary.
209
+ tokenizer_model = TokenizerModel.from_transformers_tokenizer(tokenizer).prune_added_tokens()
210
+ preprocessor = tokenizer_model.preprocessor
211
+ seen = set(tokenizer_model.sorted_vocabulary)
212
+ filtered = []
213
+ for token in vocab:
214
+ preprocessed = preprocessor.preprocess(token)
215
+ if len(preprocessed) == 1 and preprocessed[0] not in seen:
216
+ seen.add(preprocessed[0])
217
+ filtered.append(preprocessed[0])
218
+ logger.info("Vocabulary after filtering: %d tokens added to CodeRankEmbed.", len(filtered))
219
 
220
+ # NomicBERT requires monkey-patched embedding accessors.
221
+ model.get_input_embeddings = lambda: model.embeddings.word_embeddings
222
+ model.set_input_embeddings = lambda v: setattr(model.embeddings, "word_embeddings", v)
223
+
224
+ logger.info("Distilling (pca_dims=%d, sif=%g) ...", PCA_DIMS, SIF_COEFFICIENT)
225
+ static_model = distill_from_model(
226
+ model=model,
227
+ tokenizer=tokenizer,
228
+ vocabulary=filtered,
229
+ pca_dims=PCA_DIMS,
230
+ sif_coefficient=SIF_COEFFICIENT,
231
+ pooling="mean",
232
+ quantize_to="float32",
233
+ )
234
+
235
+ save_path.mkdir(parents=True, exist_ok=True)
236
+ static_model.save_pretrained(str(save_path))
237
+ logger.info("Distilled model saved to %s (vocab=%d, dims=%d)",
238
+ save_path, static_model.embedding.shape[0], static_model.embedding.shape[1])
239
+
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # Step 2: Tokenlearn
243
+ # ---------------------------------------------------------------------------
244
+
245
+ def run_tokenlearn(base_model_path: Path, save_path: Path) -> None:
246
+ # Load 20k docs + 20k queries per language → 240k total
247
+ logger.info("Loading tokenlearn data (docs + queries, %d/lang × %d langs) ...",
248
+ TOKENLEARN_MAX_PER_LANGUAGE, len(TOKENLEARN_LANGUAGES))
249
+ shards = []
250
+ for lang in TOKENLEARN_LANGUAGES:
251
+ docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
252
+ queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
253
+ shards.extend([docs, queries])
254
+ dataset = concatenate_datasets(shards)
255
+ logger.info("Total samples: %d", len(dataset))
256
+
257
+ train_txt: list[str] = list(dataset["text"])
258
+ train_vec = np.array(dataset["embedding"], dtype=np.float32)
259
+ non_nan_mask = ~np.isnan(train_vec).any(axis=1)
260
+ train_txt = np.array(train_txt)[non_nan_mask].tolist()
261
+ train_vec = train_vec[non_nan_mask]
262
+ logger.info("Loaded %d samples, raw vector shape: %s", len(train_txt), train_vec.shape)
263
+
264
+ logger.info("Fitting PCA to %d dims ...", PCA_DIMS)
265
+ pca = PCA(n_components=PCA_DIMS)
266
+ train_vec = pca.fit_transform(train_vec)
267
+ logger.info("Explained variance: %.4f. Shape: %s",
268
+ pca.explained_variance_ratio_.cumsum()[-1], train_vec.shape)
269
+
270
+ logger.info("Loading base model from %s ...", base_model_path)
271
+ base_model = StaticModel.from_pretrained(str(base_model_path), force_download=False)
272
+ if base_model.embedding.dtype != np.float32:
273
+ base_model.embedding = base_model.embedding.astype(np.float32)
274
+
275
+ trainable = StaticModelForFineTuning.from_static_model(
276
+ model=base_model,
277
+ out_dim=PCA_DIMS,
278
+ loss=Loss("cosine"),
279
+ )
280
+ logger.info("Training tokenlearn (lr=%g, max_epochs=%d, batch=%d) ...",
281
+ TOKENLEARN_LR, TOKENLEARN_MAX_EPOCHS, TOKENLEARN_BATCH_SIZE)
282
+ trainable.fit(
283
+ X=train_txt,
284
+ y=torch.from_numpy(train_vec.astype(np.float32)),
285
+ batch_size=TOKENLEARN_BATCH_SIZE,
286
+ learning_rate=TOKENLEARN_LR,
287
+ max_epochs=TOKENLEARN_MAX_EPOCHS,
288
+ early_stopping_patience=5,
289
+ use_wandb=False,
290
+ )
291
+ logger.info("Tokenlearn training complete.")
292
+
293
+ trained_model = trainable.to_static_model()
294
+ trained_model = apply_post_sif(trained_model, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT)
295
+
296
+ save_path.mkdir(parents=True, exist_ok=True)
297
+ trained_model.save_pretrained(str(save_path))
298
+ logger.info("Tokenlearn model saved to %s", save_path)
299
+
300
+
301
+ # ---------------------------------------------------------------------------
302
+ # Step 3: Contrastive fine-tuning (MNRL)
303
+ # ---------------------------------------------------------------------------
304
+
305
+ def run_contrastive(base_model_path: Path, save_path: Path) -> None:
306
+ random.seed(CONTRASTIVE_SEED)
307
+
308
+ logger.info("Streaming CornStack pairs (%d/lang × %d langs) ...",
309
+ CONTRASTIVE_MAX_PER_LANGUAGE, len(CORNSTACK_DATASETS))
310
+ all_queries: list[str] = []
311
+ all_docs: list[str] = []
312
+ for lang, hf_name in CORNSTACK_DATASETS.items():
313
+ hf_ds = load_dataset(hf_name, split="train", streaming=True)
314
+ hf_ds = hf_ds.shuffle(seed=CONTRASTIVE_SEED, buffer_size=10_000)
315
+ kept = 0
316
+ seen_q: set[str] = set()
317
+ seen_d: set[str] = set()
318
+ for row in hf_ds:
319
+ q, d = row.get("query"), row.get("document")
320
+ if not isinstance(q, str) or not isinstance(d, str):
321
+ continue
322
+ if len(q) < 32 or len(d) < 32:
323
+ continue
324
+ if q in seen_q or d in seen_d:
325
+ continue
326
+ seen_q.add(q)
327
+ seen_d.add(d)
328
+ all_queries.append(q)
329
+ all_docs.append(d)
330
+ kept += 1
331
+ if kept >= CONTRASTIVE_MAX_PER_LANGUAGE:
332
+ break
333
+ logger.info(" %s: %d pairs", lang, kept)
334
+
335
+ logger.info("Total pairs: %d", len(all_queries))
336
+ train_dataset = Dataset.from_dict({"anchor": all_queries, "positive": all_docs})
337
+
338
+ static_embedding = StaticEmbedding.from_model2vec(str(base_model_path))
339
+ model = SentenceTransformer(modules=[static_embedding])
340
+ loss = MultipleNegativesRankingLoss(model)
341
+
342
+ training_args = SentenceTransformerTrainingArguments(
343
+ output_dir=str(save_path) + "-checkpoints",
344
+ num_train_epochs=CONTRASTIVE_EPOCHS,
345
+ per_device_train_batch_size=CONTRASTIVE_BATCH_SIZE,
346
+ learning_rate=CONTRASTIVE_LR,
347
+ warmup_steps=0.1,
348
+ fp16=False,
349
+ bf16=False,
350
+ batch_sampler=BatchSamplers.NO_DUPLICATES,
351
+ save_strategy="no",
352
+ logging_steps=100,
353
+ logging_first_step=True,
354
+ report_to=[],
355
+ )
356
+ logger.info("Training contrastive (lr=%g, epochs=%d, batch=%d) ...",
357
+ CONTRASTIVE_LR, CONTRASTIVE_EPOCHS, CONTRASTIVE_BATCH_SIZE)
358
+
359
+ trainer = SentenceTransformerTrainer(
360
+ model=model, args=training_args, train_dataset=train_dataset, loss=loss,
361
+ )
362
+ trainer.train()
363
+ logger.info("Contrastive training complete.")
364
+
365
+ base_m2v = StaticModel.from_pretrained(str(base_model_path), force_download=False)
366
+ base_m2v.embedding = model[0].embedding.weight.detach().cpu().float().numpy()
367
+
368
+ final_model = apply_post_sif(base_m2v, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT)
369
+
370
+ save_path.mkdir(parents=True, exist_ok=True)
371
+ final_model.save_pretrained(str(save_path))
372
+ logger.info("Final model saved to %s", save_path)
373
+
374
+
375
+ # ---------------------------------------------------------------------------
376
+ # Main
377
+ # ---------------------------------------------------------------------------
378
+
379
+ if __name__ == "__main__":
380
+ distilled_path = OUTPUT_DIR / "potion-code-16M-distilled"
381
+ tokenlearn_path = OUTPUT_DIR / "potion-code-16M-tokenlearn"
382
+ contrastive_path = OUTPUT_DIR / "potion-code-16M-contrastive"
383
+
384
+ logger.info("=== Step 1/3: Distill ===")
385
+ run_distill(save_path=distilled_path)
386
+
387
+ logger.info("=== Step 2/3: Tokenlearn ===")
388
+ run_tokenlearn(base_model_path=distilled_path, save_path=tokenlearn_path)
389
+
390
+ logger.info("=== Step 3/3: Contrastive ===")
391
+ run_contrastive(base_model_path=tokenlearn_path, save_path=contrastive_path)
392
+
393
+ logger.info("Done. Final model: %s", contrastive_path)
394
  ```
395
+
396
+ ## Citation
397
+
398
+ ```bibtex
399
  @software{minishlab2024model2vec,
400
  author = {Stephan Tulkens and {van Dongen}, Thomas},
401
  title = {Model2Vec: Fast State-of-the-Art Static Embeddings},
 
405
  url = {https://github.com/MinishLab/model2vec},
406
  license = {MIT}
407
  }
408
+ ```