--- language: - ja license: apache-2.0 tags: - ja - japanese - text-generation - lm - jax - flax - lm1b datasets: - wiki40b --- # transformer-lm-japanese-0.1b This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code ([lm1b](https://github.com/google/flax/tree/main/examples/lm1b)). ## Update Log * 2024/05/20 Added JGLUE 4-task benchmark scores. * 2024/05/13 FlaxAutoModelForCausalLM is now supported with custom model code added. ## Source Code We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github. * [transformer-lm-japanese](https://github.com/FookieMonster/transformer-lm-japanese) ## Our Blog Post * [【0.1Bから作るLLM】 JAX/Flaxで作るTransformer言語モデル](https://zenn.dev/fukugawa/articles/4446573ec0f697) ## Model Details | Model | Params | Layers | Dim | Heads | PPL | Dataset | Training time | |-|-|-|-|-|-|-|-| | transformer-lm-japanese-0.1b | 0.1B | 12 | 768 | 12 | 35.22 | wiki40b/ja | 1.5 days | ## Benchmarking * **JGLUE 4-task (2024/05/22)** - *We used [Stability-AI/lm-evaluation-harness](https://github.com/Stability-AI/lm-evaluation-harness) library for evaluation.* - *We modified the harness to work with the FlaxAutoModel for evaluating JAX/Flax models. See the code [here](https://github.com/FookieMonster/lm-evaluation-harness).* - *We evaluated four tasks: JCommonsenseQA-1.1, JNLI-1.3, MARC-ja-1.1, and JSQuAD-1.1.* - *All evaluations used version 0.3 (Alpaca) of the prompt template in a zero-shot setting.* - *The revision of the custom model used: [here](https://huggingface.co/fukugawa/transformer-lm-japanese-0.1b/commit/fe82d0f1366af71df8f8b383bf8de9ab6b0030be).* | Model | Average | JCommonsenseQA | JNLI | MARC-ja | JSQuAD | | :-- | :-- | :-- | :-- | :-- | :-- | | transformer-lm-japanese-0.1b | 41.41 | 35.21 | 43.59 | 78.63 | 8.24 | | Reference: rinna/japanese-gpt-neox-small | 40.75 | 40.39 | 29.13 | 85.48 | 8.02 | ## Usage: FlaxAutoModel #### Requirements: ``` pip install transformers>=4.39.0 pip install jax==0.4.13 pip install flax==0.6.11 pip install sentencepiece==0.1.99 # For CPU pip install jax[cpu]==0.4.13 # For GPU pip install -U "jax[cuda12_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html ``` Note: Set **trust_remote_code=True** to load our custom model. ~~~~python from transformers import AutoTokenizer, FlaxAutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True) model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True) text = "日本の首都は、" token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False) output_ids = model.generate( token_ids, do_sample=True, temperature=0.6, top_k=20, max_new_tokens=100 ) output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True) print(output) ~~~~ We tested text generation in a Python 3.10 environment on GCP as follows * Machine Type: c2-standard-4 (4 CPUs, 16GB Memory) * Disk: 100GB (Standard Persistent Disk) * OS: Ubuntu 22.04 LTS x86/64 ## Dataset * wiki40b/ja ## Tokenization * [sentencepiece](https://github.com/google/sentencepiece) ## Author [Ryoichi Fukugawa](https://zenn.dev/fukugawa)