Pringled commited on
Commit
49f83e2
·
verified ·
1 Parent(s): 643eff7

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +314 -0
README.md CHANGED
@@ -79,6 +79,320 @@ Results on the [CoIR benchmark](https://github.com/CoIR-team/coir) (NDCG@10, `mt
79
  - [CornStack dataset](https://huggingface.co/datasets/nomic-ai/cornstack-python-v1)
80
  - [CoIR benchmark](https://github.com/CoIR-team/coir)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ## Citation
83
 
84
  ```bibtex
 
79
  - [CornStack dataset](https://huggingface.co/datasets/nomic-ai/cornstack-python-v1)
80
  - [CoIR benchmark](https://github.com/CoIR-team/coir)
81
 
82
+ ## Reproducibility
83
+
84
+ The following script reproduces this model end-to-end. It requires the tokenlearn training data from `Pringled/cornstack-docs-tokenlearn` and `Pringled/cornstack-queries-tokenlearn` (20k samples per language used).
85
+
86
+ ```python
87
+ """Reproduction script for potion-code-16M.
88
+
89
+ Runs the full pipeline: distill → tokenlearn → contrastive fine-tuning.
90
+
91
+ Requirements:
92
+ pip install model2vec tokenlearn sentence-transformers datasets skeletoken einops
93
+
94
+ The three model checkpoints are saved to:
95
+ ./models/potion-code-16M-distilled
96
+ ./models/potion-code-16M-tokenlearn
97
+ ./models/potion-code-16M-contrastive ← final model
98
+ """
99
+
100
+ from __future__ import annotations
101
+
102
+ import logging
103
+ import random
104
+
105
+ import numpy as np
106
+ import torch
107
+ from datasets import Dataset, concatenate_datasets, load_dataset
108
+ from huggingface_hub import snapshot_download
109
+ from model2vec import StaticModel
110
+ from model2vec.distill import distill_from_model
111
+ from model2vec.distill.inference import post_process_embeddings
112
+ from pathlib import Path
113
+ from sentence_transformers import (
114
+ SentenceTransformer,
115
+ SentenceTransformerTrainer,
116
+ SentenceTransformerTrainingArguments,
117
+ )
118
+ from sentence_transformers.losses import MultipleNegativesRankingLoss
119
+ from sentence_transformers.models import StaticEmbedding
120
+ from sentence_transformers.training_args import BatchSamplers
121
+ from skeletoken import TokenizerModel
122
+ from sklearn.decomposition import PCA
123
+ from tokenlearn.losses import Loss
124
+ from tokenlearn.model import StaticModelForFineTuning
125
+ from tokenlearn.utils import create_vocab
126
+ from transformers import AutoModel, AutoTokenizer
127
+
128
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
129
+ logger = logging.getLogger(__name__)
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # Hyperparameters
133
+ # ---------------------------------------------------------------------------
134
+
135
+ TEACHER_MODEL = "nomic-ai/CodeRankEmbed"
136
+ OUTPUT_DIR = Path("models")
137
+
138
+ # Distill
139
+ VOCAB_SIZE = 42_000 # extra tokens mined from CornStack → ~62.5k total → ~16M params
140
+ PCA_DIMS = 256
141
+ SIF_COEFFICIENT = 1e-4
142
+
143
+ # Tokenlearn
144
+ TOKENLEARN_DOCS_DATASET = "Pringled/cornstack-docs-tokenlearn"
145
+ TOKENLEARN_QUERIES_DATASET = "Pringled/cornstack-queries-tokenlearn"
146
+ TOKENLEARN_LANGUAGES = ["go", "java", "javascript", "php", "python", "ruby"]
147
+ TOKENLEARN_MAX_PER_LANGUAGE = 20_000 # 20k docs + 20k queries × 6 langs = 240k total
148
+ TOKENLEARN_LR = 1e-3
149
+ TOKENLEARN_MAX_EPOCHS = 20 # early stopping (patience=5) typically kicks in earlier
150
+ TOKENLEARN_BATCH_SIZE = 128
151
+
152
+ # Contrastive
153
+ CORNSTACK_DATASETS = {
154
+ "python": "nomic-ai/cornstack-python-v1",
155
+ "java": "nomic-ai/cornstack-java-v1",
156
+ "php": "nomic-ai/cornstack-php-v1",
157
+ "go": "nomic-ai/cornstack-go-v1",
158
+ "javascript": "nomic-ai/cornstack-javascript-v1",
159
+ "ruby": "nomic-ai/cornstack-ruby-v1",
160
+ }
161
+ CONTRASTIVE_MAX_PER_LANGUAGE = 20_000 # 20k × 6 langs = 120k pairs total
162
+ CONTRASTIVE_LR = 5e-3
163
+ CONTRASTIVE_EPOCHS = 3
164
+ CONTRASTIVE_BATCH_SIZE = 512
165
+ CONTRASTIVE_SEED = 42
166
+
167
+
168
+ # ---------------------------------------------------------------------------
169
+ # Helpers
170
+ # ---------------------------------------------------------------------------
171
+
172
+ def apply_post_sif(model: StaticModel, pca_dims: int, sif_coefficient: float) -> StaticModel:
173
+ embeddings_np = model.embedding.astype(np.float32)
174
+ processed, weights = post_process_embeddings(
175
+ embeddings_np, pca_dims=pca_dims, sif_coefficient=sif_coefficient
176
+ )
177
+ logger.info("post_process_embeddings: %s → %s", embeddings_np.shape, processed.shape)
178
+ model.embedding = processed
179
+ model.weights = weights
180
+ return model
181
+
182
+
183
+ # ---------------------------------------------------------------------------
184
+ # Step 1: Distill
185
+ # ---------------------------------------------------------------------------
186
+
187
+ def run_distill(save_path: Path) -> None:
188
+ logger.info("Downloading %s ...", TEACHER_MODEL)
189
+ local_path = snapshot_download(TEACHER_MODEL)
190
+ model = AutoModel.from_pretrained(local_path, trust_remote_code=True)
191
+ tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True, use_fast=True)
192
+
193
+ # Load tokenlearn corpus texts for vocab mining (docs + queries, 20k/lang)
194
+ logger.info("Loading texts for vocabulary mining ...")
195
+ shards = []
196
+ for lang in TOKENLEARN_LANGUAGES:
197
+ docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
198
+ queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
199
+ shards.extend([docs, queries])
200
+ corpus = concatenate_datasets(shards)
201
+ texts: list[str] = list(corpus["text"])
202
+ logger.info("Loaded %d texts for vocab mining.", len(texts))
203
+
204
+ logger.info("Mining vocabulary (target size=%d) ...", VOCAB_SIZE)
205
+ vocab = create_vocab(texts=texts, vocab_size=VOCAB_SIZE)
206
+ logger.info("Mined %d tokens.", len(vocab))
207
+
208
+ # Filter: keep only new single-token entries not already in CodeRankEmbed vocabulary.
209
+ tokenizer_model = TokenizerModel.from_transformers_tokenizer(tokenizer).prune_added_tokens()
210
+ preprocessor = tokenizer_model.preprocessor
211
+ seen = set(tokenizer_model.sorted_vocabulary)
212
+ filtered = []
213
+ for token in vocab:
214
+ preprocessed = preprocessor.preprocess(token)
215
+ if len(preprocessed) == 1 and preprocessed[0] not in seen:
216
+ seen.add(preprocessed[0])
217
+ filtered.append(preprocessed[0])
218
+ logger.info("Vocabulary after filtering: %d tokens added to CodeRankEmbed.", len(filtered))
219
+
220
+ # NomicBERT requires monkey-patched embedding accessors.
221
+ model.get_input_embeddings = lambda: model.embeddings.word_embeddings
222
+ model.set_input_embeddings = lambda v: setattr(model.embeddings, "word_embeddings", v)
223
+
224
+ logger.info("Distilling (pca_dims=%d, sif=%g) ...", PCA_DIMS, SIF_COEFFICIENT)
225
+ static_model = distill_from_model(
226
+ model=model,
227
+ tokenizer=tokenizer,
228
+ vocabulary=filtered,
229
+ pca_dims=PCA_DIMS,
230
+ sif_coefficient=SIF_COEFFICIENT,
231
+ pooling="mean",
232
+ quantize_to="float32",
233
+ )
234
+
235
+ save_path.mkdir(parents=True, exist_ok=True)
236
+ static_model.save_pretrained(str(save_path))
237
+ logger.info("Distilled model saved to %s (vocab=%d, dims=%d)",
238
+ save_path, static_model.embedding.shape[0], static_model.embedding.shape[1])
239
+
240
+
241
+ # ---------------------------------------------------------------------------
242
+ # Step 2: Tokenlearn
243
+ # ---------------------------------------------------------------------------
244
+
245
+ def run_tokenlearn(base_model_path: Path, save_path: Path) -> None:
246
+ # Load 20k docs + 20k queries per language → 240k total
247
+ logger.info("Loading tokenlearn data (docs + queries, %d/lang × %d langs) ...",
248
+ TOKENLEARN_MAX_PER_LANGUAGE, len(TOKENLEARN_LANGUAGES))
249
+ shards = []
250
+ for lang in TOKENLEARN_LANGUAGES:
251
+ docs = load_dataset(TOKENLEARN_DOCS_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
252
+ queries = load_dataset(TOKENLEARN_QUERIES_DATASET, name=lang, split=f"train[:{TOKENLEARN_MAX_PER_LANGUAGE}]")
253
+ shards.extend([docs, queries])
254
+ dataset = concatenate_datasets(shards)
255
+ logger.info("Total samples: %d", len(dataset))
256
+
257
+ train_txt: list[str] = list(dataset["text"])
258
+ train_vec = np.array(dataset["embedding"], dtype=np.float32)
259
+ non_nan_mask = ~np.isnan(train_vec).any(axis=1)
260
+ train_txt = np.array(train_txt)[non_nan_mask].tolist()
261
+ train_vec = train_vec[non_nan_mask]
262
+ logger.info("Loaded %d samples, raw vector shape: %s", len(train_txt), train_vec.shape)
263
+
264
+ logger.info("Fitting PCA to %d dims ...", PCA_DIMS)
265
+ pca = PCA(n_components=PCA_DIMS)
266
+ train_vec = pca.fit_transform(train_vec)
267
+ logger.info("Explained variance: %.4f. Shape: %s",
268
+ pca.explained_variance_ratio_.cumsum()[-1], train_vec.shape)
269
+
270
+ logger.info("Loading base model from %s ...", base_model_path)
271
+ base_model = StaticModel.from_pretrained(str(base_model_path), force_download=False)
272
+ if base_model.embedding.dtype != np.float32:
273
+ base_model.embedding = base_model.embedding.astype(np.float32)
274
+
275
+ trainable = StaticModelForFineTuning.from_static_model(
276
+ model=base_model,
277
+ out_dim=PCA_DIMS,
278
+ loss=Loss("cosine"),
279
+ )
280
+ logger.info("Training tokenlearn (lr=%g, max_epochs=%d, batch=%d) ...",
281
+ TOKENLEARN_LR, TOKENLEARN_MAX_EPOCHS, TOKENLEARN_BATCH_SIZE)
282
+ trainable.fit(
283
+ X=train_txt,
284
+ y=torch.from_numpy(train_vec.astype(np.float32)),
285
+ batch_size=TOKENLEARN_BATCH_SIZE,
286
+ learning_rate=TOKENLEARN_LR,
287
+ max_epochs=TOKENLEARN_MAX_EPOCHS,
288
+ early_stopping_patience=5,
289
+ use_wandb=False,
290
+ )
291
+ logger.info("Tokenlearn training complete.")
292
+
293
+ trained_model = trainable.to_static_model()
294
+ trained_model = apply_post_sif(trained_model, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT)
295
+
296
+ save_path.mkdir(parents=True, exist_ok=True)
297
+ trained_model.save_pretrained(str(save_path))
298
+ logger.info("Tokenlearn model saved to %s", save_path)
299
+
300
+
301
+ # ---------------------------------------------------------------------------
302
+ # Step 3: Contrastive fine-tuning (MNRL)
303
+ # ---------------------------------------------------------------------------
304
+
305
+ def run_contrastive(base_model_path: Path, save_path: Path) -> None:
306
+ random.seed(CONTRASTIVE_SEED)
307
+
308
+ logger.info("Streaming CornStack pairs (%d/lang × %d langs) ...",
309
+ CONTRASTIVE_MAX_PER_LANGUAGE, len(CORNSTACK_DATASETS))
310
+ all_queries: list[str] = []
311
+ all_docs: list[str] = []
312
+ for lang, hf_name in CORNSTACK_DATASETS.items():
313
+ hf_ds = load_dataset(hf_name, split="train", streaming=True)
314
+ hf_ds = hf_ds.shuffle(seed=CONTRASTIVE_SEED, buffer_size=10_000)
315
+ kept = 0
316
+ seen_q: set[str] = set()
317
+ seen_d: set[str] = set()
318
+ for row in hf_ds:
319
+ q, d = row.get("query"), row.get("document")
320
+ if not isinstance(q, str) or not isinstance(d, str):
321
+ continue
322
+ if len(q) < 32 or len(d) < 32:
323
+ continue
324
+ if q in seen_q or d in seen_d:
325
+ continue
326
+ seen_q.add(q)
327
+ seen_d.add(d)
328
+ all_queries.append(q)
329
+ all_docs.append(d)
330
+ kept += 1
331
+ if kept >= CONTRASTIVE_MAX_PER_LANGUAGE:
332
+ break
333
+ logger.info(" %s: %d pairs", lang, kept)
334
+
335
+ logger.info("Total pairs: %d", len(all_queries))
336
+ train_dataset = Dataset.from_dict({"anchor": all_queries, "positive": all_docs})
337
+
338
+ static_embedding = StaticEmbedding.from_model2vec(str(base_model_path))
339
+ model = SentenceTransformer(modules=[static_embedding])
340
+ loss = MultipleNegativesRankingLoss(model)
341
+
342
+ training_args = SentenceTransformerTrainingArguments(
343
+ output_dir=str(save_path) + "-checkpoints",
344
+ num_train_epochs=CONTRASTIVE_EPOCHS,
345
+ per_device_train_batch_size=CONTRASTIVE_BATCH_SIZE,
346
+ learning_rate=CONTRASTIVE_LR,
347
+ warmup_steps=0.1,
348
+ fp16=False,
349
+ bf16=False,
350
+ batch_sampler=BatchSamplers.NO_DUPLICATES,
351
+ save_strategy="no",
352
+ logging_steps=100,
353
+ logging_first_step=True,
354
+ report_to=[],
355
+ )
356
+ logger.info("Training contrastive (lr=%g, epochs=%d, batch=%d) ...",
357
+ CONTRASTIVE_LR, CONTRASTIVE_EPOCHS, CONTRASTIVE_BATCH_SIZE)
358
+
359
+ trainer = SentenceTransformerTrainer(
360
+ model=model, args=training_args, train_dataset=train_dataset, loss=loss,
361
+ )
362
+ trainer.train()
363
+ logger.info("Contrastive training complete.")
364
+
365
+ base_m2v = StaticModel.from_pretrained(str(base_model_path), force_download=False)
366
+ base_m2v.embedding = model[0].embedding.weight.detach().cpu().float().numpy()
367
+
368
+ final_model = apply_post_sif(base_m2v, pca_dims=PCA_DIMS, sif_coefficient=SIF_COEFFICIENT)
369
+
370
+ save_path.mkdir(parents=True, exist_ok=True)
371
+ final_model.save_pretrained(str(save_path))
372
+ logger.info("Final model saved to %s", save_path)
373
+
374
+
375
+ # ---------------------------------------------------------------------------
376
+ # Main
377
+ # ---------------------------------------------------------------------------
378
+
379
+ if __name__ == "__main__":
380
+ distilled_path = OUTPUT_DIR / "potion-code-16M-distilled"
381
+ tokenlearn_path = OUTPUT_DIR / "potion-code-16M-tokenlearn"
382
+ contrastive_path = OUTPUT_DIR / "potion-code-16M-contrastive"
383
+
384
+ logger.info("=== Step 1/3: Distill ===")
385
+ run_distill(save_path=distilled_path)
386
+
387
+ logger.info("=== Step 2/3: Tokenlearn ===")
388
+ run_tokenlearn(base_model_path=distilled_path, save_path=tokenlearn_path)
389
+
390
+ logger.info("=== Step 3/3: Contrastive ===")
391
+ run_contrastive(base_model_path=tokenlearn_path, save_path=contrastive_path)
392
+
393
+ logger.info("Done. Final model: %s", contrastive_path)
394
+ ```
395
+
396
  ## Citation
397
 
398
  ```bibtex