Text Generation
Transformers
PyTorch
Safetensors
Japanese
gpt_neox
lm
nlp
text-generation-inference
Edit model card

japanese-gpt-neox-3.6b-instruction-ppo

rinna-icon

Overview

This repository provides a Japanese GPT-NeoX model of 3.6 billion parameters. The model is based on rinna/japanese-gpt-neox-3.6b-instruction-sft-v2 and has been aligned to serve as an instruction-following conversational agent.

Limitations

  • We found this verison of PPO model tends to generate repeated text more often than its SFT counterpart, and thus we set repetition_penalty=1.1 for better generation performance. (The same generation hyper-parameters are applied to the SFT model in aforementioned evaluation experiments.) You can also explore other hyperparameter combinations that yield higher generation randomness/diversity for better generation quality, e.g. temperature=0.9, repetition_penalty=1.0.

I/O Format

A special format has been adopted to construct inputs.

  • An input prompt is formatted as a conversation between ใƒฆใƒผใ‚ถใƒผ and ใ‚ทใ‚นใƒ†ใƒ .
  • Each input utterance consists of (1) its speaker ("ใƒฆใƒผใ‚ถใƒผ" or "ใ‚ทใ‚นใƒ†ใƒ "), (2) a colon (":"), (3) a whitespace (" "), and (4) utterance text (e.g. "ไธ–็•Œใงไธ€็•ช้ซ˜ใ„ๅฑฑใฏ๏ผŸ").
  • The input prompt should be ended with "ใ‚ทใ‚นใƒ†ใƒ : " to acknowledge the model to generate a response.
  • Since the model's tokenizer does not recognize "\n", a special newline symbol "<NL>" is used instead.
  • All the newlines in input and output utterances should be replaced with "<NL>".
  • All the utterances in the input prompt should be separated by "<NL>".

Following is an example to construct an input from a conversation.

prompt = [
    {
        "speaker": "ใƒฆใƒผใ‚ถใƒผ",
        "text": "ใ‚ณใƒณใ‚ฟใ‚ฏใƒˆใƒฌใƒณใ‚บใ‚’ๆ…ฃใ‚Œใ‚‹ใซใฏใฉใ†ใ™ใ‚Œใฐใ‚ˆใ„ใงใ™ใ‹๏ผŸ"
    },
    {
        "speaker": "ใ‚ทใ‚นใƒ†ใƒ ",
        "text": "ใ“ใ‚Œใซใคใ„ใฆๅ…ทไฝ“็š„ใซ่ชฌๆ˜Žใ—ใฆใ„ใŸใ ใ‘ใพใ™ใ‹๏ผŸไฝ•ใŒ้›ฃใ—ใ„ใฎใงใ—ใ‚‡ใ†ใ‹๏ผŸ"
    },
    {
        "speaker": "ใƒฆใƒผใ‚ถใƒผ",
        "text": "็›ฎใŒ็—›ใ„ใฎใงใ™ใ€‚"
    },
    {
        "speaker": "ใ‚ทใ‚นใƒ†ใƒ ",
        "text": "ๅˆ†ใ‹ใ‚Šใพใ—ใŸใ€ใ‚ณใƒณใ‚ฟใ‚ฏใƒˆใƒฌใƒณใ‚บใ‚’ใคใ‘ใ‚‹ใจ็›ฎใŒใ‹ใ‚†ใใชใ‚‹ใจใ„ใ†ใ“ใจใงใ™ใญใ€‚ๆ€ใฃใŸไปฅไธŠใซใƒฌใƒณใ‚บใ‚’ๅค–ใ™ๅฟ…่ฆใŒใ‚ใ‚‹ใงใ—ใ‚‡ใ†ใ‹๏ผŸ"
    },
    {
        "speaker": "ใƒฆใƒผใ‚ถใƒผ",
        "text": "ใ„ใˆใ€ใƒฌใƒณใ‚บใฏๅค–ใ—ใพใ›ใ‚“ใŒใ€็›ฎใŒ่ตคใใชใ‚‹ใ‚“ใงใ™ใ€‚"
    }
]
prompt = [
    f"{uttr['speaker']}: {uttr['text']}"
    for uttr in prompt
]
prompt = "<NL>".join(prompt)
prompt = (
    prompt
    + "<NL>"
    + "ใ‚ทใ‚นใƒ†ใƒ : "
)
print(prompt)
# "ใƒฆใƒผใ‚ถใƒผ: ใ‚ณใƒณใ‚ฟใ‚ฏใƒˆใƒฌใƒณใ‚บใ‚’ๆ…ฃใ‚Œใ‚‹ใซใฏใฉใ†ใ™ใ‚Œใฐใ‚ˆใ„ใงใ™ใ‹๏ผŸ<NL>ใ‚ทใ‚นใƒ†ใƒ : ใ“ใ‚Œใซใคใ„ใฆๅ…ทไฝ“็š„ใซ่ชฌๆ˜Žใ—ใฆใ„ใŸใ ใ‘ใพใ™ใ‹๏ผŸไฝ•ใŒ้›ฃใ—ใ„ใฎใงใ—ใ‚‡ใ†ใ‹๏ผŸ<NL>ใƒฆใƒผใ‚ถใƒผ: ็›ฎใŒ็—›ใ„ใฎใงใ™ใ€‚<NL>ใ‚ทใ‚นใƒ†ใƒ : ๅˆ†ใ‹ใ‚Šใพใ—ใŸใ€ใ‚ณใƒณใ‚ฟใ‚ฏใƒˆใƒฌใƒณใ‚บใ‚’ใคใ‘ใ‚‹ใจ็›ฎใŒใ‹ใ‚†ใใชใ‚‹ใจใ„ใ†ใ“ใจใงใ™ใญใ€‚ๆ€ใฃใŸไปฅไธŠใซใƒฌใƒณใ‚บใ‚’ๅค–ใ™ๅฟ…่ฆใŒใ‚ใ‚‹ใงใ—ใ‚‡ใ†ใ‹๏ผŸ<NL>ใƒฆใƒผใ‚ถใƒผ: ใ„ใˆใ€ใƒฌใƒณใ‚บใฏๅค–ใ—ใพใ›ใ‚“ใŒใ€็›ฎใŒ่ตคใใชใ‚‹ใ‚“ใงใ™ใ€‚<NL>ใ‚ทใ‚นใƒ†ใƒ : "

How to use the model

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt-neox-3.6b-instruction-ppo")

if torch.cuda.is_available():
    model = model.to("cuda")

token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")

with torch.no_grad():
    output_ids = model.generate(
        token_ids.to(model.device),
        do_sample=True,
        max_new_tokens=128,
        temperature=0.7,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

output = tokenizer.decode(output_ids.tolist()[0][token_ids.size(1):])
output = output.replace("<NL>", "\n")
print(output)
"""ใใ‚Œใฏใ€ใ‚ณใƒณใ‚ฟใ‚ฏใƒˆใƒฌใƒณใ‚บใŒ็›ฎใซๅˆใ‚ใชใ„ใŸใ‚ใซ่ตทใ“ใ‚‹ใ“ใจใŒใ‚ใ‚Šใพใ™ใ€‚ใƒฌใƒณใ‚บใŒ็›ฎใฎ่กจ้ขใซ้•ทๆ™‚้–“่งฆใ‚Œ็ถšใ‘ใ‚‹ใ“ใจใŒๅŽŸๅ› ใจใชใ‚‹ใ“ใจใŒใ‚ใ‚Šใพใ™ใ€‚ใพใŸใ€ใ‚ณใƒณใ‚ฟใ‚ฏใƒˆใƒฌใƒณใ‚บใŒๆฑšใ‚Œใฆใ„ใ‚‹ๅฏ่ƒฝๆ€งใ‚‚ใ‚ใ‚Šใพใ™ใ€‚ใ‚ณใƒณใ‚ฟใ‚ฏใƒˆใƒฌใƒณใ‚บใ‚ฑใƒผใ‚นใ‚’ๅฎšๆœŸ็š„ใซๆด—ๆต„ใ—ใŸใ‚Šใ€ใ‚ณใƒณใ‚ฟใ‚ฏใƒˆใƒฌใƒณใ‚บใ‚’ๆญฃใ—ใใƒ•ใ‚ฃใƒƒใƒˆใ•ใ›ใ‚‹ใ‚ˆใ†ใซใ—ใŸใ‚Šใ™ใ‚‹ใ“ใจใŒๅฝน็ซ‹ใกใพใ™ใ€‚</s>"""

Tokenization

The model uses a sentencepiece-based tokenizer.

  • The tokenizer has a vocabulary size of 32,000.
  • It uses sentencepiece's byte fallback feature to decompose unknown text pieces into UTF-8 byte pieces and to avoid producing <UNK> tokens.
  • sentencepiece's --add_dummy_prefix option was turned off so that a leading whitespace will not be prepended automatically.
      print(tokenizer.tokenize("ๅพ่ผฉใฏ็Œซใงใ‚ใ‚‹"))
      # ['ๅพ', '่ผฉ', 'ใฏ', '็Œซ', 'ใงใ‚ใ‚‹']
      # instead of ['โ–', 'ๅพ', '่ผฉ', 'ใฏ', '็Œซ', 'ใงใ‚ใ‚‹'] as in rinna/japanese-gpt-1b
    
  • sentencepiece's --remove_extra_whitespaces option was turned off so that leading, trailing, and duplicate whitespaces are reserved.
      print(tokenizer.tokenize("  ๅพ่ผฉใฏ  ็Œซใงใ‚ใ‚‹   "))
      # ['โ–', 'โ–', 'ๅพ', '่ผฉ', 'ใฏ', 'โ–', 'โ–', '็Œซ', 'ใงใ‚ใ‚‹', 'โ–', 'โ–', 'โ–']
      # instead of ['โ–', 'ๅพ', '่ผฉ', 'ใฏ', 'โ–็Œซ', 'ใงใ‚ใ‚‹'] as in rinna/japanese-gpt-1b
    
  • Don't forget to set use_fast=False to make the above features function correctly.
      good_tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b", use_fast=False)
      bad_tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt-neox-3.6b")
    
      print(good_tokenizer.decode(good_tokenizer.encode("แƒ’แƒแƒ›แƒแƒ แƒฏแƒแƒ‘แƒ  ๅพ่ผฉใฏ  ็Œซใงใ‚ใ‚‹   ")))
      # 'แƒ’แƒแƒ›แƒแƒ แƒฏแƒแƒ‘แƒ  ๅพ่ผฉใฏ  ็Œซใงใ‚ใ‚‹   </s>'
      print(bad_tokenizer.decode(bad_tokenizer.encode("แƒ’แƒแƒ›แƒแƒ แƒฏแƒแƒ‘แƒ  ๅพ่ผฉใฏ  ็Œซใงใ‚ใ‚‹   ")))
      # 'แƒ’แƒแƒ›แƒแƒ [UNK]แƒแƒ‘แƒ ๅพ่ผฉใฏ ็Œซใงใ‚ใ‚‹ </s>'
    

How to cite

@misc{rinna-japanese-gpt-neox-3.6b-instruction-ppo,
    title = {rinna/japanese-gpt-neox-3.6b-instruction-ppo},
    author = {Zhao, Tianyu and Sawada, Kei},
    url = {https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-ppo}
}

@inproceedings{sawada2024release,
    title = {Release of Pre-Trained Models for the {J}apanese Language},
    author = {Sawada, Kei and Zhao, Tianyu and Shing, Makoto and Mitsui, Kentaro and Kaga, Akio and Hono, Yukiya and Wakatsuki, Toshiaki and Mitsuda, Koh},
    booktitle = {Proceedings of the 2024 Joint International Conference on Computational Linguistics, Language Resources and Evaluation (LREC-COLING 2024)},
    month = {5},
    year = {2024},
    pages = {13898--13905},
    url = {https://aclanthology.org/2024.lrec-main.1213},
    note = {\url{https://arxiv.org/abs/2404.01657}}
}

Licenese

The MIT license

Downloads last month
4,256
Safetensors
Model size
3.76B params
Tensor type
BF16
ยท
BOOL
ยท
Inference Examples
Inference API (serverless) has been turned off for this model.

Model tree for rinna/japanese-gpt-neox-3.6b-instruction-ppo

Finetuned
(5)
this model
Quantizations
1 model

Dataset used to train rinna/japanese-gpt-neox-3.6b-instruction-ppo

Space using rinna/japanese-gpt-neox-3.6b-instruction-ppo 1

Collection including rinna/japanese-gpt-neox-3.6b-instruction-ppo