Upload folder using huggingface_hub
Browse files- lm_eval/.gitignore +10 -0
- lm_eval/.python-version +1 -0
- lm_eval/README.md +57 -0
- lm_eval/arc_challenge.yaml +26 -0
- lm_eval/arc_easy_mi.yaml +26 -0
- lm_eval/eval.py +65 -0
- lm_eval/lambada_openai_norm.yaml +23 -0
- lm_eval/main.py +6 -0
- lm_eval/pyproject.toml +64 -0
- lm_eval/requirements.txt +49 -0
- lm_eval/uv.lock +0 -0
- modeling_cloverlm.py +15 -3
lm_eval/.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
lm_eval/.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
lm_eval/README.md
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Environment Setup
|
| 2 |
+
|
| 3 |
+
Download this directory to a local machine and set up [`uv`](https://docs.astral.sh/uv/).
|
| 4 |
+
|
| 5 |
+
1. **Install `uv`** (if you haven't already):
|
| 6 |
+
```bash
|
| 7 |
+
curl -LsSf [https://astral.sh/uv/install.sh](https://astral.sh/uv/install.sh) | sh
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
2. **Sync the environment:**
|
| 11 |
+
```bash
|
| 12 |
+
uv sync
|
| 13 |
+
```
|
| 14 |
+
*(This automatically creates a virtual environment at `.venv` and strictly installs the dependencies locked in `uv.lock`.)*
|
| 15 |
+
|
| 16 |
+
3. **Activate the environment:**
|
| 17 |
+
```bash
|
| 18 |
+
source .venv/bin/activate`
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
### Evaluation Script
|
| 22 |
+
|
| 23 |
+
Run:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
accelerate launch eval.py \
|
| 27 |
+
--model cloverlm \
|
| 28 |
+
--model_args "pretrained=/localhome/apanfero/models/CloverLM,dtype=bfloat16,quartet_2_impl=quartet2,attn_backend=pytorch" \
|
| 29 |
+
--tasks "arc_easy_mi,arc_challenge_mi,hellaswag,piqa" \
|
| 30 |
+
--num_fewshot 0 \
|
| 31 |
+
--include_path ./ \
|
| 32 |
+
--trust_remote_code \
|
| 33 |
+
--confirm_run_unsafe_code \
|
| 34 |
+
--batch_size auto
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### Expected Evaluation Results
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr|
|
| 41 |
+
|----------------|------:|------|-----:|---------------|---|-----:|---|-----:|
|
| 42 |
+
|arc_challenge_mi| 1|none | 0|acc |↑ |0.4625|± |0.0146|
|
| 43 |
+
| | |none | 0|acc_mutual_info|↑ |0.5094|± |0.0146|
|
| 44 |
+
| | |none | 0|acc_norm |↑ |0.4923|± |0.0146|
|
| 45 |
+
|arc_easy_mi | 1|none | 0|acc |↑ |0.7997|± |0.0082|
|
| 46 |
+
| | |none | 0|acc_mutual_info|↑ |0.7239|± |0.0092|
|
| 47 |
+
| | |none | 0|acc_norm |↑ |0.7731|± |0.0086|
|
| 48 |
+
|hellaswag | 1|none | 0|acc |↑ |0.5392|± |0.0050|
|
| 49 |
+
| | |none | 0|acc_norm |↑ |0.7167|± |0.0045|
|
| 50 |
+
|piqa | 1|none | 0|acc |↑ |0.7922|± |0.0095|
|
| 51 |
+
| | |none | 0|acc_norm |↑ |0.8058|± |0.0092|
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### Alternative Backends
|
| 55 |
+
|
| 56 |
+
Replace `quartet_2_impl=quartet2` with `quartet_2_impl=pseudoquant` on non-Blackwell GPUs.
|
| 57 |
+
You can try `attn_backend=pytorch/flash2/flash3/flash4` if you have the corresponding libs installed.
|
lm_eval/arc_challenge.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tag:
|
| 2 |
+
- ai2_arc
|
| 3 |
+
task: arc_challenge_mi
|
| 4 |
+
dataset_path: allenai/ai2_arc
|
| 5 |
+
dataset_name: ARC-Challenge
|
| 6 |
+
output_type: multiple_choice
|
| 7 |
+
training_split: train
|
| 8 |
+
validation_split: validation
|
| 9 |
+
test_split: test
|
| 10 |
+
doc_to_text: "Question: {{question}}\nAnswer:"
|
| 11 |
+
doc_to_target: "{{choices.label.index(answerKey)}}"
|
| 12 |
+
doc_to_choice: "{{choices.text}}"
|
| 13 |
+
should_decontaminate: true
|
| 14 |
+
doc_to_decontamination_query: "Question: {{question}}\nAnswer:"
|
| 15 |
+
metric_list:
|
| 16 |
+
- metric: acc
|
| 17 |
+
aggregation: mean
|
| 18 |
+
higher_is_better: true
|
| 19 |
+
- metric: acc_mutual_info
|
| 20 |
+
aggregation: mean
|
| 21 |
+
higher_is_better: true
|
| 22 |
+
- metric: acc_norm
|
| 23 |
+
aggregation: mean
|
| 24 |
+
higher_is_better: true
|
| 25 |
+
metadata:
|
| 26 |
+
version: 1.0
|
lm_eval/arc_easy_mi.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tag:
|
| 2 |
+
- ai2_arc
|
| 3 |
+
task: arc_easy_mi
|
| 4 |
+
dataset_path: allenai/ai2_arc
|
| 5 |
+
dataset_name: ARC-Easy
|
| 6 |
+
output_type: multiple_choice
|
| 7 |
+
training_split: train
|
| 8 |
+
validation_split: validation
|
| 9 |
+
test_split: test
|
| 10 |
+
doc_to_text: "Question: {{question}}\nAnswer:"
|
| 11 |
+
doc_to_target: "{{choices.label.index(answerKey)}}"
|
| 12 |
+
doc_to_choice: "{{choices.text}}"
|
| 13 |
+
should_decontaminate: true
|
| 14 |
+
doc_to_decontamination_query: "Question: {{question}}\nAnswer:"
|
| 15 |
+
metric_list:
|
| 16 |
+
- metric: acc
|
| 17 |
+
aggregation: mean
|
| 18 |
+
higher_is_better: true
|
| 19 |
+
- metric: acc_mutual_info
|
| 20 |
+
aggregation: mean
|
| 21 |
+
higher_is_better: true
|
| 22 |
+
- metric: acc_norm
|
| 23 |
+
aggregation: mean
|
| 24 |
+
higher_is_better: true
|
| 25 |
+
metadata:
|
| 26 |
+
version: 1.0
|
lm_eval/eval.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
from lm_eval.api.registry import register_model
|
| 5 |
+
from lm_eval.models.huggingface import HFLM
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@register_model("cloverlm")
|
| 9 |
+
class CloverLMHFLM(HFLM):
|
| 10 |
+
def __init__(self, pad_multiple=128, **kwargs):
|
| 11 |
+
super().__init__(**kwargs)
|
| 12 |
+
self.pad_multiple = pad_multiple
|
| 13 |
+
|
| 14 |
+
def _encode_pair(self, context, continuation):
|
| 15 |
+
context_enc, continuation_enc = super()._encode_pair(context, continuation)
|
| 16 |
+
|
| 17 |
+
if not continuation_enc and continuation:
|
| 18 |
+
whole_enc = self.tok_encode(context + continuation)
|
| 19 |
+
if len(whole_enc) > 1:
|
| 20 |
+
continuation_enc = whole_enc[-1:]
|
| 21 |
+
context_enc = whole_enc[:-1]
|
| 22 |
+
elif whole_enc:
|
| 23 |
+
continuation_enc = whole_enc
|
| 24 |
+
context_enc = [self.prefix_token_id]
|
| 25 |
+
else:
|
| 26 |
+
continuation_enc = [self.prefix_token_id]
|
| 27 |
+
context_enc = [self.prefix_token_id]
|
| 28 |
+
|
| 29 |
+
return context_enc, continuation_enc
|
| 30 |
+
|
| 31 |
+
def _model_call(self, inps: torch.Tensor, attn_mask: torch.Tensor = None, **kwargs):
|
| 32 |
+
orig_len = inps.shape[1]
|
| 33 |
+
remainder = orig_len % self.pad_multiple
|
| 34 |
+
|
| 35 |
+
if remainder != 0:
|
| 36 |
+
pad_len = self.pad_multiple - remainder
|
| 37 |
+
inps = F.pad(inps, (0, pad_len), value=self.tokenizer.pad_token_id)
|
| 38 |
+
if attn_mask is not None:
|
| 39 |
+
attn_mask = F.pad(attn_mask, (0, pad_len), value=0)
|
| 40 |
+
|
| 41 |
+
logits = super()._model_call(inps, attn_mask=attn_mask, **kwargs)
|
| 42 |
+
if remainder != 0:
|
| 43 |
+
logits = logits[:, :orig_len, :]
|
| 44 |
+
return logits
|
| 45 |
+
|
| 46 |
+
def _model_generate(self, context, max_length, **kwargs):
|
| 47 |
+
orig_len = context.shape[1]
|
| 48 |
+
remainder = orig_len % self.pad_multiple
|
| 49 |
+
|
| 50 |
+
if remainder != 0:
|
| 51 |
+
pad_len = self.pad_multiple - remainder
|
| 52 |
+
context = F.pad(context, (pad_len, 0), value=self.tokenizer.pad_token_id)
|
| 53 |
+
if "attention_mask" in kwargs and kwargs["attention_mask"] is not None:
|
| 54 |
+
kwargs["attention_mask"] = F.pad(kwargs["attention_mask"], (pad_len, 0), value=0)
|
| 55 |
+
|
| 56 |
+
out = super()._model_generate(context, max_length, **kwargs)
|
| 57 |
+
if remainder != 0:
|
| 58 |
+
out = out[:, pad_len:]
|
| 59 |
+
|
| 60 |
+
return out
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
from lm_eval.__main__ import cli_evaluate
|
| 65 |
+
cli_evaluate()
|
lm_eval/lambada_openai_norm.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tag:
|
| 2 |
+
- lambada
|
| 3 |
+
task: lambada_openai_norm
|
| 4 |
+
dataset_path: EleutherAI/lambada_openai
|
| 5 |
+
dataset_name: default
|
| 6 |
+
output_type: loglikelihood
|
| 7 |
+
test_split: test
|
| 8 |
+
doc_to_text: "{{text.split(' ')[:-1]|join(' ')}}"
|
| 9 |
+
doc_to_target: "{{' '+text.split(' ')[-1]}}"
|
| 10 |
+
should_decontaminate: true
|
| 11 |
+
doc_to_decontamination_query: "{{text}}"
|
| 12 |
+
metric_list:
|
| 13 |
+
- metric: perplexity
|
| 14 |
+
aggregation: perplexity
|
| 15 |
+
higher_is_better: false
|
| 16 |
+
- metric: acc
|
| 17 |
+
aggregation: mean
|
| 18 |
+
higher_is_better: true
|
| 19 |
+
- metric: acc_norm
|
| 20 |
+
aggregation: mean
|
| 21 |
+
higher_is_better: true
|
| 22 |
+
metadata:
|
| 23 |
+
version: 1.0
|
lm_eval/main.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def main():
|
| 2 |
+
print("Hello from lm-eval!")
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
lm_eval/pyproject.toml
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "cloverlm-eval"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"accelerate>=1.13.0",
|
| 9 |
+
"apache-tvm-ffi==0.1.9",
|
| 10 |
+
"certifi==2026.2.25",
|
| 11 |
+
"charset-normalizer==3.4.6",
|
| 12 |
+
"click==8.3.1",
|
| 13 |
+
"cuda-bindings==13.0.3",
|
| 14 |
+
"cuda-pathfinder==1.4.3",
|
| 15 |
+
"cuda-python==13.0.3",
|
| 16 |
+
"einops==0.8.2",
|
| 17 |
+
"filelock==3.25.2",
|
| 18 |
+
"flashinfer-python==0.6.6",
|
| 19 |
+
"fsspec==2026.2.0",
|
| 20 |
+
"idna==3.11",
|
| 21 |
+
"jinja2==3.1.6",
|
| 22 |
+
"lm-eval>=0.4.11",
|
| 23 |
+
"markupsafe==3.0.3",
|
| 24 |
+
"mpmath==1.3.0",
|
| 25 |
+
"networkx==3.6.1",
|
| 26 |
+
"ninja==1.13.0",
|
| 27 |
+
"numpy==2.4.3",
|
| 28 |
+
"nvidia-cublas==13.1.0.3",
|
| 29 |
+
"nvidia-cuda-cupti==13.0.85",
|
| 30 |
+
"nvidia-cuda-nvrtc==13.0.88",
|
| 31 |
+
"nvidia-cuda-runtime==13.0.96",
|
| 32 |
+
"nvidia-cudnn-cu13==9.15.1.9",
|
| 33 |
+
"nvidia-cudnn-frontend==1.19.0",
|
| 34 |
+
"nvidia-cufft==12.0.0.61",
|
| 35 |
+
"nvidia-cufile==1.15.1.6",
|
| 36 |
+
"nvidia-curand==10.4.0.35",
|
| 37 |
+
"nvidia-cusolver==12.0.4.66",
|
| 38 |
+
"nvidia-cusparse==12.6.3.3",
|
| 39 |
+
"nvidia-cusparselt-cu13==0.8.0",
|
| 40 |
+
"nvidia-cutlass-dsl==4.4.2",
|
| 41 |
+
"nvidia-cutlass-dsl-libs-base==4.4.2",
|
| 42 |
+
"nvidia-ml-py==13.590.48",
|
| 43 |
+
"nvidia-nccl-cu13==2.28.9",
|
| 44 |
+
"nvidia-nvjitlink==13.0.88",
|
| 45 |
+
"nvidia-nvshmem-cu13==3.4.5",
|
| 46 |
+
"nvidia-nvtx==13.0.85",
|
| 47 |
+
"nvtx==0.2.15",
|
| 48 |
+
"packaging==26.0",
|
| 49 |
+
"quartet2",
|
| 50 |
+
"requests==2.32.5",
|
| 51 |
+
"scipy==1.17.1",
|
| 52 |
+
"sympy==1.14.0",
|
| 53 |
+
"tabulate==0.10.0",
|
| 54 |
+
"tokenmonster>=1.1.12",
|
| 55 |
+
"torch==2.10.0+cu130",
|
| 56 |
+
"tqdm==4.67.3",
|
| 57 |
+
"transformers>=5.3.0",
|
| 58 |
+
"triton==3.6.0",
|
| 59 |
+
"typing-extensions==4.15.0",
|
| 60 |
+
"urllib3==2.6.3",
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
[tool.uv.sources]
|
| 64 |
+
quartet2 = { git = "https://github.com/IST-DASLab/Quartet-II.git", subdirectory = "kernels", rev = "0a0d60c51602a78ae530944047e9e4973485bfef" }
|
lm_eval/requirements.txt
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
apache-tvm-ffi==0.1.9
|
| 2 |
+
certifi==2026.2.25
|
| 3 |
+
charset-normalizer==3.4.6
|
| 4 |
+
click==8.3.1
|
| 5 |
+
cuda-bindings==13.0.3
|
| 6 |
+
cuda-pathfinder==1.4.3
|
| 7 |
+
cuda-python==13.0.3
|
| 8 |
+
einops==0.8.2
|
| 9 |
+
filelock==3.25.2
|
| 10 |
+
flashinfer-python==0.6.6
|
| 11 |
+
fsspec==2026.2.0
|
| 12 |
+
idna==3.11
|
| 13 |
+
jinja2==3.1.6
|
| 14 |
+
markupsafe==3.0.3
|
| 15 |
+
mpmath==1.3.0
|
| 16 |
+
networkx==3.6.1
|
| 17 |
+
ninja==1.13.0
|
| 18 |
+
numpy==2.4.3
|
| 19 |
+
nvidia-cublas==13.1.0.3
|
| 20 |
+
nvidia-cuda-cupti==13.0.85
|
| 21 |
+
nvidia-cuda-nvrtc==13.0.88
|
| 22 |
+
nvidia-cuda-runtime==13.0.96
|
| 23 |
+
nvidia-cudnn-cu13==9.15.1.9
|
| 24 |
+
nvidia-cudnn-frontend==1.20.0
|
| 25 |
+
nvidia-cufft==12.0.0.61
|
| 26 |
+
nvidia-cufile==1.15.1.6
|
| 27 |
+
nvidia-curand==10.4.0.35
|
| 28 |
+
nvidia-cusolver==12.0.4.66
|
| 29 |
+
nvidia-cusparse==12.6.3.3
|
| 30 |
+
nvidia-cusparselt-cu13==0.8.0
|
| 31 |
+
nvidia-cutlass-dsl==4.4.2
|
| 32 |
+
nvidia-cutlass-dsl-libs-base==4.4.2
|
| 33 |
+
nvidia-ml-py==13.590.48
|
| 34 |
+
nvidia-nccl-cu13==2.28.9
|
| 35 |
+
nvidia-nvjitlink==13.0.88
|
| 36 |
+
nvidia-nvshmem-cu13==3.4.5
|
| 37 |
+
nvidia-nvtx==13.0.85
|
| 38 |
+
nvtx==0.2.15
|
| 39 |
+
packaging==26.0
|
| 40 |
+
quartet2 @ git+https://github.com/IST-DASLab/Quartet-II.git@0a0d60c51602a78ae530944047e9e4973485bfef#subdirectory=kernels
|
| 41 |
+
requests==2.32.5
|
| 42 |
+
scipy==1.17.1
|
| 43 |
+
sympy==1.14.0
|
| 44 |
+
tabulate==0.10.0
|
| 45 |
+
torch==2.10.0+cu130
|
| 46 |
+
tqdm==4.67.3
|
| 47 |
+
triton==3.6.0
|
| 48 |
+
typing-extensions==4.15.0
|
| 49 |
+
urllib3==2.6.3
|
lm_eval/uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_cloverlm.py
CHANGED
|
@@ -111,15 +111,27 @@ class MHSA(nn.Module):
|
|
| 111 |
|
| 112 |
dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
|
| 113 |
if attn_backend == "flash2":
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 116 |
elif attn_backend == "flash3":
|
| 117 |
import importlib
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 120 |
elif attn_backend == "flash4":
|
| 121 |
import importlib
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
|
| 124 |
Y = Y.to(Q.dtype).flatten(-2, -1)
|
| 125 |
|
|
|
|
| 111 |
|
| 112 |
dtype = Q.dtype if Q.dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
|
| 113 |
if attn_backend == "flash2":
|
| 114 |
+
try:
|
| 115 |
+
import flash_attn
|
| 116 |
+
except ImportError as e:
|
| 117 |
+
e.add_note(f"Can't run `attn_backend=flash2` because can't import flash_attn")
|
| 118 |
+
raise e
|
| 119 |
Y = flash_attn.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 120 |
elif attn_backend == "flash3":
|
| 121 |
import importlib
|
| 122 |
+
try:
|
| 123 |
+
_fa3 = importlib.import_module("flash_attn_interface")
|
| 124 |
+
except ImportError as e:
|
| 125 |
+
e.add_note(f"Can't run `attn_backend=flash3` because can't import flash_attn_interface")
|
| 126 |
+
raise e
|
| 127 |
Y = _fa3.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)
|
| 128 |
elif attn_backend == "flash4":
|
| 129 |
import importlib
|
| 130 |
+
try:
|
| 131 |
+
_fa4 = importlib.import_module("flash_attn.cute")
|
| 132 |
+
except ImportError as e:
|
| 133 |
+
e.add_note(f"Can't run `attn_backend=flash4` because can't import flash_attn.cute")
|
| 134 |
+
raise e
|
| 135 |
Y = _fa4.flash_attn_func(Q.to(dtype), K.to(dtype), V.to(dtype), causal=True, softmax_scale=1.0)[0]
|
| 136 |
Y = Y.to(Q.dtype).flatten(-2, -1)
|
| 137 |
|