File size: 7,450 Bytes
fb5d3c9
 
 
 
 
 
 
 
 
 
 
 
3cfe01a
fb5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b76332
fb5d3c9
329ccb0
3cfe01a
fb5d3c9
 
 
 
 
 
819d065
fb5d3c9
 
3cfe01a
fb5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f18687
fb5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab4dcf3
fb5d3c9
 
 
 
ab4dcf3
 
 
fb5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329ccb0
fb5d3c9
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
---
license: apache-2.0
datasets:
- tiiuae/falcon-refinedweb
pipeline_tag: text-generation
library_name: openlm
tags:
- linear
- mistral
language:
- en
model-index:
- name: mistral-supra
  results:
  - task:
      type: text-generation
    dataset:
      type: MMLU
      name: MMLU
    metrics:
    - name: accuracy
      type: accuracy
      value: 34.2
      verified: false
  - task:
      type: text-generation
    dataset:
      type: HellaSwag
      name: HellaSwag
    metrics:
    - name: accuracy
      type: accuracy
      value: 77.1
      verified: false
  - task:
      type: text-generation
    dataset:
      type: PIQA
      name: PIQA
    metrics:
    - name: accuracy
      type: accuracy
      value: 80.4
      verified: false
  - task:
      type: text-generation
    dataset:
      type: Winogrande
      name: Winogrande
    metrics:
    - name: accuracy
      type: accuracy
      value: 70.3
      verified: false
  - task:
      type: text-generation
    dataset:
      type: ai2_arc
      name: ARC-E
    metrics:
    - name: accuracy
      type: accuracy
      value: 75.9
      verified: false
  - task:
      type: text-generation
    dataset:
      type: ai2_arc
      name: ARC-C
    metrics:
    - name: accuracy
      type: accuracy
      value: 45.8
      verified: false
---

# Mistral-SUPRA
This model was initialized from the weights of the [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) transformer model and up-trained into a linear RNN. 

This is an accompanying model of our paper [Linearizing Large Language Models](https://arxiv.org/abs/2405.06640), where we detail our process of converting a softmax transformer into a linear transformer, which at inference time can function as both a transformer and a recurrent model.
Our linear attention code can be found at https://github.com/TRI-ML/linear_open_lm/ 

We uptrain Mistral-7B on 100B tokens of RefinedWeb.


## Model Details
- **Developed by**: [Toyota Research Institute](https://www.tri.global/our-work/robotics)
- **Model Type**: This is an auto-regressive language model initialized from [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) and uptrained into a linear model based on the [SUPRA](https://arxiv.org/abs/2405.06640) architecture.
- **Dataset**: Initialized from [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1). Uprained on 100B tokens of [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb).
- **Tokenizer**: `mistralai/Mistral-7B-v0.1`
- **Library**: [OpenLM](https://github.com/mlfoundations/open_lm/) (we use a [fork](https://github.com/TRI-ML/linear_open_lm/) of OpenLM that supports linear attention)
- **License**: This model is licensed under [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0).
 
| Parameters | Hidden Size | Layers | Vocab Size | Sequence Length | 
|------------|-------------|--------| ---------- | --------------- |
| 7B         | 4096        | 32     | 32000      | 2048            |

## Training Details
- Mistral-SUPRA was trained using AWS SageMaker on 128 H100 80GB GPUs.
- Training on 100B tokens finished in 1.5 days.
| **Hyperparameter** | **Value**  | 
|--------------------|------------|
| Precision          | `bfloat16` |
| Optimizer          | AdamW      |
| Learning rate      | 3e-5       |
| LR cooldown end    | 1e-5       |
| Warmup steps       | 1000       |
| Batch size         | 2M         |
| QK norm            | False      |


## Usage
This model was trained using [OpenLM](https://github.com/mlfoundations/open_lm/). The weights have been converted to be compatible with HuggingFace.

To use the model, you need to first pip install our fork of OpenLM.
```bash
pip install git+https://github.com/tri-ml/linear_open_lm.git
```

Import the OpenLM classes with 

```python
from open_lm.open_lm_hf import *
```

The model can then be loaded normally using `AutoTokenizer` and `AutoModelForCausalLM` as follows:

```python
from open_lm.open_lm_hf import *
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("tri-ml/mistral-supra")
model = AutoModelForCausalLM.from_pretrained("tri-ml/mistral-supra")

inputs = tokenizer(["Machine learning is"], return_tensors="pt")
gen_kwargs = {"max_new_tokens": 50, "top_p": 0.8, "temperature": 0.8, "do_sample": True, "repetition_penalty": 1.1}
output = model.generate(inputs['input_ids'], **gen_kwargs)
output = tokenizer.decode(output[0].tolist(), skip_special_tokens=True)
print(output)
# Machine learning is a branch of artificial intelligence (AI) that enables computers to learn from experience without being explicitly programmed. Machine learning is used in a wide range of applications, including spam filtering, image recognition, speech recognition, and computer-based medical diagnosis
```

The Mistral-SUPRA model can be used both in parallel mode and in recurrent mode. If `use_cache` is set to `False` for `model.generate(...)`, then it will use parallel mode; otherwise, it will use recurrent mode. 
The recurrent model uses `xformers` and requires the inputs and models to be loaded to GPU.

```python
# Recurrent mode
output = model.to('cuda').generate(inputs['input_ids'].to('cuda'), use_cache=True, **gen_kwargs)

# Parallel mode
output = model.to('cuda').generate(inputs['input_ids'].to('cuda'), use_cache=False, **gen_kwargs)
```


## Performance Evaluation
Our evaluations were done using the [Eleuther LM Eval Harness](https://github.com/EleutherAI/lm-evaluation-harness) repo.

Below we report the performance of Mistral-SUPRA compared to other similarly sized models.

<div class="evalTable">

|                   | HellaSwag     | PIQA     | Winogrande     | ARC-E     | ARC-C     | MMLU (5-shot)    |
| ----------------- | ------------- | -------- | -------------- | --------- | --------- | ---------------- |
| Llama2-7B         | 76.0          | 79.1     | 69.1           | 76.3      | 46.3      | 45.9             |
| Gemma-7B          | 80.7          | 81.9     | 73.7           | 81.1      | 53.2      | 62.9             |
| Mistral-7B        | 81.0          | 82.1     | 74.0           | 80.9      | 53.8      | 62.4             |
| RWKV5-1.7T-7B     | 73.0          | 78.6     | 72.9           | 75.8      | 45.6      | 34.9             |
| Mamba-7B          | 77.9          | 81.0     | 71.8           | 77.5      | 46.7      | 33.3             |
| **Mistral-SUPRA** | 77.1          | 80.4     | 70.3           | 75.9      | 45.8      | 34.2             |

</div>


## How to Cite
If you use this model, please cite our paper on Linearizing Large Language Models.
```
@article{Mercat2024Linearizing,
  title={Linearizing Large Language Models},
  author={Jean Mercat and Igor Vasiljevic and Sedrick Keh and Kushal Arora and Achal Dave and Adrien Gaidon and Thomas Kollar},
  year={2024},
  journal={arXiv preprint arXiv:2405.06640},
}
```

## Citations
OpenLM
```
@misc{open_lm,
  author = {Gururangan, Suchin and Wortsman, Mitchell and Gadre, Samir Yitzhak and Dave, Achal and Kilian, Maciej and Shi, Weijia and Mercat, Jean and Smyrnis, Georgios and Ilharco, Gabriel and Jordan, Matt and Heckel, Reinhard and Dimakis, Alex and Farhadi, Ali and Shankar, Vaishaal and Schmidt, Ludwig},
  title = {{open_lm}:  a minimal but performative language modeling (LM) repository},
  year = {2023},
  note = {GitHub repository},
  url = {https://github.com/mlfoundations/open_lm/}
}
```