Edit model card

transformer-lm-japanese-0.1b

This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code (lm1b).

Update Log

  • 2024/05/13 FlaxAutoModelForCausalLM is now supported with custom model code added.

Source Code

We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.

Our Blog Post

Model Details

Model Params Layers Dim Heads PPL Dataset Training time
lm1b-default 0.05B 6 512 8 22.67 lm1b 0.5 days
transformer-lm-japanese-default 0.05B 6 512 8 66.38 cc100/ja 0.5 days
transformer-lm-japanese-0.1b 0.1B 12 768 12 35.22 wiki40b/ja 1.5 days

tensor-board

Usage: FlaxAutoModel

Requirements:

pip install transformers>=4.39.0
pip install jax==0.4.13
pip install flax==0.6.11
pip install sentencepiece==0.1.99

# For CPU
pip install jax[cpu]==0.4.13

# For GPU
pip install --upgrade "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Note: Set trust_remote_code=True to load our custom model.

from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)

text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)

output_ids = model.generate(
  token_ids,
  do_sample=True,
  temperature=0.6,
  top_k=20,
  max_new_tokens=100
)

output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)

We tested text generation in a Python 3.10 environment on GCP as follows

  • Machine Type: c2-standard-4 (4 CPUs, 16GB Memory)
  • Disk: 100GB (Standard Persistent Disk)
  • OS: Ubuntu 22.04 LTS x86/64

Dataset

  • wiki40b/ja

Tokenization

Author

Ryoichi Fukugawa

Downloads last month
23

Dataset used to train fukugawa/transformer-lm-japanese-0.1b