File size: 3,421 Bytes
b0b72af
3721ad6
 
01799ff
3721ad6
 
 
 
 
 
 
 
01799ff
 
b0b72af
3721ad6
 
 
 
ac2d478
1b8053a
ac2d478
 
64d2ac6
3721ad6
 
 
 
 
61462f1
c8716cf
61462f1
c8716cf
3721ad6
 
520ddaa
a66e248
520ddaa
3721ad6
1b8053a
 
57a64f5
1b8053a
 
 
 
974c310
57a64f5
1b8053a
 
 
57a64f5
1b8053a
85e238b
61462f1
3721ad6
61462f1
fda7b62
 
61462f1
dd56d5d
 
61462f1
3721ad6
0077f4b
dd56d5d
3721ad6
0077f4b
dd56d5d
3721ad6
 
61462f1
3721ad6
61462f1
 
3721ad6
61462f1
 
6ef0e58
61462f1
 
6ef0e58
61462f1
 
 
 
 
 
 
3721ad6
61462f1
 
 
3721ad6
61462f1
3721ad6
61462f1
2fd9d6d
61462f1
3721ad6
 
 
a66e248
3721ad6
 
 
 
 
 
 
c8716cf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
---
language:
- ja
license: apache-2.0
tags:
- ja
- japanese
- text-generation
- lm
- jax
- flax
- lm1b
datasets:
- wiki40b
---
# transformer-lm-japanese-0.1b

This is a JAX/Flax-based transformer language model trained on a Japanese dataset. It is based on the official Flax example code ([lm1b](https://github.com/google/flax/tree/main/examples/lm1b)).

## Update Log
* 2024/05/20 Added JGLUE 4-task benchmark scores.
* 2024/05/13 FlaxAutoModelForCausalLM is now supported with custom model code added.

## Source Code

We've modified Flax's 'lm1b' example to train on Japanese dataset. You can find the code on Github.

* [transformer-lm-japanese](https://github.com/FookieMonster/transformer-lm-japanese)

## Our Blog Post

* [【0.1Bから作るLLM】 JAX/Flaxで作るTransformer言語モデル](https://zenn.dev/fukugawa/articles/4446573ec0f697)

## Model Details

| Model | Params | Layers | Dim | Heads | Dataset | Dataset size | Training time | PPL |
|-|-|-|-|-|-|-|-|-|
| transformer-lm-japanese-0.1b | 0.1B | 12 | 768 | 12 | wiki40b/ja | 2.19GB | 1.5 days | 35.22 |

## Benchmarking

* **JGLUE 4-task (2024/05/22)**

    - *We used [Stability-AI/lm-evaluation-harness](https://github.com/Stability-AI/lm-evaluation-harness) library for evaluation.*
    - *We modified the harness to work with the FlaxAutoModel for evaluating JAX/Flax models. See the code [here](https://github.com/FookieMonster/lm-evaluation-harness).*
    - *We evaluated four tasks: JCommonsenseQA-1.1, JNLI-1.3, MARC-ja-1.1, and JSQuAD-1.1.*
    - *All evaluations used version 0.3 (Alpaca) of the prompt template in a zero-shot setting.*
    - *The revision of the custom model used: [here](https://huggingface.co/fukugawa/transformer-lm-japanese-0.1b/commit/fe82d0f1366af71df8f8b383bf8de9ab6b0030be).*
   
    | Model | Average | JCommonsenseQA | JNLI | MARC-ja | JSQuAD |
    | :-- | :-- | :-- | :-- | :-- | :-- |
    | transformer-lm-japanese-0.1b | 41.41 | 35.21 | 43.59 | 78.63 | 8.24 |
    | Reference: rinna/japanese-gpt-neox-small | 40.75 | 40.39 | 29.13 | 85.48 | 8.02 |

## Usage: FlaxAutoModel

#### Requirements:

```
pip install transformers>=4.39.0
pip install jax==0.4.31
pip install flax==0.8.3
pip install sentencepiece==0.1.99

# For CPU
pip install -U "jax[cpu]==0.4.31"

# For GPU
pip install -U "jax[cuda12]==0.4.31"
```

Note: Set **trust_remote_code=True** to load our custom model.

~~~~python
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)
model = FlaxAutoModelForCausalLM.from_pretrained("fukugawa/transformer-lm-japanese-0.1b", trust_remote_code=True)

text = "日本の首都は、"
token_ids = tokenizer.encode(text, return_tensors="jax", add_special_tokens=False)

output_ids = model.generate(
  token_ids,
  do_sample=True,
  temperature=0.6,
  top_k=20,
  max_new_tokens=100
)

output = tokenizer.decode(output_ids[0][0], skip_special_tokens=True)
print(output)
~~~~

We tested text generation in a Python 3.10 environment on GCP as follows

* Machine Type: c2-standard-4 (4 CPUs, 16GB Memory)
* Disk: 100GB
* OS: Ubuntu 22.04 LTS x86/64

## Dataset

* [wiki40b/ja](https://www.tensorflow.org/datasets/catalog/wiki40b?hl=ja#wiki40bja) (2.19GB)
 
## Tokenization

* [sentencepiece](https://github.com/google/sentencepiece)

## Author

[Ryoichi Fukugawa](https://zenn.dev/fukugawa)