File size: 7,818 Bytes
2601c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ef2b59
2601c58
 
 
 
 
5ef2b59
2601c58
5ef2b59
2601c58
5ef2b59
2601c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
024e114
 
2601c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c07cbf
 
 
 
 
 
 
 
185cefc
 
1ec632b
 
2601c58
 
 
185cefc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2601c58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3035dd
2601c58
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
---
license: apache-2.0
base_model: jetmoe/jetmoe-8b
tags:
- alignment-handbook
- generated_from_trainer
datasets:
- HuggingFaceH4/ultrachat_200k
- HuggingFaceH4/airoboros-3.2
- HuggingFaceH4/Code-Feedback
- HuggingFaceH4/orca-math-word-problems-200k
- HuggingFaceH4/SystemChat
- HuggingFaceH4/capybara
model-index:
- name: jetmoe-8b-sft
  results: []
---

<div align="center">
  <div>&nbsp;</div>
  <img src="https://cdn-uploads.huggingface.co/production/uploads/641de0213239b631552713e4/ieHnwuczidNNoGRA_FN2y.png" width="500"/> 
  <img src="https://cdn-uploads.huggingface.co/production/uploads/641de0213239b631552713e4/UOsk9_zcbHpCCy6kmryYM.png" width="530"/> 
</div>

# JetMoE: Reaching LLaMA2 Performance with 0.1M Dollars

## Key Messages

1. JetMoE-8B is **trained with less than $ 0.1 million**<sup>1</sup> **cost but outperforms LLaMA2-7B from Meta AI**, who has multi-billion-dollar training resources. LLM training can be **much cheaper than people previously thought**.

2. JetMoE-8B is **fully open-sourced and academia-friendly** because:
    - It **only uses public datasets** for training, and the code is open-sourced. No proprietary resource is needed.
    - It **can be finetuned with very limited compute budget** (e.g., consumer-grade GPU) that most labs can afford.

3. JetMoE-8B **only has 2.2B active parameters** during inference, which drastically lowers the computational cost. Compared to a model with similar inference computation, like Gemma-2B, JetMoE-8B achieves constantly better performance.

<sup>1</sup> We used a 96×H100 GPU cluster for 2 weeks, which cost ~$0.08 million.

Website: [https://research.myshell.ai/jetmoe](https://research.myshell.ai/jetmoe)

HuggingFace: [https://huggingface.co/jetmoe/jetmoe-8b](https://huggingface.co/jetmoe/jetmoe-8b)

Online Demo on Lepton AI: [https://www.lepton.ai/playground/chat?model=jetmoe-8b-chat](https://www.lepton.ai/playground/chat?model=jetmoe-8b-chat)

Technical Report: [https://arxiv.org/pdf/2404.07413.pdf](https://arxiv.org/pdf/2404.07413.pdf)

## Authors

The project is contributed by [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ), [Zhen Guo](https://zguo0525.github.io/), [Tianle Cai](https://www.tianle.website/#/) and [Zengyi Qin](https://www.qinzy.tech/). For technical inquiries, please contact [Yikang Shen](https://scholar.google.com.hk/citations?user=qff5rRYAAAAJ). For media and collaboration inquiries, please contact [Zengyi Qin](https://www.qinzy.tech/).

## Collaboration
**If you have great ideas but need more resources (GPU, data, funding, etc.)**, welcome to contact **MyShell.ai** via [Zengyi Qin](https://www.qinzy.tech/). **MyShell.ai** is open to collaborations and are actively supporting high-quality open-source projects.

## Benchmarks
We use the same evaluation methodology as in the Open LLM leaderboard. For MBPP code benchmark, we use the same evaluation methodology as in the LLaMA2 and Deepseek-MoE paper. The results are shown below:

|Model|Activate Params|Training Tokens|Open LLM Leaderboard Avg|ARC|Hellaswag|MMLU|TruthfulQA|WinoGrande|GSM8k|MBPP|HumanEval|
|---|---|---|---|---|---|---|---|---|---|---|---|
|Shot||||25|10|5|0|5|5|3|0|
|Metric||||acc_norm|acc_norm|acc|mc2|acc|acc|Pass@1|Pass@1|
|LLaMA2-7B|7B|2T|51.0|53.1|78.6|46.9|38.8|74|14.5|20.8|12.8|
|LLaMA-13B|13B|1T|51.4|**56.2**|**80.9**|47.7|39.5|**76.2**|7.6|22.0|15.8|
|DeepseekMoE-16B|2.8B|2T|51.1|53.2|79.8|46.3|36.1|73.7|17.3|34.0|**25.0**|
|Gemma-2B|2B|2T|46.4|48.4|71.8|41.8|33.1|66.3|16.9|28.0|24.4|
|JetMoE-8B|2.2B|1.25T|**53.0**|48.7|80.5|**49.2**|**41.7**|70.2|**27.8**|**34.2**|14.6|

| Model               | MT-Bench Score     |
|---------------------|-----------|
| GPT-4               | 9.014     |
| GPT-3.5-turbo       | 7.995     |
| Claude-v1           | 7.923     |
| **JetMoE-8B-chat**  | **6.681** |
| Llama-2-13b-chat    | 6.650     |
| Vicuna-13b-v1.3     | 6.413     |
| Wizardlm-13b        | 6.353     |
| Llama-2-7b-chat     | 6.269     |



To our surprise, despite the lower training cost and computation, JetMoE-8B performs even better than LLaMA2-7B, LLaMA-13B, and DeepseekMoE-16B. Compared to a model with similar training and inference computation, like Gemma-2B, JetMoE-8B achieves better performance.

## Model Usage

Here's a quick example to get you started with JetMoE-8B-chat:

```python
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
# Initialize the model and tokenizer
model_name = "jetmoe/jetmoe-8b-chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation="eager", trust_remote_code=True)
# Check if a GPU is available and move the model to GPU if it is
if torch.cuda.is_available():
    model = model.cuda()
    print("Using GPU:", torch.cuda.get_device_name(torch.cuda.current_device()))
else:
    print("GPU is not available, using CPU instead.")
# Encode input context
messages = [
    {
        "role": "system",
        "content": "You are a friendly chatbot",
    },
    {"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
 ]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")
print(tokenized_chat)
# If using a GPU, move the input IDs to the GPU
if torch.cuda.is_available():
    input_ids = tokenized_chat.cuda()
# Generate text
output = model.generate(input_ids, max_length=500, num_return_sequences=1, no_repeat_ngram_size=2)
# If the output is on the GPU, move it back to CPU for decoding
if torch.cuda.is_available():
    output = output.cpu()
# Decode the generated text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)
```

## Model Details
JetMoE-8B has 24 blocks. 
Each block has two MoE layers: Mixture of Attention heads (MoA) and Mixture of MLP Experts (MoE).
Each MoA and MoE layer has 8 expert, and 2 experts are activated for each input token.
It has 8 billion parameters in total and 2.2B active parameters. 
JetMoE-8B is trained on 1.25T tokens from publicly available datasets, with a learning rate of 5.0 x 10<sup>-4</sup> and a global batch-size of 4M tokens.

<figure>
<center>
<img src="images/jetmoe_architecture.png" width="40%">
<figcaption>JetMoE Architecture</figcaption>
</center>
</figure>

## Training Details
Our training recipe follows the [MiniCPM](https://shengdinghu.notion.site/MiniCPM-Unveiling-the-Potential-of-End-side-Large-Language-Models-d4d3a8c426424654a4e80e42a711cb20?pvs=4)'s two-phases training method. Phase 1 uses a constant learning rate with linear warmup and is trained on 1 trillion tokens from large-scale open-source pretraining datasets, including RefinedWeb, Pile, Github data, etc. Phase 2 uses exponential learning rate decay and is trained on 250 billion tokens from phase 1 datasets and extra high-quality open-source datasets.

<figure>
<center>
<img src="images/Phase1_data.png" width="60%">
<img src="images/Phase2_data.png" width="60%">
</center>
</figure>

## Technical Report
For more details, please refer to the [JetMoE Technical Report](https://arxiv.org/abs/2404.07413).

## JetMoE Model Index
|Model|Index|
|---|---|
|JetMoE-8B-Base| [Link](https://huggingface.co/jetmoe/jetmoe-8B) |
|JetMoE-8B-SFT| [Link](https://huggingface.co/jetmoe/jetmoe-8B-sft) |
|JetMoE-8B-Chat| [Link](https://huggingface.co/jetmoe/jetmoe-8B-chat) |

## Acknowledgement
We express our gratitude to [Shengding Hu](https://shengdinghu.github.io/) for his valuable advice on the Phase 2 data mixture. We also express our gratitude to [Exabits](https://www.exabits.ai/) for their assistance in setting up the GPU clusters, and to [Lepton AI](https://www.lepton.ai/) for their support in setting up the chat demo.