File size: 2,063 Bytes
d5d8142
ed9424a
d5d8142
f52774c
90f1e33
 
f52774c
 
 
 
96f9264
f52774c
 
96f9264
f52774c
 
 
 
 
 
 
 
 
 
 
 
 
7746361
90f1e33
f52774c
54366c1
 
 
 
 
f52774c
 
fbd0d01
90f1e33
fbd0d01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
---

license: other
---


# xLSTM-7B
This xLSTM-7B was pre-trained on the DCLM and selected high-quality data for in a total of approx. 2.3 T tokens using the `xlstm-jax` framework.


## How to use it
First, install `xlstm`, which now uses the `mlstm_kernels` package for triton kernels:

```bash

pip install xlstm

pip install mlstm_kernels

```

For now, install the transformers repositiory fork from NX-AI (until it is merged):
```bash

pip install 'transformers @ git+ssh://git@github.com/NX-AI/transformers.git@integrate_xlstm'

```

Use this model as:
```python

from transformers import AutoModelForCausalLM, AutoTokenizer



xlstm = AutoModelForCausalLM.from_pretrained("NX-AI/xLSTM-7b", device_map="auto")



# this is a fork of EleutherAI/gpt-neox-20b

tokenizer = AutoTokenizer.from_pretrained("NX-AI/xLSTM-7b")



tokens = tokenizer("Hello xLSTM, how are you doing?", return_tensors='pt')['input_ids'].to(device="cuda")



out = xlstm.generate(tokens, max_new_tokens=20)



print(tokenizer.decode(out[0]))

```

## Speed results
Generation Speed using `torch.cuda.graph` and `torch.compile` optimizations on one NVIDIA H100:
![generation speed](plot_tokens_per_sec.svg)

## Performance
![mmlu_train_token](MMLUvsTrainToken.svg)

Using HuggingFace's `lm_eval`:

| BBH   | MMLU-Pro | Math   | MUSR | GPQA | IfEval | 
|-------|----------|--------|------|------|--------|
| 0.381	| 0.242	   | 0.036	| 0.379|0.280 |	0.244  |

Using HuggingFace's `lighteval` in the Leaderboard-v1 settings:

|Arc-Challenge (25-shot) |MMLU (5-shot) |Hellaswag (10-shot)|Winogrande (5-shot) |TruthfulQA (0-shot) |GSM8k (5-shot) |OpenbookQA (5-shot) | PiQA (5-shot)|
|------------------------|--------------|-------------------|--------------------|--------------------|---------------|--------------------|--------------|
| 0.584	                 |0.589         |           0.710   |0.742               |          0.420     |         0.004 |         0.443      |        0.817 |

## License 
NXAI Community License (see `LICENSE` file)