h1 / README.md
dvitel's picture
Update README.md
989728f
metadata
license: apache-2.0
tags:
  - distigpt2
  - hearthstone
metrics:
  - bleu
  - dvitel/codebleu
  - exact_match
  - chrf
datasets:
  - dvitel/hearthstone
model-index:
  - name: h1
    results:
      - task:
          type: text-generation
          name: Python Code Synthesis
        dataset:
          type: dvitel/hearthstone
          name: HearthStone
          split: test
        metrics:
          - type: exact_match
            value: 0.21212121212121213
            name: Exact Match
          - type: bleu
            value: 0.9637468196180485
            name: BLEU
          - type: dvitel/codebleu
            value: 0.8884667222252154
            name: CodeBLEU
          - type: chrf
            value: 96.5942286007928
            name: chrF

h1

This model is a fine-tuned version of distilgpt2 on hearthstone dataset. GitHub repo. It achieves the following results on the evaluation set:

  • Loss: 0.0890
  • Exact Match: 0.1970
  • Bleu: 0.9737
  • Codebleu: 0.9172
  • Ngram Match Score: 0.8984
  • Weighted Ngram Match Score: 0.8985
  • Syntax Match Score: 0.9293
  • Dataflow Match Score: 0.9429
  • Chrf: 97.5313

Model description

DistilGPT2 applied onto HearthStone dataset with preprocessing of python code to dumped AST. Example:

#gold labels
Module([ClassDef('Innervate', [Name('SpellCard', Load())], [], [FunctionDef('__init__', arguments([], [arg('self', None, None)], None, [], [], None, []), [Expr(Call(Attribute(Call(Name('super', Load()), [], []), '__init__', Load()), [Constant('Innervate', None), Constant(0, None), Attribute(Name('CHARACTER_CLASS', Load()), 'DRUID', Load()), Attribute(Name('CARD_RARITY', Load()), 'FREE', Load())], []))], [], None, None), FunctionDef('use', arguments([], [arg('self', None, None), arg('player', None, None), arg('game', None, None)], None, [], [], None, []), [Expr(Call(Attribute(Call(Name('super', Load()), [], []), 'use', Load()), [Name('player', Load()), Name('game', Load())], [])), If(Compare(Attribute(Name('player', Load()),'mana', Load()), [Lt()], [Constant(8, None)]), [AugAssign(Attribute(Name('player', Load()),'mana', Store()), Add(), Constant(2, None))], [Assign([Attribute(Name('player', Load()),'mana', Store())], Constant(10, None), None)])], [], None, None)], [])], [])
#wrong prediction (example of error after training)
Module([ClassDef('Innervate', [Name('SpellCard', Load())], [], [FunctionDef('__init__', arguments([], [arg('self', None, None)], None, [], [], None, []), [Expr(Call(Attribute(Call(Name('super', Load()), [], []), '__init__', Load()), [Constant('Innervate', None), Constant(0, None), Attribute(Name('CHARACTER_CLASS', Load()), 'DRUID', Load()), Attribute(Name('CARD_RARITY', Load()), 'FREE', Load())], []))], [], None, None), FunctionDef('use', arguments([], [arg('self', None, None), arg('player', None, None), arg('game', None, None)], None, [], [], None, []), [Expr(Call(Attribute(Call(Name('super', Load()), [], []), 'use', Load()), [Name('player', Load()), Name('game', Load())], [])), For(Compare(Attribute(Name('player', Load()),'maxa', Load()), [Lt()], [Constant(10, None)]), [AugAssign(Attribute(Name('player', Load()),'mana', Store()), Add(), Constant(2, None))], Exign([Name(Name('player', Load()),'mana', Store())], Constant(None, None), None)],], [], None, None)], [])], [])

Intended uses & limitations

HearthStone card code synthesis.

Training and evaluation data

See split of hearthstone dataset

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 2e-05
  • train_batch_size: 4
  • eval_batch_size: 4
  • seed: 17
  • optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: cosine
  • num_epochs: 200
  • mixed_precision_training: Native AMP

Training results

Training Loss Epoch Step Validation Loss Exact Match Bleu Codebleu Ngram Match Score Weighted Ngram Match Score Syntax Match Score Dataflow Match Score Chrf
0.3871 11.94 1600 0.1043 0.0152 0.9499 0.8549 0.8089 0.8089 0.8653 0.9366 95.4674
0.0752 23.88 3200 0.0784 0.1212 0.9640 0.8874 0.8525 0.8526 0.8929 0.9516 96.7978
0.0448 35.82 4800 0.0717 0.1364 0.9693 0.9077 0.8782 0.8782 0.9069 0.9674 97.2100
0.0308 47.76 6400 0.0752 0.1364 0.9702 0.9061 0.8808 0.8810 0.9070 0.9554 97.1896
0.0223 59.7 8000 0.0762 0.1364 0.9724 0.9050 0.8877 0.8881 0.9093 0.9348 97.4616
0.0166 71.64 9600 0.0762 0.1667 0.9733 0.9140 0.8948 0.8951 0.9197 0.9461 97.4945
0.0128 83.58 11200 0.0793 0.1515 0.9728 0.9085 0.8911 0.8918 0.9189 0.9321 97.4152
0.0104 95.52 12800 0.0822 0.1667 0.9732 0.9165 0.8946 0.8950 0.9222 0.9541 97.4887
0.0084 107.46 14400 0.0832 0.1667 0.9737 0.9167 0.8970 0.8972 0.9254 0.9471 97.5326
0.007 119.4 16000 0.0837 0.1818 0.9743 0.9160 0.8983 0.8986 0.9238 0.9434 97.6638
0.0058 131.34 17600 0.0858 0.1818 0.9739 0.9200 0.8977 0.8977 0.9267 0.9579 97.5583
0.005 143.28 19200 0.0878 0.1818 0.9743 0.9180 0.8993 0.9001 0.9301 0.9426 97.5819
0.0044 155.22 20800 0.0877 0.1667 0.9736 0.9156 0.8957 0.8960 0.9278 0.9429 97.5109
0.0042 167.16 22400 0.0890 0.1970 0.9736 0.9171 0.8984 0.8984 0.9293 0.9424 97.5617
0.0038 179.1 24000 0.0891 0.2121 0.9738 0.9174 0.8991 0.8991 0.9285 0.9429 97.5452
0.0037 191.04 25600 0.0890 0.1970 0.9737 0.9172 0.8984 0.8985 0.9293 0.9429 97.5313

Framework versions

  • Transformers 4.24.0
  • Pytorch 1.13.0
  • Datasets 2.6.1
  • Tokenizers 0.13.1