metadata
language:
- ja
license: apache-2.0
tags:
- ja
- japanese
- text-generation
- lm
- jax
- flax
- lm1b
datasets:
- wiki40b
transformer-lm-japanese-0.1b
This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code (lm1b).
Update Log
- 2024/05/20 Added JGLUE 4-task benchmark scores.
- 2024/05/13 FlaxAutoModelForCausalLM is now supported with custom model code added.
Source Code
We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.
Our Blog Post
Model Details
Model | Params | Layers | Dim | Heads | Dataset | Dataset size | Training time | PPL |
---|---|---|---|---|---|---|---|---|
transformer-lm-japanese-0.1b | 0.1B | 12 | 768 | 12 | wiki40b/ja | 2.19GB | 1.5 days | 35.22 |
Benchmarking
JGLUE 4-task (2024/05/22)
- We used Stability-AI/lm-evaluation-harness library for evaluation.
- We modified the harness to work with the FlaxAutoModel for evaluating JAX/Flax models. See the code here.
- We evaluated four tasks: JCommonsenseQA-1.1, JNLI-1.3, MARC-ja-1.1, and JSQuAD-1.1.
- All evaluations used version 0.3 (Alpaca) of the prompt template in a zero-shot setting.
- The revision of the custom model used: here.
Model Average JCommonsenseQA JNLI MARC-ja JSQuAD transformer-lm-japanese-0.1b 41.41 35.21 43.59 78.63 8.24 Reference: rinna/japanese-gpt-neox-small 40.75 40.39 29.13 85.48 8.02
Usage: FlaxAutoModel
Requirements:
pip install transformers>=4.39.0
pip install jax==0.4.31
pip install flax==0.8.3
pip install sentencepiece==0.1.99
# For CPU
pip install -U "jax[cpu]==0.4.31"
# For GPU
pip install -U "jax[cuda12]==0.4.31"
Note: Set trust_remote_code=True to load our custom model.
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)
output_ids = model.generate(
token_ids,
do_sample=True,
temperature=0.6,
top_k=20,
max_new_tokens=100
)
output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)
We tested text generation in a Python 3.10 environment on GCP as follows
- Machine Type: c2-standard-4 (4 CPUs, 16GB Memory)
- Disk: 100GB
- OS: Ubuntu 22.04 LTS x86/64
Dataset
- wiki40b/ja (2.19GB)