Upload 3 files
Browse files- README.md +120 -0
- loss_comparison.png +0 -0
- retention.gif +0 -0
README.md
CHANGED
@@ -1,3 +1,123 @@
|
|
1 |
---
|
|
|
|
|
|
|
2 |
license: mit
|
|
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
language:
|
3 |
+
- ja
|
4 |
+
- en
|
5 |
license: mit
|
6 |
+
library_name: transformers
|
7 |
---
|
8 |
+
|
9 |
+
![SpiralAI RetNet-3b-ja-base](logo.png)
|
10 |
+
|
11 |
+
# SpiralAI RetNet-3b-ja-base
|
12 |
+
|
13 |
+
We have conducted pre-training from scratch on the RetNet (https://arxiv.org/abs/2307.08621) architecture model 3b using a mixed dataset of Japanese and English.
|
14 |
+
This model is released primarily for the basic research of "retention mechanism".
|
15 |
+
|
16 |
+
# Model Description
|
17 |
+
|
18 |
+
- **Developed by:** [SpiralAI](https://go-spiral.ai/)
|
19 |
+
- **Model type:** The `SpiralAI RetNet-3b-ja-base` is a language model equipped with a retention mechanism. It uses the `cyberagent/calm2-7b-chat` tokenizer.
|
20 |
+
- **Languages:** Japanese, English.
|
21 |
+
- **License:** MIT
|
22 |
+
- **Training:** Trained on 80b tokens.
|
23 |
+
- **Context Length:** 2,048 tokens.
|
24 |
+
|
25 |
+
# Installation
|
26 |
+
|
27 |
+
```bash
|
28 |
+
pip install transformers==4.38 # The top_k_top_p_filtering feature has been removed in later versions.
|
29 |
+
|
30 |
+
```
|
31 |
+
|
32 |
+
Clone the repository from **`https://github.com/syncdoth/RetNet`** and follow the *Getting Started* guide provided there.
|
33 |
+
|
34 |
+
Example:
|
35 |
+
|
36 |
+
```bash
|
37 |
+
git clone https://github.com/syncdoth/RetNet.git
|
38 |
+
pip install torch transformers timm
|
39 |
+
cd RetNet
|
40 |
+
|
41 |
+
```
|
42 |
+
|
43 |
+
# Usage
|
44 |
+
|
45 |
+
```python
|
46 |
+
from transformers import AutoTokenizer
|
47 |
+
|
48 |
+
from retnet.modeling_retnet import RetNetForCausalLM
|
49 |
+
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")
|
51 |
+
tokenizer.pad_token = tokenizer.eos_token
|
52 |
+
|
53 |
+
model = RetNetForCausalLM.from_pretrained(
|
54 |
+
"Spiral-AI/RetNet-3b-base-ja", device_map="auto"
|
55 |
+
)
|
56 |
+
inputs = tokenizer("最近、秋葉原周辺で興味深い", return_tensors="pt")
|
57 |
+
input_ids = inputs["input_ids"].to(model.device)
|
58 |
+
generated = model.generate(
|
59 |
+
input_ids,
|
60 |
+
max_new_tokens=32,
|
61 |
+
repetition_penalty=1.2, # better to set this value for 3 billion model
|
62 |
+
)
|
63 |
+
print(tokenizer.decode(generated[0]))
|
64 |
+
|
65 |
+
```
|
66 |
+
|
67 |
+
## Examples
|
68 |
+
```
|
69 |
+
input: 最近、秋葉原周辺で興味深い
|
70 |
+
output: お店がいくつかあります。
|
71 |
+
1. 神田カレー街「カレーハウスCoCo壱番屋」
|
72 |
+
2016年7月3日オープン
|
73 |
+
```
|
74 |
+
|
75 |
+
```
|
76 |
+
input: 近年、AI技術の進歩によって
|
77 |
+
output: 人間の仕事が奪われるのではないかという懸念がある。
|
78 |
+
しかしながら、AIは人間に取って代わるものではなく、「人間がコンピュータに仕事をさせる」という考え方
|
79 |
+
```
|
80 |
+
|
81 |
+
```
|
82 |
+
input: When I was a child, I used to play with
|
83 |
+
output: 3-D glasses. They were so much fun!
|
84 |
+
I have been playing around in the world of video games for years now and it is amazing how
|
85 |
+
```
|
86 |
+
|
87 |
+
# Basic study
|
88 |
+
|
89 |
+
## Visualization of the retention mechanism
|
90 |
+
|
91 |
+
![retention](retention.gif)
|
92 |
+
This visualization shows the retention mechanism in action. The token being generated is represented by `*`.
|
93 |
+
The blue bars show how the tokens are weighted during generation.
|
94 |
+
|
95 |
+
Using the mathmatical equivalence between "recurrent mode" and "parallel mode", we apply the similar visualization technique as the attention mechanism, e.g.,
|
96 |
+
inner product between queries and keys are added up over all heads after absolute values are taken.
|
97 |
+
Here we show the result of the last layer.
|
98 |
+
|
99 |
+
## Test loss comparison
|
100 |
+
|
101 |
+
We compared the test loss of `Spiral-AI/RetNet-3b-ja-base` and `cyberagent/open-calm-3b` on different length of tokens.
|
102 |
+
The first 100 examples are extracted from `wikipedia-ja` for the test dataset.
|
103 |
+
|
104 |
+
![test_loss](loss_comparison.png)
|
105 |
+
|
106 |
+
Key findings are:
|
107 |
+
|
108 |
+
- The test loss of `Spiral-AI/RetNet-3b-ja-base` goes as low as `cyberagent/open-calm-3b`, showing the effectiveness of the retention mechanism.
|
109 |
+
- The explosion of test loss is suppressed in `Spiral-AI/RetNet-3b-ja-base` when the context length goes longer than 2,048 tokens (the maximum context length of training data; Note that `cyberagent/open-calm-3b` is trained on the same context length.).
|
110 |
+
|
111 |
+
# Training Datasets
|
112 |
+
|
113 |
+
- [izumi-lab/cc100-ja-filter-ja-normal](https://huggingface.co/datasets/izumi-lab/cc100-ja-filter-ja-normal) (Japanese)
|
114 |
+
- [izumi-lab/wikipedia-ja-20230720](https://huggingface.co/datasets/izumi-lab/wikipedia-ja-20230720) (Japanese)
|
115 |
+
- [wikipedia](https://huggingface.co/datasets/wikipedia/tree/main/data/20220301.en) (English)
|
116 |
+
- [uonlp/CulturaX](https://huggingface.co/datasets/uonlp/CulturaX) (English, Japanese)
|
117 |
+
|
118 |
+
# Limitations
|
119 |
+
|
120 |
+
This model is designed for broad applicability, but it may not fully meet the specific needs or contexts of all uses.
|
121 |
+
Pre-training data may contain inappropriate content, which could be reflected in the texts generated by the model. Therefore, when using this model, it is important to carefully review its output and avoid situations where it might cause discomfort or harm to individuals or groups.
|
122 |
+
|
123 |
+
There are no specific restrictions on commercial use, but users are responsible for addressing any ethical or legal issues that may arise in connection with the use of the model.
|
loss_comparison.png
ADDED
retention.gif
ADDED