deep-haiku-gpt-2 / README.md
librarian-bot's picture
Librarian Bot: Add base_model information to model
60787b1
|
raw
history blame
2.64 kB
metadata
license: mit
tags:
  - generated_from_trainer
base_model: gpt2
model-index:
  - name: deep-haiku-gpt-2
    results: []

deep-haiku-gpt-2

This model is a fine-tuned version of gpt2 on the haiku dataset.

Model description

The model is a fine-tuned version of GPT-2 for generation of Haikus. The model, data and training procedure is inspired by a blog post by Robert A. Gonsalves. Instead of using a 8bit version of GPT-J 6B, we instead used vanilla GPT-2. From what we saw, the model performance comparable but is much easier to fine-tune.

We used the same multitask training approach as in der post, but significantly extended the dataset (almost double the size of the original on). A prepared version of the dataset can be found here.

Intended uses & limitations

The model is intended to generate Haikus. To do so, it was trained using a multitask learning approach (see Caruana 1997) with the following four different tasks: :

  • topic2graphemes (keywords = text)
  • topic2phonemes <keyword_phonemes = text_phonemes>
  • graphemes2phonemes [text = text_phonemes]
  • phonemes2graphemes {text_phonemes = text}

To use the model, use an appropriate prompt like "(dog rain =" and let the model generate a Haiku given the keyword.

Training and evaluation data

We used a collection of existing haikus for training. Furthermore, all haikus were used in their graphemes version as well as a phonemes version. In addition, we extracted key word for all haikus using KeyBERT and sorted out haikus with a low text quality according to the GRUEN score.

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 5e-05
  • train_batch_size: 2
  • eval_batch_size: 2
  • seed: 42
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_steps: 100
  • num_epochs: 10

Training results

Framework versions

  • Transformers 4.19.2
  • Pytorch 1.11.0+cu102
  • Datasets 2.2.1
  • Tokenizers 0.12.1