{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "2c3bb18a", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:49:05.484451Z", "iopub.status.busy": "2022-04-18T01:49:05.482966Z", "iopub.status.idle": "2022-04-18T01:49:22.249321Z", "shell.execute_reply": "2022-04-18T01:49:22.248692Z", "shell.execute_reply.started": "2022-04-16T12:16:29.630467Z" }, "papermill": { "duration": 16.788107, "end_time": "2022-04-18T01:49:22.249468", "exception": false, "start_time": "2022-04-18T01:49:05.461361", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'Poetry'...\n", "remote: Enumerating objects: 135, done.\u001b[K\n", "remote: Total 135 (delta 0), reused 0 (delta 0), pack-reused 135\u001b[K\n", "Receiving objects: 100% (135/135), 123.55 MiB | 12.33 MiB/s, done.\n", "Resolving deltas: 100% (77/77), done.\n", "Updating files: 100% (39/39), done.\n" ] } ], "source": [ "#!wget https://raw.githubusercontent.com/youyuge34/Poems_generator_Keras/master/dataset/poetry.txt\n", "!git clone https://github.com/Werneror/Poetry.git" ] }, { "cell_type": "code", "execution_count": 2, "id": "d76b15a8", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:49:22.322907Z", "iopub.status.busy": "2022-04-18T01:49:22.322113Z", "iopub.status.idle": "2022-04-18T01:49:28.965795Z", "shell.execute_reply": "2022-04-18T01:49:28.965246Z", "shell.execute_reply.started": "2022-04-16T12:16:41.322744Z" }, "papermill": { "duration": 6.678735, "end_time": "2022-04-18T01:49:28.965944", "exception": false, "start_time": "2022-04-18T01:49:22.287209", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from transformers import GPT2Config, GPT2LMHeadModel\n", "from transformers import TrainingArguments, Trainer" ] }, { "cell_type": "code", "execution_count": 3, "id": "de8c9caa", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:49:29.041708Z", "iopub.status.busy": "2022-04-18T01:49:29.040006Z", "iopub.status.idle": "2022-04-18T01:49:34.168753Z", "shell.execute_reply": "2022-04-18T01:49:34.169341Z", "shell.execute_reply.started": "2022-04-16T12:16:48.16115Z" }, "papermill": { "duration": 5.16885, "end_time": "2022-04-18T01:49:34.169515", "exception": false, "start_time": "2022-04-18T01:49:29.000665", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
题目朝代作者内容
0彭生行何景明岷峨山根江水坼,万里波涛混吴越。倾湖倒海不可量,仰看一线青天上。郁蓝秀色盘三巴,间产锦石兼丹...
1黄河篇何景明黄河昆崙源,九曲与天通。银汉贯箕尾,左盘日月宫。奔流下龙门,喷薄沙海风。三山万里倚穷发,鳖极...
2三清山人歌何景明山人佩剑冠远游,腰间鞶囊垂虎头,七星照耀金银钩。东行策杖指卢霍,逝将沧海寻丹丘。三清西南龙虎...
3昔游篇何景明三星烂夜河汉流,觞行瑟作中堂幽。李君勿叹息,薛君且停讴。英英孟夫子,听我当筵歌昔游。昔游少年...
4赠商三何景明去冬雪雨留蓟门,开筵谑浪倒金樽。今春灯火到长安,过门不肯回银鞍。燕山花隔平山柳,马上东风几回首。
\n", "
" ], "text/plain": [ " 题目 朝代 作者 内容\n", "0 彭生行 明 何景明 岷峨山根江水坼,万里波涛混吴越。倾湖倒海不可量,仰看一线青天上。郁蓝秀色盘三巴,间产锦石兼丹...\n", "1 黄河篇 明 何景明 黄河昆崙源,九曲与天通。银汉贯箕尾,左盘日月宫。奔流下龙门,喷薄沙海风。三山万里倚穷发,鳖极...\n", "2 三清山人歌 明 何景明 山人佩剑冠远游,腰间鞶囊垂虎头,七星照耀金银钩。东行策杖指卢霍,逝将沧海寻丹丘。三清西南龙虎...\n", "3 昔游篇 明 何景明 三星烂夜河汉流,觞行瑟作中堂幽。李君勿叹息,薛君且停讴。英英孟夫子,听我当筵歌昔游。昔游少年...\n", "4 赠商三 明 何景明 去冬雪雨留蓟门,开筵谑浪倒金樽。今春灯火到长安,过门不肯回银鞍。燕山花隔平山柳,马上东风几回首。" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "data = None\n", "for (dirpath, dirnames, filenames) in os.walk(\"Poetry\"):\n", " for filename in filenames:\n", " if filename.endswith(\"csv\"):\n", " cur_data = pd.read_csv(f\"Poetry/{filename}\")\n", " if data is None:\n", " data = cur_data\n", " else:\n", " data = pd.concat([data, cur_data])\n", "data.head()" ] }, { "cell_type": "code", "execution_count": 4, "id": "40c84fbf", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:49:34.242596Z", "iopub.status.busy": "2022-04-18T01:49:34.241754Z", "iopub.status.idle": "2022-04-18T01:49:34.244196Z", "shell.execute_reply": "2022-04-18T01:49:34.243782Z", "shell.execute_reply.started": "2022-04-16T12:16:53.639047Z" }, "papermill": { "duration": 0.041531, "end_time": "2022-04-18T01:49:34.244315", "exception": false, "start_time": "2022-04-18T01:49:34.202784", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import re\n", "\n", "def verse_length(verses):\n", " return len(verses[0])\n", "\n", "def verse_heads(verses):\n", " verse_heads = [verse[0] for verse in verses]\n", " return \"\".join(verse_heads)\n", "\n", "def split_poem(poem):\n", " return [verse for verse in re.split(\",|。\", poem) if len(verse)]\n", " \n", "def is_correct_length(poem, max_length, min_length):\n", " return len(poem) < max_length and len(poem) > min_length\n", " \n", "def is_equal_length(verses):\n", " verse_lengths = [len(verse) for verse in verses]\n", " for length in verse_lengths:\n", " if length != verse_lengths[0]:\n", " return False\n", " return True " ] }, { "cell_type": "code", "execution_count": 5, "id": "4fd4df65", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:49:34.407430Z", "iopub.status.busy": "2022-04-18T01:49:34.406391Z", "iopub.status.idle": "2022-04-18T01:49:47.517219Z", "shell.execute_reply": "2022-04-18T01:49:47.516725Z", "shell.execute_reply.started": "2022-04-16T12:16:53.648455Z" }, "papermill": { "duration": 13.240486, "end_time": "2022-04-18T01:49:47.517350", "exception": false, "start_time": "2022-04-18T01:49:34.276864", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " \n", "/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", " import sys\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Number of valid poems: 617674\n" ] } ], "source": [ "data = data[~data[\"内容\"].isna()]\n", "data['verses'] = [split_poem(poem) for poem in data['内容']]\n", "data['equal_verse_lengths'] = [is_equal_length(verses) for verses in data['verses']]\n", "data['meet_length_requirements'] = [is_correct_length(poem, 100, 20) for poem in data['内容']]\n", "valid_poems = data[data['equal_verse_lengths'] & data['meet_length_requirements']]\n", "valid_poems['verse_lengths'] = [verse_length(verses) for verses in valid_poems['verses']]\n", "valid_poems['verse_heads'] = [verse_heads(verses) for verses in valid_poems['verses']]\n", "valid_poems = valid_poems[valid_poems['verse_lengths'] < 10]\n", "print(f\"Number of valid poems: {len(valid_poems)}\")" ] }, { "cell_type": "code", "execution_count": 6, "id": "f86c5f9c", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:49:47.589515Z", "iopub.status.busy": "2022-04-18T01:49:47.588746Z", "iopub.status.idle": "2022-04-18T01:49:47.601086Z", "shell.execute_reply": "2022-04-18T01:49:47.600657Z", "shell.execute_reply.started": "2022-04-16T12:17:06.029609Z" }, "papermill": { "duration": 0.049888, "end_time": "2022-04-18T01:49:47.601199", "exception": false, "start_time": "2022-04-18T01:49:47.551311", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
题目朝代作者内容versesequal_verse_lengthsmeet_length_requirementsverse_lengthsverse_heads
4赠商三何景明去冬雪雨留蓟门,开筵谑浪倒金樽。今春灯火到长安,过门不肯回银鞍。燕山花隔平山柳,马上东风几回首。[去冬雪雨留蓟门, 开筵谑浪倒金樽, 今春灯火到长安, 过门不肯回银鞍, 燕山花隔平山柳, ...TrueTrue7去开今过燕马
14送叶生还闽中兼怀郑继之何景明叶生行吟燕中市,葛巾麻鞋岁将晚。两都为客今始归,五岳寻仙不辞远。江南画舸春柳低,海上茅堂白云...[叶生行吟燕中市, 葛巾麻鞋岁将晚, 两都为客今始归, 五岳寻仙不辞远, 江南画舸春柳低, ...TrueTrue7叶葛两五江海谷为
15送林利正同知之潮阳何景明忆在成均共携手,泉山门下相知久。万里恩情若父兄,十年道义惭师友。君才岂孤一第名,佩刀今作岭南...[忆在成均共携手, 泉山门下相知久, 万里恩情若父兄, 十年道义惭师友, 君才岂孤一第名, ...TrueTrue7忆泉万十君佩挂伐燕相过道
16金陵歌送李先生何景明李公为舅有吕甥,甥舅四海皆知名。吕君关西昨日去,公自金陵来复行。金陵江水无断绝,金陵之山高巀...[李公为舅有吕甥, 甥舅四海皆知名, 吕君关西昨日去, 公自金陵来复行, 金陵江水无断绝, ...TrueTrue7李甥吕公金金龙星白清燕
21延津歌送韩令何景明延津寇过馀少男,延津县令莫停骖。双凫直向黄河北,一雁先飞清卫南。黄河岸边不种麦,浊浪滔天多贾...[延津寇过馀少男, 延津县令莫停骖, 双凫直向黄河北, 一雁先飞清卫南, 黄河岸边不种麦, ...TrueTrue7延延双一黄浊城县
\n", "
" ], "text/plain": [ " 题目 朝代 作者 内容 \\\n", "4 赠商三 明 何景明 去冬雪雨留蓟门,开筵谑浪倒金樽。今春灯火到长安,过门不肯回银鞍。燕山花隔平山柳,马上东风几回首。 \n", "14 送叶生还闽中兼怀郑继之 明 何景明 叶生行吟燕中市,葛巾麻鞋岁将晚。两都为客今始归,五岳寻仙不辞远。江南画舸春柳低,海上茅堂白云... \n", "15 送林利正同知之潮阳 明 何景明 忆在成均共携手,泉山门下相知久。万里恩情若父兄,十年道义惭师友。君才岂孤一第名,佩刀今作岭南... \n", "16 金陵歌送李先生 明 何景明 李公为舅有吕甥,甥舅四海皆知名。吕君关西昨日去,公自金陵来复行。金陵江水无断绝,金陵之山高巀... \n", "21 延津歌送韩令 明 何景明 延津寇过馀少男,延津县令莫停骖。双凫直向黄河北,一雁先飞清卫南。黄河岸边不种麦,浊浪滔天多贾... \n", "\n", " verses equal_verse_lengths \\\n", "4 [去冬雪雨留蓟门, 开筵谑浪倒金樽, 今春灯火到长安, 过门不肯回银鞍, 燕山花隔平山柳, ... True \n", "14 [叶生行吟燕中市, 葛巾麻鞋岁将晚, 两都为客今始归, 五岳寻仙不辞远, 江南画舸春柳低, ... True \n", "15 [忆在成均共携手, 泉山门下相知久, 万里恩情若父兄, 十年道义惭师友, 君才岂孤一第名, ... True \n", "16 [李公为舅有吕甥, 甥舅四海皆知名, 吕君关西昨日去, 公自金陵来复行, 金陵江水无断绝, ... True \n", "21 [延津寇过馀少男, 延津县令莫停骖, 双凫直向黄河北, 一雁先飞清卫南, 黄河岸边不种麦, ... True \n", "\n", " meet_length_requirements verse_lengths verse_heads \n", "4 True 7 去开今过燕马 \n", "14 True 7 叶葛两五江海谷为 \n", "15 True 7 忆泉万十君佩挂伐燕相过道 \n", "16 True 7 李甥吕公金金龙星白清燕 \n", "21 True 7 延延双一黄浊城县 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "valid_poems.head()" ] }, { "cell_type": "code", "execution_count": 7, "id": "33140481", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:49:47.695680Z", "iopub.status.busy": "2022-04-18T01:49:47.694684Z", "iopub.status.idle": "2022-04-18T01:49:47.696696Z", "shell.execute_reply": "2022-04-18T01:49:47.697169Z", "shell.execute_reply.started": "2022-04-16T12:23:28.401922Z" }, "papermill": { "duration": 0.06126, "end_time": "2022-04-18T01:49:47.697307", "exception": false, "start_time": "2022-04-18T01:49:47.636047", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import torch, json\n", "\n", "class CharTokenizer:\n", " def __init__(self, corpus=None, vocab=None):\n", " if vocab is not None:\n", " self.vocab = vocab\n", " elif corpus is not None:\n", " self.vocab = self._build_vocab(corpus)\n", " else:\n", " raise Exception(\"Either corpus or vocab has to be supplied\")\n", " self.id2vocab = [char for char, index in sorted(self.vocab.items(), key=lambda item: item[1])]\n", " \n", " def _tokenize(self, text):\n", " return list(text)\n", " \n", " def __call__(self, prompt, text=None, add_eos_token=False):\n", " token_ids = [self.vocab.get(token, 0) for token in self._tokenize(prompt)]\n", " if text is not None:\n", " text_token_ids = [self.vocab.get(token, 0) for token in self._tokenize(text)]\n", " token_ids = token_ids + [self.vocab[\"\"]] + text_token_ids\n", " if add_eos_token:\n", " token_ids = token_ids + [self.vocab[\"\"]]\n", " input_ids_tensor = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0)\n", " attention_masks = torch.ones_like(input_ids_tensor)\n", " return {\"input_ids\": input_ids_tensor, \"attention_mask\": attention_masks}\n", " \n", " def _build_vocab(self, corpus):\n", " vocab = {\"\": 0}\n", " for verse_lengths in range(3, 10):\n", " vocab[str(verse_lengths)] = len(vocab)\n", " for doc in corpus:\n", " chars = self._tokenize(doc)\n", " for char in chars:\n", " if char not in vocab:\n", " vocab[char] = len(vocab)\n", " vocab[\"\"] = len(vocab)\n", " vocab[\"\"] = len(vocab)\n", " return vocab\n", " \n", " def decode(self, token_ids):\n", " chars = [self.id2vocab[token_id] for token_id in token_ids.flatten().tolist()]\n", " filtered_chars = [char for char in chars if char not in [\"\", \"\", \"\"]]\n", " return \"\".join(filtered_chars)\n", " \n", " def save(self, filepath):\n", " with open(filepath, \"w\") as f:\n", " json.dump(self.vocab, f)\n", " \n", " @classmethod\n", " def load(cls, filepath):\n", " with open(filepath) as f:\n", " vocab = json.load(f)\n", " return cls(vocab=vocab)" ] }, { "cell_type": "code", "execution_count": 8, "id": "73f55174", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:49:47.806040Z", "iopub.status.busy": "2022-04-18T01:49:47.795805Z", "iopub.status.idle": "2022-04-18T01:49:51.506784Z", "shell.execute_reply": "2022-04-18T01:49:51.506307Z", "shell.execute_reply.started": "2022-04-16T12:23:28.57419Z" }, "papermill": { "duration": 3.770368, "end_time": "2022-04-18T01:49:51.506972", "exception": false, "start_time": "2022-04-18T01:49:47.736604", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "tokenizer = CharTokenizer(valid_poems['内容'])\n", "tokenizer.save(\"/kaggle/working/tokenizer.json\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "2d0c4b52", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:49:51.587046Z", "iopub.status.busy": "2022-04-18T01:49:51.578743Z", "iopub.status.idle": "2022-04-18T01:50:13.120701Z", "shell.execute_reply": "2022-04-18T01:50:13.121126Z", "shell.execute_reply.started": "2022-04-16T12:35:45.273336Z" }, "papermill": { "duration": 21.579069, "end_time": "2022-04-18T01:50:13.121274", "exception": false, "start_time": "2022-04-18T01:49:51.542205", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "123\n" ] } ], "source": [ "tokenized_dataset = [tokenizer(prompt = str(length) + heads, text=poem, add_eos_token=True) for poem, length, heads in zip(valid_poems['内容'],\n", " valid_poems['verse_lengths'],\n", " valid_poems['verse_heads'])]\n", "train_dataset, val_dataset = train_test_split(tokenized_dataset, test_size=0.02, random_state=1234)\n", "max_lengths = max([tokenized[\"input_ids\"].size(1) for tokenized in tokenized_dataset])\n", "print(max_lengths)" ] }, { "cell_type": "code", "execution_count": 10, "id": "a4e831ab", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:50:13.232157Z", "iopub.status.busy": "2022-04-18T01:50:13.231258Z", "iopub.status.idle": "2022-04-18T01:50:13.233058Z", "shell.execute_reply": "2022-04-18T01:50:13.233434Z", "shell.execute_reply.started": "2022-04-16T12:24:19.850932Z" }, "papermill": { "duration": 0.075455, "end_time": "2022-04-18T01:50:13.233582", "exception": false, "start_time": "2022-04-18T01:50:13.158127", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "PAD_TOKEN_ID = 0\n", "\n", "def collate_fn(batch_inputs):\n", " seq_lengths = [i[\"input_ids\"].size(1) for i in batch_inputs]\n", " max_length = max(seq_lengths)\n", " input_ids = torch.full((len(batch_inputs), max_length), PAD_TOKEN_ID, dtype=torch.long)\n", " attention_mask = torch.full((len(batch_inputs), max_length), 0, dtype=torch.long)\n", " for idx, inputs in enumerate(batch_inputs):\n", " input_ids[idx, :seq_lengths[idx]] = inputs[\"input_ids\"]\n", " attention_mask[idx, :seq_lengths[idx]] = 1\n", " labels = input_ids.clone()\n", " labels[labels == PAD_TOKEN_ID] = -100\n", " return {\"input_ids\": input_ids, \"attention_mask\": attention_mask, \"labels\": labels}" ] }, { "cell_type": "code", "execution_count": 11, "id": "193e7672", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:50:13.312349Z", "iopub.status.busy": "2022-04-18T01:50:13.308720Z", "iopub.status.idle": "2022-04-18T01:50:16.181794Z", "shell.execute_reply": "2022-04-18T01:50:16.182874Z", "shell.execute_reply.started": "2022-04-16T12:33:23.688559Z" }, "papermill": { "duration": 2.914467, "end_time": "2022-04-18T01:50:16.183073", "exception": false, "start_time": "2022-04-18T01:50:13.268606", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of trainable parameters: 50873088\n" ] } ], "source": [ "config = GPT2Config(vocab_size = len(tokenizer.vocab),\n", " n_positions = max_lengths,\n", " n_embd = 768,\n", " n_layer = 6,\n", " n_head = 12,\n", " eos_token_id=tokenizer.vocab[\"\"],\n", " bos_token_id=tokenizer.vocab[\"\"])\n", "model = GPT2LMHeadModel(config)\n", "num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "print(f\"Number of trainable parameters: {num_parameters}\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "484c0fc2", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:50:16.302344Z", "iopub.status.busy": "2022-04-18T01:50:16.301561Z", "iopub.status.idle": "2022-04-18T01:50:21.013819Z", "shell.execute_reply": "2022-04-18T01:50:21.014253Z", "shell.execute_reply.started": "2022-04-16T12:24:46.722086Z" }, "papermill": { "duration": 4.776549, "end_time": "2022-04-18T01:50:21.014420", "exception": false, "start_time": "2022-04-18T01:50:16.237871", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using amp half precision backend\n" ] } ], "source": [ "from transformers import EarlyStoppingCallback\n", "training_args = TrainingArguments(\n", " output_dir=\"results\",\n", " eval_steps=2000,\n", " save_steps=2000,\n", " evaluation_strategy=\"steps\",\n", " learning_rate=3e-4,\n", " per_device_train_batch_size=32,\n", " per_device_eval_batch_size=64,\n", " save_total_limit=2,\n", " num_train_epochs=8,\n", " fp16=True,\n", " report_to=\"none\",\n", " dataloader_num_workers=2,\n", " group_by_length=True,\n", " metric_for_best_model = 'loss',\n", " load_best_model_at_end=True\n", ")\n", "\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset,\n", " data_collator=collate_fn,\n", " callbacks = [EarlyStoppingCallback(early_stopping_patience=1)]\n", ")" ] }, { "cell_type": "code", "execution_count": 13, "id": "fbc93ddf", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T01:50:21.089679Z", "iopub.status.busy": "2022-04-18T01:50:21.089153Z", "iopub.status.idle": "2022-04-18T05:43:12.456180Z", "shell.execute_reply": "2022-04-18T05:43:12.455654Z", "shell.execute_reply.started": "2022-04-16T12:25:06.616641Z" }, "papermill": { "duration": 13971.40658, "end_time": "2022-04-18T05:43:12.456310", "exception": false, "start_time": "2022-04-18T01:50:21.049730", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.7/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " FutureWarning,\n", "***** Running training *****\n", " Num examples = 605320\n", " Num Epochs = 8\n", " Instantaneous batch size per device = 32\n", " Total train batch size (w. parallel, distributed & accumulation) = 32\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 151336\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [ 58000/151336 3:52:48 < 6:14:39, 4.15 it/s, Epoch 3/8]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining LossValidation Loss
20004.3677004.235631
40003.9533003.883913
60003.7907003.730361
80003.6995003.639758
100003.6265003.581570
120003.5758003.529477
140003.5395003.490788
160003.5061003.457211
180003.4711003.427910
200003.4117003.404946
220003.3885003.384355
240003.3845003.362393
260003.3639003.345612
280003.3506003.330873
300003.3393003.316820
320003.3206003.303108
340003.3166003.286899
360003.3129003.277738
380003.2725003.271317
400003.2281003.260200
420003.2320003.252335
440003.2205003.247865
460003.2197003.236358
480003.2180003.228396
500003.2149003.219474
520003.2071003.213028
540003.2068003.206626
560003.1962003.197654
580003.1250003.197687

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-2000\n", "Configuration saved in results/checkpoint-2000/config.json\n", "Model weights saved in results/checkpoint-2000/pytorch_model.bin\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-4000\n", "Configuration saved in results/checkpoint-4000/config.json\n", "Model weights saved in results/checkpoint-4000/pytorch_model.bin\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-6000\n", "Configuration saved in results/checkpoint-6000/config.json\n", "Model weights saved in results/checkpoint-6000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-2000] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-8000\n", "Configuration saved in results/checkpoint-8000/config.json\n", "Model weights saved in results/checkpoint-8000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-4000] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-10000\n", "Configuration saved in results/checkpoint-10000/config.json\n", "Model weights saved in results/checkpoint-10000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-6000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-12000\n", "Configuration saved in results/checkpoint-12000/config.json\n", "Model weights saved in results/checkpoint-12000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-8000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-14000\n", "Configuration saved in results/checkpoint-14000/config.json\n", "Model weights saved in results/checkpoint-14000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-10000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-16000\n", "Configuration saved in results/checkpoint-16000/config.json\n", "Model weights saved in results/checkpoint-16000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-12000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-18000\n", "Configuration saved in results/checkpoint-18000/config.json\n", "Model weights saved in results/checkpoint-18000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-14000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-20000\n", "Configuration saved in results/checkpoint-20000/config.json\n", "Model weights saved in results/checkpoint-20000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-16000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-22000\n", "Configuration saved in results/checkpoint-22000/config.json\n", "Model weights saved in results/checkpoint-22000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-18000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-24000\n", "Configuration saved in results/checkpoint-24000/config.json\n", "Model weights saved in results/checkpoint-24000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-20000] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-26000\n", "Configuration saved in results/checkpoint-26000/config.json\n", "Model weights saved in results/checkpoint-26000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-22000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-28000\n", "Configuration saved in results/checkpoint-28000/config.json\n", "Model weights saved in results/checkpoint-28000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-24000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-30000\n", "Configuration saved in results/checkpoint-30000/config.json\n", "Model weights saved in results/checkpoint-30000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-26000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-32000\n", "Configuration saved in results/checkpoint-32000/config.json\n", "Model weights saved in results/checkpoint-32000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-28000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-34000\n", "Configuration saved in results/checkpoint-34000/config.json\n", "Model weights saved in results/checkpoint-34000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-30000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-36000\n", "Configuration saved in results/checkpoint-36000/config.json\n", "Model weights saved in results/checkpoint-36000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-32000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-38000\n", "Configuration saved in results/checkpoint-38000/config.json\n", "Model weights saved in results/checkpoint-38000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-34000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-40000\n", "Configuration saved in results/checkpoint-40000/config.json\n", "Model weights saved in results/checkpoint-40000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-36000] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-42000\n", "Configuration saved in results/checkpoint-42000/config.json\n", "Model weights saved in results/checkpoint-42000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-38000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-44000\n", "Configuration saved in results/checkpoint-44000/config.json\n", "Model weights saved in results/checkpoint-44000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-40000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-46000\n", "Configuration saved in results/checkpoint-46000/config.json\n", "Model weights saved in results/checkpoint-46000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-42000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-48000\n", "Configuration saved in results/checkpoint-48000/config.json\n", "Model weights saved in results/checkpoint-48000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-44000] due to args.save_total_limit\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-50000\n", "Configuration saved in results/checkpoint-50000/config.json\n", "Model weights saved in results/checkpoint-50000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-46000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-52000\n", "Configuration saved in results/checkpoint-52000/config.json\n", "Model weights saved in results/checkpoint-52000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-48000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-54000\n", "Configuration saved in results/checkpoint-54000/config.json\n", "Model weights saved in results/checkpoint-54000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-50000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-56000\n", "Configuration saved in results/checkpoint-56000/config.json\n", "Model weights saved in results/checkpoint-56000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-52000] due to args.save_total_limit\n", "/opt/conda/lib/python3.7/site-packages/transformers/trainer.py:1410: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway. Note that the default behavior will change in a future release to error out if a non-finite total norm is encountered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.\n", " args.max_grad_norm,\n", "***** Running Evaluation *****\n", " Num examples = 12354\n", " Batch size = 64\n", "Saving model checkpoint to results/checkpoint-58000\n", "Configuration saved in results/checkpoint-58000/config.json\n", "Model weights saved in results/checkpoint-58000/pytorch_model.bin\n", "Deleting older checkpoint [results/checkpoint-54000] due to args.save_total_limit\n", "\n", "\n", "Training completed. Do not forget to share your model on huggingface.co/models =)\n", "\n", "\n", "Loading best model from results/checkpoint-56000 (score: 3.1976535320281982).\n" ] }, { "data": { "text/plain": [ "TrainOutput(global_step=58000, training_loss=3.448922660038389, metrics={'train_runtime': 13970.1599, 'train_samples_per_second': 346.636, 'train_steps_per_second': 10.833, 'total_flos': 5.124009885990912e+16, 'train_loss': 3.448922660038389, 'epoch': 3.07})" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# n_embd = 768, n_layer = 12, n_head = 12, 58k steps, 93.4 M parameters, train loss 3.150600, val loss 3.163932\n", "# n_embd = 768, n_layer = 6, n_head = 12, steps, 50.9 M parameters, train loss , val loss \n", "# n_embd = 256, n_layer = 4, n_head = 8, steps, 5.94M parameters, train loss 3.374200, val loss 3.339147\n", "# n_embd = 128, n_layer = 2, n_head = 4, 54k steps, 1.78M parameters, train loss 3.819500, val loss 3.694196\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": 14, "id": "127bea6d", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T05:43:12.684274Z", "iopub.status.busy": "2022-04-18T05:43:12.683525Z", "iopub.status.idle": "2022-04-18T05:43:12.685531Z", "shell.execute_reply": "2022-04-18T05:43:12.685926Z", "shell.execute_reply.started": "2022-04-16T12:29:27.832584Z" }, "papermill": { "duration": 0.122187, "end_time": "2022-04-18T05:43:12.686065", "exception": false, "start_time": "2022-04-18T05:43:12.563878", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def generation(prompt, length):\n", " tokens = tokenizer(prompt=str(length) + prompt)\n", " output_ids = model.generate(tokens['input_ids'].to(\"cuda\"),\n", " do_sample=True, \n", " top_k=50,\n", " top_p=0.95,\n", " max_length=100)\n", " decoded_verse = tokenizer.decode(output_ids)[5:]\n", " return decoded_verse" ] }, { "cell_type": "code", "execution_count": 15, "id": "e7f22169", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T05:43:12.909172Z", "iopub.status.busy": "2022-04-18T05:43:12.908333Z", "iopub.status.idle": "2022-04-18T05:43:13.116636Z", "shell.execute_reply": "2022-04-18T05:43:13.117086Z", "shell.execute_reply.started": "2022-04-16T12:30:03.02288Z" }, "papermill": { "duration": 0.325253, "end_time": "2022-04-18T05:43:13.117240", "exception": false, "start_time": "2022-04-18T05:43:12.791987", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:10741 for open-end generation.\n" ] }, { "data": { "text/plain": [ "'花明水在溪,好在波上得。月光忽在溪,圆明了不蚀。'" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generation(\"花好月圆\", length=5)" ] }, { "cell_type": "code", "execution_count": 16, "id": "536bd1dd", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T05:43:13.336560Z", "iopub.status.busy": "2022-04-18T05:43:13.335672Z", "iopub.status.idle": "2022-04-18T05:43:13.521122Z", "shell.execute_reply": "2022-04-18T05:43:13.521536Z", "shell.execute_reply.started": "2022-04-16T12:29:42.949166Z" }, "papermill": { "duration": 0.298044, "end_time": "2022-04-18T05:43:13.521677", "exception": false, "start_time": "2022-04-18T05:43:13.223633", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:10741 for open-end generation.\n" ] }, { "data": { "text/plain": [ "'下山来访小园中,楼阁清幽景物同。吃吃僧斋分数宿,饭松茶灶有馀功。'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generation(\"下楼吃饭\", length=7)" ] }, { "cell_type": "code", "execution_count": 17, "id": "dd75f0be", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T05:43:13.745410Z", "iopub.status.busy": "2022-04-18T05:43:13.744513Z", "iopub.status.idle": "2022-04-18T05:43:14.123442Z", "shell.execute_reply": "2022-04-18T05:43:14.123883Z", "shell.execute_reply.started": "2022-04-16T12:29:44.683058Z" }, "papermill": { "duration": 0.490314, "end_time": "2022-04-18T05:43:14.124043", "exception": false, "start_time": "2022-04-18T05:43:13.633729", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:10741 for open-end generation.\n" ] }, { "data": { "text/plain": [ "'大深无坐今夕分明是别年,晚陪花下醉清眠。加餐我自能高咏,班列君应似谪仙。大地星河连太皞,深宵星斗下华躔。无言独向閒庭静,坐对西南又一天。'" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generation(\"今晚加班\", length=7)" ] }, { "cell_type": "code", "execution_count": 18, "id": "393331e4", "metadata": { "execution": { "iopub.execute_input": "2022-04-18T05:43:14.346788Z", "iopub.status.busy": "2022-04-18T05:43:14.345916Z", "iopub.status.idle": "2022-04-18T05:43:14.539457Z", "shell.execute_reply": "2022-04-18T05:43:14.539890Z", "shell.execute_reply.started": "2022-04-16T12:29:56.371973Z" }, "papermill": { "duration": 0.307929, "end_time": "2022-04-18T05:43:14.540041", "exception": false, "start_time": "2022-04-18T05:43:14.232112", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:10741 for open-end generation.\n" ] }, { "data": { "text/plain": [ "'加餐未暇望天颜,班列群仙戏綵幡。内史赐花频赐宴,卷帘先为看朝元。'" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "generation(\"加班内卷\", length=7)" ] }, { "cell_type": "code", "execution_count": 19, "id": "ea886add", "metadata": { "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", "execution": { "iopub.execute_input": "2022-04-18T05:43:14.760813Z", "iopub.status.busy": "2022-04-18T05:43:14.759955Z", "iopub.status.idle": "2022-04-18T05:43:14.761716Z", "shell.execute_reply": "2022-04-18T05:43:14.762174Z" }, "papermill": { "duration": 0.113971, "end_time": "2022-04-18T05:43:14.762305", "exception": false, "start_time": "2022-04-18T05:43:14.648334", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "# # This Python 3 environment comes with many helpful analytics libraries installed\n", "# # It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python\n", "# # For example, here's several helpful packages to load\n", "\n", "# import numpy as np # linear algebra\n", "# import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)\n", "\n", "# # Input data files are available in the read-only \"../input/\" directory\n", "# # For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory\n", "\n", "# import os\n", "# for dirname, _, filenames in os.walk('/kaggle/input'):\n", "# for filename in filenames:\n", "# print(os.path.join(dirname, filename))\n", "\n", "# # You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using \"Save & Run All\" \n", "# # You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.10" }, "papermill": { "default_parameters": {}, "duration": 14060.414143, "end_time": "2022-04-18T05:43:17.806051", "environment_variables": {}, "exception": null, "input_path": "__notebook__.ipynb", "output_path": "__notebook__.ipynb", "parameters": {}, "start_time": "2022-04-18T01:48:57.391908", "version": "2.3.3" } }, "nbformat": 4, "nbformat_minor": 5 }