ksterx's picture
Upload 3 files
bbdb282 verified
|
raw
history blame
4.72 kB
metadata
language:
  - ja
  - en
license: mit
library_name: transformers

SpiralAI RetNet-3b-ja-base

SpiralAI RetNet-3b-ja-base

We have conducted pre-training from scratch on the RetNet (https://arxiv.org/abs/2307.08621) architecture model 3b using a mixed dataset of Japanese and English. This model is released primarily for the basic research of "retention mechanism".

Model Description

  • Developed by: SpiralAI
  • Model type: The SpiralAI RetNet-3b-ja-base is a language model equipped with a retention mechanism. It uses the cyberagent/calm2-7b-chat tokenizer.
  • Languages: Japanese, English.
  • License: MIT
  • Training: Trained on 80b tokens.
  • Context Length: 2,048 tokens.

Installation

pip install transformers==4.38  # The top_k_top_p_filtering feature has been removed in later versions.

Clone the repository from https://github.com/syncdoth/RetNet and follow the Getting Started guide provided there.

Example:

git clone https://github.com/syncdoth/RetNet.git
pip install torch transformers timm
cd RetNet

Usage

from transformers import AutoTokenizer

from retnet.modeling_retnet import RetNetForCausalLM

tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")
tokenizer.pad_token = tokenizer.eos_token

model = RetNetForCausalLM.from_pretrained(
    "Spiral-AI/RetNet-3b-base-ja", device_map="auto"
)
inputs = tokenizer("最近、秋葉原周辺で興味深い", return_tensors="pt")
input_ids = inputs["input_ids"].to(model.device)
generated = model.generate(
    input_ids,
    max_new_tokens=32,
    repetition_penalty=1.2,  # better to set this value for 3 billion model
)
print(tokenizer.decode(generated[0]))

Examples

input: 最近、秋葉原周辺で興味深い
output: お店がいくつかあります。
1. 神田カレー街「カレーハウスCoCo壱番屋」
2016年7月3日オープン
input: 近年、AI技術の進歩によって
output: 人間の仕事が奪われるのではないかという懸念がある。
しかしながら、AIは人間に取って代わるものではなく、「人間がコンピュータに仕事をさせる」という考え方
input: When I was a child, I used to play with
output: 3-D glasses. They were so much fun!
I have been playing around in the world of video games for years now and it is amazing how

Basic study

Visualization of the retention mechanism

retention This visualization shows the retention mechanism in action. The token being generated is represented by *. The blue bars show how the tokens are weighted during generation.

Using the mathmatical equivalence between "recurrent mode" and "parallel mode", we apply the similar visualization technique as the attention mechanism, e.g., inner product between queries and keys are added up over all heads after absolute values are taken. Here we show the result of the last layer.

Test loss comparison

We compared the test loss of Spiral-AI/RetNet-3b-ja-base and cyberagent/open-calm-3b on different length of tokens. The first 100 examples are extracted from wikipedia-ja for the test dataset.

test_loss

Key findings are:

  • The test loss of Spiral-AI/RetNet-3b-ja-base goes as low as cyberagent/open-calm-3b, showing the effectiveness of the retention mechanism.
  • The explosion of test loss is suppressed in Spiral-AI/RetNet-3b-ja-base when the context length goes longer than 2,048 tokens (the maximum context length of training data; Note that cyberagent/open-calm-3b is trained on the same context length.).

Training Datasets

Limitations

This model is designed for broad applicability, but it may not fully meet the specific needs or contexts of all uses. Pre-training data may contain inappropriate content, which could be reflected in the texts generated by the model. Therefore, when using this model, it is important to carefully review its output and avoid situations where it might cause discomfort or harm to individuals or groups.

There are no specific restrictions on commercial use, but users are responsible for addressing any ethical or legal issues that may arise in connection with the use of the model.