File size: 2,818 Bytes
3ad581d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import Tensor\n",
    "\n",
    "from self_rewarding_lm_pytorch import (\n",
    "    SelfRewardingTrainer,\n",
    "    create_mock_dataset\n",
    ")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "sft fine-tuning: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 9/9 [00:03<00:00,  2.91it/s]\n",
      "generating dpo dataset with self-rewarding: 0it [00:00, ?it/s]"
     ]
    }
   ],
   "source": [
    "from x_transformers import TransformerWrapper, Decoder\n",
    "transformer = TransformerWrapper(\n",
    "    num_tokens = 256,\n",
    "    max_seq_len = 1024,\n",
    "    attn_layers = Decoder(\n",
    "        dim = 512,\n",
    "        depth = 1,\n",
    "        heads = 8\n",
    "    )\n",
    ")\n",
    "\n",
    "sft_dataset = create_mock_dataset(100, lambda: (torch.randint(0, 256, (256,)), torch.tensor(1))) # length, output(callable function) -> return class instance\n",
    "prompt_dataset = create_mock_dataset(100, lambda: 'mock prompt')\n",
    "\n",
    "def decode_tokens(tokens: Tensor) -> str:\n",
    "    decode_token = lambda token: str(chr(max(32, token))) # chr(i) : return ASCII code correspoding to i\n",
    "    return ''.join(list(map(decode_token, tokens)))\n",
    "\n",
    "def encode_str(seq_str: str) -> Tensor:\n",
    "    return Tensor(list(map(ord, seq_str))) # ord('c') : return the ASCII code of 'c'\n",
    "\n",
    "trainer = SelfRewardingTrainer(\n",
    "    transformer,\n",
    "    finetune_configs = dict(\n",
    "        train_sft_dataset = sft_dataset,\n",
    "        self_reward_prompt_dataset = prompt_dataset,\n",
    "        dpo_num_train_steps = 1000\n",
    "    ),\n",
    "    tokenizer_decode = decode_tokens,\n",
    "    tokenizer_encode = encode_str,\n",
    "    accelerate_kwargs = dict(\n",
    "        cpu = True\n",
    "    )\n",
    ")\n",
    "trainer(overwrite_checkpoints = True)\n",
    "\n",
    "\n",
    "# checkpoints after each finetuning stage will be saved to ./checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}