Edit model card

LLaMA-8x265M-MoE

💻 Code

👋 Very nice to meet you here~

❤️ This repo contains the model LLaMA-8x265M-MoE(970M totally), which activates 2 out of 8 experts (332M parameters). This model is trained from scratch with FP32 precision. We firstly train the model through wikipedia dataset with 1 epoch and then through 10% of C4 dataset (10 data shards among 1024 data shards) with 1 epoch. This is NOT fine-tuned by instruction pairs, so it may not be good enough to act like a chatbot.

📢 This series also includes a dense version (without MoE structure), see 🤗this repo.

1. 🚀QuickStart

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_dir = "JuncaiL/llama-8x265m-moe"
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
model.eval()
model.to("cuda:0")

input_text = "Beijing is a famous city"
inputs = tokenizer(input_text, return_tensors="pt",return_token_type_ids=False)
inputs = inputs.to("cuda:0")

pred = model.generate(**inputs, max_length=50, temperature=0.0)
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))
# Beijing is a famous city in China. It is the capital of the Beijing Province and the largest city in China. It is also the home of the world’s largest city, Beijing.
#The city is the

2. 📑Checkpoint Details and Evaluation

Model Parameter

Model #Experts #Activated Experts #Params # Activated Params Flops(T) per sample (se q=2048) Model Weights
265M - - 265M 265M 0.48 🤗 llama-265m
8 $\times$ 265M MoE 8 2 970M 332M 0.76 🤗 llama-8x265m-moe
llama-7b - - 7B 7B 25.29

Model Evaluation

We use the "Average number of tokens verified" $N$ ( see reference link ) as the metric to evaluate these models. This metric demonstrates that giving the same input to the small speculative model and llama-7b, counting from the first predicted tokens, how many successive tokens in the output sentence of the small speculative model are the same as the output sentence of the llama-7b.

  • Average number of tokens verified
Dataset 8 $\times$ 265M MoE GPT without MoE
tatsu-lab/alpaca 3.2362 3.0334
alespalla/chatbot_instruction_prompts 3.2031 3.0823
web_questions 2.7201 2.5541
MohamedRashad/ChatGPT-prompts 3.0954 2.9768

Supposed that the small speculative model can have a hit rate $p$ for the next token when giving the same input. Then we have

1p+2p2+3p3+...=N 1p + 2p^2 + 3p^3 + ... = N

We can get the hit rate as follow.

p=1+11+4N2N p = 1 + \frac{1-\sqrt{1+4N}}{2N}

  • Hit Rate
Dataset 8 $\times$ 265M MoE GPT without MoE
tatsu-lab/alpaca 0.578 0.567
alespalla/chatbot_instruction_prompts 0.576 0.570
web_questions 0.550 0.540
MohamedRashad/ChatGPT-prompts 0.571 0.565

3. 🚧Limitation and Future Plans

For the MoE model, we only show the accuracy of how this small speculative model approximates the performance of llama-7b. In practice, to achieve physically low latency, the implementation of our MoE needs to be improved. In this version, we calculate the result of MoE expert by expert (sequentially) , and we need to fuse the calculation of these experts.

Acknowledgment

  1. My implementation of MoE structure is based on the repo https://huggingface.co/llama-moe/LLaMA-MoE-v1-3_5B-2_8
  2. My inspiration for Speculative Inference comes from the paper "SpecInfer: Accelerating Generative Large Language Model Serving with Tree-based Speculative Inference and Verification" (link) . I am very appreciative of the help and suggestions from the SpecInfer group. ❤️

Citation

@misc{specmoe-2024,
  title={SpecMoE: Building A Speculative MoE Model To Accelerate Inference},
  author={Juncai Liu},
  year={2024},
  month={March},
  url={https://github.com/JuncaiL/SpecMoE/}
}

Contact

If you have any interest or question about this project, please feel free to contact me.

liujc19@mails.tsinghua.edu.cn (before June 30, 2024) or liujc19@tsinghua.org.cn (After June 30, 2024)

Downloads last month
15
Inference Examples
Inference API (serverless) does not yet support model repos that contain custom code.

Datasets used to train JuncaiL/llama-8x265m-moe