Spico's picture
Upload folder using huggingface_hub
b386a77 verified
diff --git a/.gitignore b/.gitignore
index c243024..8c28ce3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -175,6 +175,7 @@ debug.py
wandb/
nohup.out
lm-evaluation-harness/
+bigcode-evaluation-harness/
results/**/*.json
results/**/*.jsonl
results/**/*.db
diff --git a/README.md b/README.md
index 8813a32..b276a78 100644
--- a/README.md
+++ b/README.md
@@ -26,6 +26,11 @@ bash scripts/data.sh
git clone https://github.com/EleutherAI/lm-evaluation-harness.git
cd lm-evaluation-harness
pip install -e .
+# commit: 9cfa52b
+git clone https://github.com/bigcode-project/bigcode-evaluation-harness.git
+cd bigcode-evaluation-harness
+# change `pyext==0.5` in `bigcode-evaluation-harness/requirements.txt`, ref: https://github.com/bigcode-project/bigcode-evaluation-harness/pull/181
+pip install -e .
```
## 📃 TODO
diff --git a/scripts/eval.sh b/scripts/eval.sh
deleted file mode 100644
index 4f41b37..0000000
--- a/scripts/eval.sh
+++ /dev/null
@@ -1,96 +0,0 @@
-# nohup srun -p MoE --gres gpu:1 bash scripts/eval.sh all /mnt/petrelfs/share_data/quxiaoye/models/Sheared-LLaMA-2.7B True results/Sheared-LLaMA-2.7B 1>logs/eval-all-Sheared-LLaMA-2.7B.log 2>&1 &
-
-mmlu() {
- # MMLU: https://github.com/princeton-nlp/LLM-Shearing/blob/20ebd2645a8ff5fa65874e1347f9891b80e01805/icl_eval/run_eval.sh#L18
- MODEL=$1
- TRUST_REMOTE_CODE=$2
- RESULT_DIR=$3
- mkdir -p $RESULT_DIR
-
- lm_eval \
- --model hf \
- --model_args pretrained=$MODEL,trust_remote_code=$TRUST_REMOTE_CODE \
- --tasks mmlu_computer_security,mmlu_high_school_chemistry,mmlu_philosophy,mmlu_elementary_mathematics,mmlu_prehistory,mmlu_formal_logic,mmlu_high_school_mathematics,mmlu_econometrics,mmlu_moral_scenarios,mmlu_college_mathematics,mmlu_high_school_government_and_politics,mmlu_us_foreign_policy,mmlu_high_school_world_history,mmlu_conceptual_physics,mmlu_college_medicine,mmlu_international_law,mmlu_abstract_algebra,mmlu_logical_fallacies,mmlu_machine_learning,mmlu_medical_genetics,mmlu_public_relations,mmlu_college_biology,mmlu_marketing,mmlu_electrical_engineering,mmlu_anatomy,mmlu_high_school_us_history,mmlu_high_school_biology,mmlu_miscellaneous,mmlu_high_school_psychology,mmlu_sociology,mmlu_business_ethics,mmlu_high_school_geography,mmlu_human_aging,mmlu_high_school_statistics,mmlu_moral_disputes,mmlu_professional_psychology,mmlu_global_facts,mmlu_college_physics,mmlu_nutrition,mmlu_high_school_macroeconomics,mmlu_world_religions,mmlu_professional_medicine,mmlu_high_school_computer_science,mmlu_college_chemistry,mmlu_human_sexuality,mmlu_high_school_microeconomics,mmlu_astronomy,mmlu_professional_accounting,mmlu_high_school_european_history,mmlu_jurisprudence,mmlu_professional_law,mmlu_high_school_physics,mmlu_virology,mmlu_management,mmlu_college_computer_science,mmlu_clinical_knowledge,mmlu_security_studies \
- --num_fewshot 5 \
- --device cuda:0 \
- --batch_size auto \
- --verbosity DEBUG \
- --output_path $RESULT_DIR/mmlu.json
-}
-
-bbh() {
- # Big Bench Hard (BBH): https://arxiv.org/pdf/2210.09261.pdf
- MODEL=$1
- TRUST_REMOTE_CODE=$2
- RESULT_DIR=$3
- mkdir -p $RESULT_DIR
-
- lm_eval \
- --log_samples \
- --model hf \
- --model_args pretrained=$MODEL,trust_remote_code=$TRUST_REMOTE_CODE \
- --tasks bbh_fewshot_boolean_expressions,bbh_fewshot_causal_judgement,bbh_fewshot_date_understanding,bbh_fewshot_disambiguation_qa,bbh_fewshot_dyck_languages,bbh_fewshot_formal_fallacies,bbh_fewshot_geometric_shapes,bbh_fewshot_hyperbaton,bbh_fewshot_logical_deduction_five_objects,bbh_fewshot_logical_deduction_seven_objects,bbh_fewshot_logical_deduction_three_objects,bbh_fewshot_movie_recommendation,bbh_fewshot_multistep_arithmetic_two,bbh_fewshot_navigate,bbh_fewshot_object_counting,bbh_fewshot_penguins_in_a_table,bbh_fewshot_reasoning_about_colored_objects,bbh_fewshot_ruin_names,bbh_fewshot_salient_translation_error_detection,bbh_fewshot_snarks,bbh_fewshot_sports_understanding,bbh_fewshot_temporal_sequences,bbh_fewshot_tracking_shuffled_objects_five_objects,bbh_fewshot_tracking_shuffled_objects_seven_objects,bbh_fewshot_tracking_shuffled_objects_three_objects,bbh_fewshot_web_of_lies,bbh_fewshot_word_sorting \
- --device cuda:0 \
- --batch_size auto \
- --verbosity DEBUG \
- --output_path $RESULT_DIR/bbh.json
-}
-
-reasoning() {
- MODEL=$1
- TRUST_REMOTE_CODE=$2
- RESULT_DIR=$3
- mkdir -p $RESULT_DIR
-
- lm_eval \
- --log_samples \
- --model hf \
- --model_args pretrained=$MODEL,trust_remote_code=$TRUST_REMOTE_CODE \
- --tasks gsm8k_cot \
- --device cuda:0 \
- --batch_size auto \
- --verbosity DEBUG \
- --output_path $RESULT_DIR/reasoning.json
-}
-
-qa() {
- MODEL=$1
- TRUST_REMOTE_CODE=$2
- RESULT_DIR=$3
- mkdir -p $RESULT_DIR
-
- lm_eval \
- --log_samples \
- --model hf \
- --model_args pretrained=$MODEL,trust_remote_code=$TRUST_REMOTE_CODE \
- --tasks arc_easy,arc_challenge,boolq \
- --num_fewshot 0 \
- --device cuda:0 \
- --batch_size auto \
- --verbosity DEBUG \
- --output_path $RESULT_DIR/qa.json
-}
-
-EVAL_TASK=$1
-shift 1
-start=$(date +%s)
-case $EVAL_TASK in
- mmlu)
- mmlu $* ;;
- bbh)
- bbh $* ;;
- reasoning)
- reasoning $* ;;
- qa)
- qa $* ;;
- all)
- mmlu $*
- bbh $*
- reasoning $*
- qa $*
- ;;
- *)
- echo "$EVAL_TASK not recognized!";;
-esac
-end=$(date +%s)
-echo "Elapsed Time: $(($end-$start)) seconds"
diff --git a/scripts/four_mix/freeze_gate.sh b/scripts/four_mix/freeze_gate.sh
index d94d78c..70afb8e 100644
--- a/scripts/four_mix/freeze_gate.sh
+++ b/scripts/four_mix/freeze_gate.sh
@@ -83,8 +83,11 @@ num_gpus=4
python -m src.eval.gen_mt_ans \
--model-path $output_dir \
- --model-id $task_name \
- --num-gpus-total $num_gpus
+ --model-id $task_name
+
+ python -m src.eval.gen_alpaca_eval_ans \
+ --model-path $output_dir \
+ --model-id $task_name
}
# nohup srun -p MoE --ntasks-per-node=1 --cpus-per-task=16 --mem=128G --nodes=1 --gres=gpu:4 bash "/mnt/petrelfs/zhutong/adaptive-sft-for-moe/scripts/one_data_steps_dynamic.sh" "llama_moe_orca_epochs_cluster_4" "auto" "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new" "data/open_orca_clustered/4" "data/open_orca_clustered_eval/4" 1>logs/llama_moe_orca_cluster_4_dynamic.log 2>&1 &
diff --git a/scripts/gen_mt_bench_ans.sh b/scripts/gen_mt_bench_ans.sh
deleted file mode 100644
index f251644..0000000
--- a/scripts/gen_mt_bench_ans.sh
+++ /dev/null
@@ -1,32 +0,0 @@
-#!/usr/bin/bash
-
-#SBATCH --job-name=moe_gen
-#SBATCH --output=logs/%x-%j.log
-#SBATCH --error=logs/%x-%j.log
-
-#SBATCH --partition=MoE
-#SBATCH --ntasks-per-node=1
-#SBATCH --cpus-per-task=16
-#SBATCH --mem=64G
-
-#SBATCH --nodes=1
-#SBATCH --gres=gpu:1
-#SBATCH --quotatype=auto
-
-{
- # python -m fastchat.llm_judge.gen_model_answer \
- # --model-path outputs/sheared_llama_sharegpt/moe_sft-2411306 \
- # --model-id sheared_llama_sharegpt
-
- # python -m fastchat.llm_judge.gen_model_answer \
- # --model-path outputs/sheared_llama_uniform_mix/moe_sft-2421072 \
- # --model-id sheared_llama_uniform_mix
-
- bash scripts/cp_model_files.sh outputs/llama_moe/moe_sft-2409782
- python -m fastchat.llm_judge.gen_model_answer \
- --model-path outputs/llama_moe/moe_sft-2409782 \
- --model-id llama_moe_uniform_mix
-}
-
-# nohup srun -p MoE -n1 -N1 --gres=gpu:1 --quotatype spot python -m fastchat.llm_judge.gen_model_answer --model-path outputs/sheared_llama_sharegpt/moe_sft-2411306 --model-id sheared_llama_sharegpt 1>logs/mt_bench_gen_sheared_llama_sharegpt.log 2>&1 &
-# nohup srun -p MoE -n1 -N1 --gres=gpu:1 --quotatype spot python -m fastchat.llm_judge.gen_model_answer --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/llama_moe_sharegpt/moe_sft-2411309 --model-id llama_moe_sharegpt 1>logs/mt_bench_gen_llama_moe_sharegpt.log 2>&1 &
diff --git a/scripts/multi.sh b/scripts/multi.sh
index bcd83b8..e399761 100644
--- a/scripts/multi.sh
+++ b/scripts/multi.sh
@@ -100,5 +100,8 @@ nohup srun -p MoE --ntasks-per-node=1 --cpus-per-task=16 --mem=128G --nodes=1 --
nohup srun -p MoE --gres gpu:1 python -m src.eval.gen_mt_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/llama_moe_four_mix_uniform/bash-2485396 --model-id llama_moe_four_mix_uniform 1>logs/gen_mt_ans-llama_moe_four_mix_uniform.log 2>&1 &
nohup srun -p MoE --gres gpu:1 python -m src.eval.gen_mt_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/sheared_four_mix_uniform/bash-2485397 --model-id sheared_four_mix_uniform 1>logs/gen_mt_ans-sheared_four_mix_uniform.log 2>&1 &
-nohup srun -p MoE --gres gpu:1 python -m src.eval.get_alpaca_eval_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/llama_moe_four_mix_uniform/bash-2485396 --model-id llama_moe_four_mix_uniform 1>logs/gen_alpaca_eval-llama_moe_four_mix_uniform.log 2>&1 &
-nohup srun -p MoE --gres gpu:1 python -m src.eval.get_alpaca_eval_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/sheared_four_mix_uniform/bash-2485397 --model-id sheared_four_mix_uniform 1>logs/gen_alpaca_eval-sheared_four_mix_uniform.log 2>&1 &
+nohup srun -p MoE --gres gpu:1 python -m src.eval.gen_alpaca_eval_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/llama_moe_four_mix_uniform/bash-2485396 --model-id llama_moe_four_mix_uniform 1>logs/gen_alpaca_eval-llama_moe_four_mix_uniform.log 2>&1 &
+nohup srun -p MoE --gres gpu:1 python -m src.eval.gen_alpaca_eval_ans --model-path /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048/sheared_four_mix_uniform/bash-2485397 --model-id sheared_four_mix_uniform 1>logs/gen_alpaca_eval-sheared_four_mix_uniform.log 2>&1 &
+
+nohup srun -p MoE --gres gpu:1 bash scripts/eval/eval.sh reasoning /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad_wo_gate_noise/moe_sft-2492650 True results/llama_moe_four_mix_wo_pad_wo_gate_noise 1>logs/eval-reasoning-llama_moe_four_mix_wo_pad_wo_gate_noise.log 2>&1 &
+nohup srun -p MoE --gres gpu:1 bash scripts/eval/eval.sh reasoning /mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad/moe_sft-2491633 True results/llama_moe_four_mix_wo_pad 1>logs/eval-reasoning-llama_moe_four_mix_wo_pad.log 2>&1 &
diff --git a/src/callbacks.py b/src/callbacks.py
index a750f69..e9d0c04 100644
--- a/src/callbacks.py
+++ b/src/callbacks.py
@@ -6,6 +6,7 @@ import torch
import numpy as np
from loguru import logger
from transformers.trainer_callback import TrainerCallback, TrainerState, TrainerControl
+from transformers.utils import is_flash_attn_2_available
from src.utils.config import TrainingArguments
from src.utils.io import append_jsonlines
@@ -22,6 +23,7 @@ class AdaptiveSamplingCallback(TrainerCallback):
criterion: Optional[Literal["min", "max", "mean"]] = "mean",
sim_type: Optional[Literal["cos", "l2"]] = "cos",
):
+ assert is_flash_attn_2_available(), "Make sure you have flash-attn installed"
self.criterion = criterion
self.sim_type = sim_type
self.prob_map = {}
@@ -74,8 +76,8 @@ class AdaptiveSamplingCallback(TrainerCallback):
cls,
ori_weights: np.ndarray,
delta: np.ndarray,
- eta: float = 1.0,
- c: float = 1e-4,
+ eta: float = 10.0,
+ c: float = 5e-2,
) -> np.ndarray:
def _softmax(vec: np.ndarray) -> np.ndarray:
exps = np.exp(vec - np.max(vec))
diff --git a/src/core/train.py b/src/core/train.py
index 2be5558..9b1f694 100644
--- a/src/core/train.py
+++ b/src/core/train.py
@@ -7,13 +7,12 @@ from loguru import logger
from src.utils.config import ModelArguments, DataArguments, TrainingArguments
from src.data import (
SubDirWeightedPackedJsonlDataset,
- get_uniform_sampling_ratio,
fault_tolerance_data_collator,
CachedJsonlDataset,
get_cached_datasets_from_dir,
)
from src.utils.io import trainer_save_model_safe
-from src.models import LlamaMoEForCausalLM, LlamaMoEConfig
+from src.models import LlamaMoEForCausalLM, LlamaMoEConfig, DeepseekConfig, DeepseekForCausalLM
from src.trainer import GateLoadRecordingTrainer
from src.callbacks import AdaptiveSamplingCallback
@@ -36,6 +35,9 @@ def get_model_and_tokenizer(
elif model_type == "llama_moe":
ConfigClass = LlamaMoEConfig
ModelClass = LlamaMoEForCausalLM
+ elif model_type == "deepseek":
+ ConfigClass = DeepseekConfig
+ ModelClass = DeepseekForCausalLM
else:
raise ValueError(f"Unknown model type: {model_type}")
@@ -54,6 +56,21 @@ def get_model_and_tokenizer(
config.update(additional_config)
logger.info("Config ready")
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ model_name_or_path,
+ cache_dir=cache_dir,
+ model_max_length=model_max_length,
+ padding_side=padding_side,
+ use_fast=False,
+ trust_remote_code=trust_remote_code,
+ )
+ if tokenizer.pad_token is None:
+ if tokenizer.unk_token is not None:
+ tokenizer.pad_token = tokenizer.unk_token
+ else:
+ tokenizer.pad_token = tokenizer.eos_token
+ logger.info(f"tokenizer ready, pad_token: {tokenizer.pad_token}")
+
# Load model and tokenizer
model = ModelClass.from_pretrained(
model_name_or_path,
@@ -65,18 +82,6 @@ def get_model_and_tokenizer(
)
logger.info("model ready")
- tokenizer = transformers.AutoTokenizer.from_pretrained(
- model_name_or_path,
- cache_dir=cache_dir,
- model_max_length=model_max_length,
- padding_side=padding_side,
- use_fast=False,
- trust_remote_code=trust_remote_code,
- )
- if tokenizer.pad_token != tokenizer.unk_token:
- tokenizer.pad_token = tokenizer.unk_token
- logger.info("tokenizer ready")
-
return model, tokenizer
@@ -117,7 +122,9 @@ def train():
train_dataset = SubDirWeightedPackedJsonlDataset(
data_args.dataset_dir_or_path,
tokenizer,
- prob_map=get_uniform_sampling_ratio(data_args.dataset_dir_or_path),
+ # prob_map=get_uniform_sampling_ratio(data_args.dataset_dir_or_path),
+ # prob_map={"code": 0.25119094959816823, "math": 0.2674581878910902, "orca": 0.243050776175138, "sharegpt": 0.23830008633560357},
+ prob_map=data_args.prob_map,
seed=training_args.seed,
)
elif datapath.is_file():
diff --git a/src/data.py b/src/data.py
index d783a21..a1a8ff7 100644
--- a/src/data.py
+++ b/src/data.py
@@ -20,6 +20,7 @@ def preprocess(
instances,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
+ tokenizer_legacy = getattr(tokenizer, "legacy", None)
conv = Conversation()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
@@ -72,7 +73,7 @@ def preprocess(
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
- if i != 0 and not tokenizer.legacy:
+ if i != 0 and not tokenizer_legacy:
# The legacy and non-legacy modes handle special tokens differently
instruction_len -= 1
@@ -80,7 +81,7 @@ def preprocess(
target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len
- if i != 0 and not tokenizer.legacy:
+ if i != 0 and not tokenizer_legacy:
# The legacy and non-legacy modes handle special tokens differently
cur_len -= 1
diff --git a/src/eval/get_alpaca_eval_ans.py b/src/eval/get_alpaca_eval_ans.py
deleted file mode 100644
index 1ff3e5e..0000000
--- a/src/eval/get_alpaca_eval_ans.py
+++ /dev/null
@@ -1,113 +0,0 @@
-import argparse
-from pathlib import Path
-
-import torch
-import datasets
-from tqdm import tqdm
-
-from src.core.train import get_model_and_tokenizer
-from src.utils.conversation import Conversation
-from src.utils.io import dump_json
-
-
-@torch.inference_mode()
-def run_eval(model_path, model_id, max_new_tokens):
- model, tokenizer = get_model_and_tokenizer(
- "auto",
- model_path,
- torch_dtype=torch.bfloat16,
- trust_remote_code=True,
- )
- model.cuda()
- model.eval()
-
- conv = Conversation()
- outputs = []
- eval_set = datasets.load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval")["eval"]
- for example in tqdm(eval_set, desc="Eval"):
- conv.append_message(conv.roles[0], example["instruction"])
- conv.append_message(conv.roles[1], None)
- prompt = conv.get_prompt()
- input_ids = tokenizer([prompt], return_tensors="pt").input_ids
- conv.clear_msg()
- # generate here is a placeholder for your models generations
- output_ids = model.generate(
- input_ids.cuda(),
- do_sample=False,
- temperature=0.0,
- max_new_tokens=max_new_tokens,
- )
- if model.config.is_encoder_decoder:
- output_ids = output_ids[0]
- else:
- output_ids = output_ids[0][len(input_ids[0]) :] # noqa: E203
- # be consistent with the template's stop_token_ids
- if conv.stop_token_ids:
- stop_token_ids_index = [
- i
- for i, id in enumerate(output_ids)
- if id in conv.stop_token_ids
- ]
- if len(stop_token_ids_index) > 0:
- output_ids = output_ids[: stop_token_ids_index[0]]
-
- output = tokenizer.decode(
- output_ids,
- spaces_between_special_tokens=False,
- )
- if conv.stop_str and isinstance(conv.stop_str, list):
- stop_str_indices = sorted(
- [
- output.find(stop_str)
- for stop_str in conv.stop_str
- if output.find(stop_str) > 0
- ]
- )
- if len(stop_str_indices) > 0:
- output = output[: stop_str_indices[0]]
- elif conv.stop_str and output.find(conv.stop_str) > 0:
- output = output[: output.find(conv.stop_str)]
-
- for special_token in tokenizer.special_tokens_map.values():
- if isinstance(special_token, list):
- for special_tok in special_token:
- output = output.replace(special_tok, "")
- else:
- output = output.replace(special_token, "")
- output = output.strip()
-
- if conv.name == "xgen" and output.startswith("Assistant:"):
- output = output.replace("Assistant:", "", 1).strip()
-
- example["output"] = output
- outputs.append(example)
-
- outpath = Path("results/alpaca_eval") / f"{model_id}.json"
- dump_json(outputs, outpath, indent=2)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--model-path",
- type=str,
- required=True,
- help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
- )
- parser.add_argument(
- "--model-id", type=str, required=True, help="A custom name for the model."
- )
- parser.add_argument(
- "--max-new-token",
- type=int,
- default=1024,
- help="The maximum number of new generated tokens.",
- )
-
- args = parser.parse_args()
-
- run_eval(
- model_path=args.model_path,
- model_id=args.model_id,
- max_new_tokens=args.max_new_token,
- )
diff --git a/src/eval/show.py b/src/eval/show.py
index d500054..ea0c210 100644
--- a/src/eval/show.py
+++ b/src/eval/show.py
@@ -55,13 +55,13 @@ def collect_results(result_dir: str, verbose: bool = True) -> dict:
avg = sum(vals) / len(vals)
tot_vals.append(avg)
if verbose:
- logger.info(f"task: {name}, num: {len(tasks.split(','))}, avg: {avg:.3%}")
+ logger.info(f"task: {name}, num: {len(tasks.split(','))}, avg: {100 * avg:.3f} %")
if len(tot_vals) == 0:
tot_avg = 0.0
else:
tot_avg = sum(tot_vals) / len(tot_vals)
- logger.info(f"total avg: {tot_avg:.3%}")
+ logger.info(f"total avg: {100 * tot_avg:.3f} %")
if __name__ == "__main__":
diff --git a/src/models/deepseek/modeling_deepseek.py b/src/models/deepseek/modeling_deepseek.py
index 1dae56e..20498b2 100644
--- a/src/models/deepseek/modeling_deepseek.py
+++ b/src/models/deepseek/modeling_deepseek.py
@@ -20,6 +20,7 @@
""" PyTorch DeepSeek model."""
import math
import warnings
+from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
@@ -297,7 +298,7 @@ class DeepseekMLP(nn.Module):
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
+ def forward(self, x, **kwargs):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
@@ -328,7 +329,9 @@ class DeepseekMLP(nn.Module):
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
+ bsz, seq_len, _ = x.shape
+ load = torch.zeros(bsz * seq_len, self.config.n_routed_experts)
+ return down_proj, load
class MoEGate(nn.Module):
@@ -356,7 +359,10 @@ class MoEGate(nn.Module):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, hidden_states):
- bsz, seq_len, h = hidden_states.shape
+ if len(hidden_states.shape) == 2:
+ bsz, h = hidden_states.shape
+ else:
+ bsz, seq_len, h = hidden_states.shape
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
@@ -404,7 +410,10 @@ class MoEGate(nn.Module):
aux_loss = (Pi * fi).sum() * self.alpha
else:
aux_loss = None
- return topk_idx, topk_weight, aux_loss
+ _zeros = torch.zeros_like(logits)
+ _scores_filtered = _zeros.scatter(dim=1, index=topk_idx, src=topk_weight)
+ load = (_scores_filtered > 0).sum(0)
+ return topk_idx, topk_weight, aux_loss, load
class AddAuxiliaryLoss(torch.autograd.Function):
@@ -450,10 +459,19 @@ class DeepseekMoE(nn.Module):
config=config, intermediate_size=intermediate_size
)
- def forward(self, hidden_states):
+ def forward(self, hidden_states, attention_mask=None):
+ bsz, seq_len, hsz = hidden_states.shape
+ hidden_states = hidden_states.reshape(-1, hsz)
+ flattened_mask = None
+ flattened_shape = None
+ if attention_mask is not None and len(attention_mask.shape) == 2:
+ flattened_mask = attention_mask.flatten()
+ flattened_shape = flattened_mask.shape
+ hidden_states = hidden_states[flattened_mask.bool()]
+
identity = hidden_states
orig_shape = hidden_states.shape
- topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
+ topk_idx, topk_weight, aux_loss, load = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
@@ -472,7 +490,15 @@ class DeepseekMoE(nn.Module):
).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
- return y
+
+ if flattened_mask is not None:
+ _y = torch.zeros(flattened_shape + (hsz,), dtype=y.dtype, device=y.device)
+ _y[flattened_mask.bool()] = y
+ y = _y
+
+ y = y.reshape(bsz, seq_len, hsz)
+
+ return y, load
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
@@ -1163,7 +1189,7 @@ class DeepseekDecoderLayer(nn.Module):
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
+ hidden_states, load = self.mlp(hidden_states, attention_mask=attention_mask)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
@@ -1174,6 +1200,8 @@ class DeepseekDecoderLayer(nn.Module):
if use_cache:
outputs += (present_key_value,)
+ outputs += (load,)
+
return outputs
@@ -1220,6 +1248,11 @@ class DeepseekPreTrainedModel(PreTrainedModel):
module.weight.data[module.padding_idx].zero_()
+@dataclass
+class BaseMoEModelOutputWithPast(BaseModelOutputWithPast):
+ gate_load: Optional[torch.Tensor] = None
+
+
Deepseek_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1429,6 +1462,7 @@ class DeepseekModel(DeepseekPreTrainedModel):
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
+ gate_load = ()
next_decoder_cache = None
for decoder_layer in self.layers:
@@ -1463,6 +1497,8 @@ class DeepseekModel(DeepseekPreTrainedModel):
if output_attentions:
all_self_attns += (layer_outputs[1],)
+ gate_load += (layer_outputs[-1],)
+
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
@@ -1482,14 +1518,20 @@ class DeepseekModel(DeepseekPreTrainedModel):
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
- return BaseModelOutputWithPast(
+ return BaseMoEModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
+ gate_load=gate_load,
)
+@dataclass
+class MoECausalLMOutputWithPast(CausalLMOutputWithPast):
+ gate_load: Optional[torch.Tensor] = None
+
+
class DeepseekForCausalLM(DeepseekPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
@@ -1620,12 +1662,13 @@ class DeepseekForCausalLM(DeepseekPreTrainedModel):
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
- return CausalLMOutputWithPast(
+ return MoECausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
+ gate_load=outputs.gate_load,
)
def prepare_inputs_for_generation(
diff --git a/src/utils/config.py b/src/utils/config.py
index 3ea5283..d4060d9 100644
--- a/src/utils/config.py
+++ b/src/utils/config.py
@@ -6,6 +6,7 @@ import torch
import transformers
from src.utils.io import load_json
+from src.data import get_uniform_sampling_ratio
@dataclass
@@ -33,7 +34,9 @@ class ModelArguments:
)
attn_impl: str = field(
default="flash_attention_2",
- metadata={"help": "attention implementation, choice from [eager, flash_attention_2, sdpa] (default: `flash_attention_2`)"}
+ metadata={
+ "help": "attention implementation, choice from [eager, flash_attention_2, sdpa] (default: `flash_attention_2`)"
+ },
)
def __post_init__(self):
@@ -56,6 +59,18 @@ class DataArguments:
default="data/merged",
metadata={"help": "Path to dataset directory or a single jsonl file"},
)
+ prob_map: str = field(
+ default=None,
+ metadata={"help": "Path to the probability map file"},
+ )
+
+ def __post_init__(self):
+ if self.prob_map is not None:
+ if not pathlib.Path(self.prob_map).exists():
+ raise ValueError(f"Probability map file {self.prob_map} not found")
+ self.prob_map = load_json(self.prob_map)
+ else:
+ self.prob_map = get_uniform_sampling_ratio(self.dataset_dir_or_path)
@dataclass
@@ -70,9 +85,7 @@ class TrainingArguments(transformers.TrainingArguments):
)
max_eval_steps_per_type: int = field(
default=10,
- metadata={
- "help": "Maximum number of steps to perform during evaluation."
- },
+ metadata={"help": "Maximum number of steps to perform during evaluation."},
)
dynamic_sampling_sim_type: Literal["cos", "l2"] = field(
default="l2",
@@ -88,7 +101,5 @@ class TrainingArguments(transformers.TrainingArguments):
)
freeze_gate: bool = field(
default=False,
- metadata={
- "help": "Whether to freeze the gate during training."
- },
+ metadata={"help": "Whether to freeze the gate during training."},
)
diff --git a/src/utils/visualization.py b/src/utils/visualization.py
index 794f6c8..02bd236 100644
--- a/src/utils/visualization.py
+++ b/src/utils/visualization.py
@@ -180,6 +180,86 @@ def gate_load_stats(model_dir, data_dir, result_dir, update_strategy: str = "cos
)
+def sampling_info_stats(filepath: str, data_type: str, output_dir: str):
+ from pathlib import Path
+ import numpy as np
+ from src.utils.io import load_jsonlines
+
+ Path(output_dir).mkdir(exist_ok=True, parents=True)
+
+ data = load_jsonlines(filepath)
+ step2data = {ins["step"]: ins for ins in data}
+
+ data_types = sorted(data[0]["old_prob_map"].keys())
+ data_type_idx = data_types.index(data_type)
+
+ probs = []
+ loads = []
+ sims = []
+ steps = sorted(step2data.keys())
+ for step in steps:
+ ins = step2data[step]
+ probs.append(ins["old_prob_map"][data_type])
+ loads.append(ins["name2load"][data_type])
+ sims.append(ins["sim"][data_type_idx])
+
+ # probs
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ ax.plot(steps, probs)
+ ax.set_title(f"Sampling Probability of {data_type}")
+ ax.set_xlabel("step")
+ fig.savefig(f"{output_dir}/prob-{data_type}.png")
+
+ # loads
+ def cv_square(data):
+ return np.var(data, axis=1) / (np.mean(data, axis=1)**2 + 1e-10)
+
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ ax.plot(steps, cv_square(loads))
+ ax.set_title(f"cv(load)^2 of {data_type}")
+ ax.set_xlabel("step")
+ fig.savefig(f"{output_dir}/load_cv-{data_type}.png")
+
+ # sims
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ ax.plot(steps, np.mean(sims, axis=1))
+ ax.set_title(f"Mean Similarities with {data_type}")
+ ax.set_xlabel("step")
+ fig.savefig(f"{output_dir}/sim-{data_type}.png")
+
+
+def test_sampling_convergence():
+ from collections import defaultdict
+ from src.callbacks import AdaptiveSamplingCallback
+
+ # freeze gate
+ name2load = {"code": [0.1359794776119403, 0.1333115671641791, 0.12858208955223882, 0.10330223880597016, 0.12544776119402984, 0.12625932835820897, 0.12761194029850748, 0.11950559701492537], "orca": [0.1509941502743006, 0.11721425756978752, 0.1232988815809414, 0.12714439426545024, 0.11256554420634679, 0.14008274482465977, 0.11819552632376563, 0.11050450095474797], "math": [0.15956486572028086, 0.10727138452881943, 0.11506675888262392, 0.10958069091633744, 0.11805010139847842, 0.11915200393871546, 0.13648938539627462, 0.13482480921846976], "sharegpt": [0.15337086599959998, 0.11428233411553493, 0.12873151621889287, 0.1177436980734424, 0.11538123789498336, 0.13793986642403783, 0.12419686111124664, 0.10835362016226212]} # fmt: skip
+ # # dynamic
+ # name2load = {"code": [0.14031716417910448, 0.1310634328358209, 0.12651119402985075, 0.10993470149253731, 0.12196828358208955, 0.12552238805970148, 0.12791977611940297, 0.11676305970149255], "orca": [0.15106234655836084, 0.11803640166095838, 0.12349968175067437, 0.12884551268450883, 0.11344072985178673, 0.1383778377231534, 0.11733170672566907, 0.1094057830448883], "math": [0.16001617686708006, 0.10756444371505268, 0.11391210568886491, 0.114803005615014, 0.11676650216277679, 0.1177863481308685, 0.13630182751708533, 0.13284959030325763], "sharegpt": [0.15440024978412215, 0.113654214863131, 0.12914741653941664, 0.12104040941178769, 0.11470799162832905, 0.13593110446537907, 0.12316259873058931, 0.10795601457724527]} # fmt: skip
+ names = sorted(name2load.keys())
+ callback = AdaptiveSamplingCallback()
+ callback.prob_map = {"code": 0.25, "math": 0.25, "orca": 0.25, "sharegpt": 0.25}
+ name2probs = defaultdict(list)
+ for _ in range(100):
+ for name in names:
+ name2probs[name].append(callback.prob_map[name])
+ new_name2prob, _ = callback._update_prob_map(name2load)
+ callback.prob_map = new_name2prob
+ print(f"final prob_map: {callback.prob_map}")
+
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ for name in names:
+ ax.plot(name2probs[name], label=name)
+ ax.legend()
+ ax.set_title("Sampling Probability")
+ ax.set_xlabel("step")
+ fig.savefig("results/sampling_convergence.png")
+
+
if __name__ == "__main__":
# gate_load_stats(
# "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new",
@@ -195,12 +275,12 @@ if __name__ == "__main__":
# "results/gate_load_vis_llama_moe_2_8_orca_4clusters",
# )
- gate_load_stats(
- "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new",
- "data/four_types_mix/dev",
- "results/debug",
- update_strategy="l2",
- )
+ # gate_load_stats(
+ # "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new",
+ # "data/four_types_mix/dev",
+ # "results/debug",
+ # update_strategy="l2",
+ # )
# gate_load_stats(
# "/mnt/petrelfs/zhutong/llama-moe-models/LLaMA-MoE-v1-3_5B-2_8-new",
@@ -227,3 +307,29 @@ if __name__ == "__main__":
# "results/gate_load_vis_llama_moe_2_8_four_types_mix_l2",
# update_strategy="l2"
# )
+
+ # sampling_info_stats(
+ # "/mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad_freeze_gate/moe_sft-2491632/sampling_info/data.jsonl",
+ # "code",
+ # "results/sampling_info/llama_moe_four_mix_wo_pad_freeze_gate/code",
+ # )
+
+ # sampling_info_stats(
+ # "/mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad/moe_sft-2491633/sampling_info/data.jsonl",
+ # "code",
+ # "results/sampling_info/llama_moe_four_mix_wo_pad/code",
+ # )
+
+ # sampling_info_stats(
+ # "/mnt/petrelfs/zhutong/adaptive-sft-for-moe/outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad_freeze_gate_wo_gate_noise/moe_sft-2493315/sampling_info/data.jsonl",
+ # "code",
+ # "results/sampling_info/llama_moe_four_mix_wo_pad_freeze_gate_wo_gate_noise/code",
+ # )
+
+ # sampling_info_stats(
+ # "outputs/len2048_dynamic_remove_padding_tokens/llama_moe_four_mix_wo_pad_wo_gate_noise/moe_sft-2492650/sampling_info/data.jsonl",
+ # "code",
+ # "results/sampling_info/llama_moe_four_mix_wo_pad_wo_gate_noise/code",
+ # )
+
+ test_sampling_convergence()