fukugawa's picture
Update README.md
dd56d5d verified
metadata
language:
  - ja
license: apache-2.0
tags:
  - ja
  - japanese
  - text-generation
  - lm
  - jax
  - flax
  - lm1b
datasets:
  - wiki40b

transformer-lm-japanese-0.1b

This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code (lm1b).

Update Log

  • 2024/05/20 Added JGLUE 4-task benchmark scores.
  • 2024/05/13 FlaxAutoModelForCausalLM is now supported with custom model code added.

Source Code

We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.

Our Blog Post

Model Details

Model Params Layers Dim Heads Dataset Dataset size Training time PPL
transformer-lm-japanese-0.1b 0.1B 12 768 12 wiki40b/ja 2.19GB 1.5 days 35.22

Benchmarking

  • JGLUE 4-task (2024/05/22)

    • We used Stability-AI/lm-evaluation-harness library for evaluation.
    • We modified the harness to work with the FlaxAutoModel for evaluating JAX/Flax models. See the code here.
    • We evaluated four tasks: JCommonsenseQA-1.1, JNLI-1.3, MARC-ja-1.1, and JSQuAD-1.1.
    • All evaluations used version 0.3 (Alpaca) of the prompt template in a zero-shot setting.
    • The revision of the custom model used: here.
    Model Average JCommonsenseQA JNLI MARC-ja JSQuAD
    transformer-lm-japanese-0.1b 41.41 35.21 43.59 78.63 8.24
    Reference: rinna/japanese-gpt-neox-small 40.75 40.39 29.13 85.48 8.02

Usage: FlaxAutoModel

Requirements:

pip install transformers>=4.39.0
pip install jax==0.4.31
pip install flax==0.8.3
pip install sentencepiece==0.1.99

# For CPU
pip install -U "jax[cpu]==0.4.31"

# For GPU
pip install -U "jax[cuda12]==0.4.31"

Note: Set trust_remote_code=True to load our custom model.

from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)

text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)

output_ids = model.generate(
  token_ids,
  do_sample=True,
  temperature=0.6,
  top_k=20,
  max_new_tokens=100
)

output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)

We tested text generation in a Python 3.10 environment on GCP as follows

  • Machine Type: c2-standard-4 (4 CPUs, 16GB Memory)
  • Disk: 100GB
  • OS: Ubuntu 22.04 LTS x86/64

Dataset

Tokenization

Author

Ryoichi Fukugawa