Memphis-CoT-3B / README.md
euclaise's picture
Update README.md
cb86a0f verified
|
raw
history blame
5.79 kB
metadata
license: cc-by-sa-3.0
datasets:
  - euclaise/TinyCoT
  - euclaise/reddit-instruct
  - sablo/oasst2_curated
library_name: transformers
tags:
  - supertrainer2000
  - human-data
metrics:
  - accuracy

image/png

Memphis-CoT is a finetune of StableLM 3b 4e1t on TinyCoT, along with reddit-instruct (subset to 5000 examples, excluding posts with brackets in the title) and a curated subset of oasst2.

Memphis was trained only on human data! No GPT generations here.

Finetuning was performed using my supertrainer2000 framework, using my Adalite optimizer.

Training Procedure

I finetuned the model using an iterative rationale-bootstrapping procedure inspired by STaR and SPIN

First, I finetuned the model on all the datasets using a MixCE loss and NEFTune, for 2 epochs.

I then performed the following steps 4 times:

  1. Generate responses for each question in TinyCoT using the current model, check each response for correctness, and create a dataset of (correct, incorrect) pairs. Extra values are discarded, such that each correct and incorrect response is unique.
  2. Finetune the model for 1 epoch using a ranking loss over length-normalized log-probabilities of each sequence, similar to Preference Ranking Optimization, comparing the correct vs incorrect generated response. A standard CE loss over the ground-truth was included to prevent excessive drift.

This should be more efficient than either STaR or SPIN, as it uses a ranking loss rather than rejection sampling (unlike STaR), and verifies correctness instead of assuming all model responses are incorrect (unlike SPIN).

Prompt formats

The format for reddit-instruct and oasst2 was:

### User:
[insert instruction here]
### Assistant:
[insert response here]
### User:
...

The format for TinyCoT was:

### User:
[insert instruction here]
### Rationale:
[insert reasoning here]
### Answer:
[insert direct answer here]

Benchmarks

Model Size Data Method GSM8K (5-shot) AGIEval (English/Nous subset, acc_norm)
StableLM 3B Base 3B Base Base 2.05% 25.14%
StableHermes 3B 3B GPT SFT 3.64% 24.31%
MPT 7B Instruct 7B Human+Anthropic SFT 2.05% 24.12%
OpenLLaMA 7B v2 open-instruct 7B Human (nearly: ecqa is an exception) SFT 8.64% 23.21%
StableLM Zephyr 3B 3B GPT DPO 45.72% 33.31%
Memphis-CoT 3B 3B Human Self-teaching 13.8% 26.24%

Memphis outperforms human-data models that are over twice its size, along with SFT models of its size, but doesn't quite reach the performance of the Zephyr DPO model. That said, Zephyr uses synthetic data, and much more of it.

Notes:

  • Evaluations were performed using the agieval branch of lm-evaluation-harness (commit 0bef5c9c273b1c2f68e6018d4bb9c32b9aaff298), using the vllm model.
  • I tried to find human-data-trained StableLM models, but couldn't find any. I did find a few OpenLLaMA models, but they wouldn't load with LM Eval Harness and vllm.
  • OpenLLaMA 7B v2 open-instruct is a particularly relevant comparison, as it was trained on a very similar dataset.

Hyperparameters

For the initial supervised finetuning step:

  • Adalite optimizer, default hyperparameters of supertrainer2000 unless otherwise specified
  • Lambda (Adalite's analogue to weight decay, see here for details) of 0.01
  • LR of 1e-5
  • MixCE ratio of 0.75
  • Sequence length of 4096
  • Cosine decay with a 20% warmup
  • Frozen embeddings
  • No training on inputs
  • Accumulated batch size of 128
  • NEFTune with an alpha of 10

For the generations:

  • Generated using the current git version of vllm
  • N=8
  • Temperature of 0.5
  • top_p of 0.8
  • Maximum of 512 generated tokens, discarding responses that do not have a valid rationale and answer

For the rank finetuning:

  • Adalite optimizer, default hyperparameters of supertrainer2000 unless otherwise specified
  • Lambda of 0.01
  • LR of 5e-7
  • Rank loss weight of 5
  • Sequence length of 1024
  • Cosine schedule with 10% warmup
  • Frozen embeddings
  • No training on inputs
  • Accumulated batch size of 128
  • NEFTune with an alpha of 10