{ "cells": [ { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "dataset = load_dataset(\"json\", data_files=\"data-flattened.json\", split=\"train\")\n", "\n", "labels = [\"datetime\", \"description\", \"location\"]\n", "dataset = dataset.train_test_split(test_size=0.1)\n" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at google-t5/t5-small and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7f612c075ba5465b85b56fa25e5c8e91", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/69 [00:00 28\u001b[0m tokenized_data_set \u001b[38;5;241m=\u001b[39m \u001b[43mdataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpreprocess_function\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 30\u001b[0m \u001b[38;5;66;03m# Training setup (assuming you have data in optimal JSON format)\u001b[39;00m\n\u001b[1;32m 31\u001b[0m training_args \u001b[38;5;241m=\u001b[39m TrainingArguments(\n\u001b[1;32m 32\u001b[0m output_dir\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcalendar_model\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 33\u001b[0m evaluation_strategy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mepoch\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# push_to_hub=True,\u001b[39;00m\n\u001b[1;32m 43\u001b[0m )\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/dataset_dict.py:869\u001b[0m, in \u001b[0;36mDatasetDict.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_names, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, desc)\u001b[0m\n\u001b[1;32m 865\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cache_file_names \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 866\u001b[0m cache_file_names \u001b[38;5;241m=\u001b[39m {k: \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m}\n\u001b[1;32m 867\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m DatasetDict(\n\u001b[1;32m 868\u001b[0m {\n\u001b[0;32m--> 869\u001b[0m k: \u001b[43mdataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 870\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 871\u001b[0m \u001b[43m \u001b[49m\u001b[43mwith_indices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwith_indices\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 872\u001b[0m \u001b[43m \u001b[49m\u001b[43mwith_rank\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwith_rank\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 873\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_columns\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 874\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 875\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 876\u001b[0m \u001b[43m \u001b[49m\u001b[43mdrop_last_batch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdrop_last_batch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 877\u001b[0m \u001b[43m \u001b[49m\u001b[43mremove_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mremove_columns\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 878\u001b[0m \u001b[43m \u001b[49m\u001b[43mkeep_in_memory\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_in_memory\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 879\u001b[0m \u001b[43m \u001b[49m\u001b[43mload_from_cache_file\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mload_from_cache_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 880\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_file_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_file_names\u001b[49m\u001b[43m[\u001b[49m\u001b[43mk\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 881\u001b[0m \u001b[43m \u001b[49m\u001b[43mwriter_batch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwriter_batch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 882\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 883\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisable_nullable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisable_nullable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 884\u001b[0m \u001b[43m \u001b[49m\u001b[43mfn_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 885\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 886\u001b[0m \u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdesc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 887\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 888\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, dataset \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 889\u001b[0m }\n\u001b[1;32m 890\u001b[0m )\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:593\u001b[0m, in \u001b[0;36mtransmit_tasks..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28mself\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mself\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 592\u001b[0m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 593\u001b[0m out: Union[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDatasetDict\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 594\u001b[0m datasets: List[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(out\u001b[38;5;241m.\u001b[39mvalues()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[1;32m 595\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m dataset \u001b[38;5;129;01min\u001b[39;00m datasets:\n\u001b[1;32m 596\u001b[0m \u001b[38;5;66;03m# Remove task templates if a column mapping of the template is no longer valid\u001b[39;00m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:558\u001b[0m, in \u001b[0;36mtransmit_format..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 551\u001b[0m self_format \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 552\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_type,\n\u001b[1;32m 553\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mformat_kwargs\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_kwargs,\n\u001b[1;32m 554\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_columns,\n\u001b[1;32m 555\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_all_columns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_output_all_columns,\n\u001b[1;32m 556\u001b[0m }\n\u001b[1;32m 557\u001b[0m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 558\u001b[0m out: Union[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDatasetDict\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 559\u001b[0m datasets: List[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(out\u001b[38;5;241m.\u001b[39mvalues()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[1;32m 560\u001b[0m \u001b[38;5;66;03m# re-apply format to the output\u001b[39;00m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:3105\u001b[0m, in \u001b[0;36mDataset.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001b[0m\n\u001b[1;32m 3099\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m transformed_dataset \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3100\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m hf_tqdm(\n\u001b[1;32m 3101\u001b[0m unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m examples\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3102\u001b[0m total\u001b[38;5;241m=\u001b[39mpbar_total,\n\u001b[1;32m 3103\u001b[0m desc\u001b[38;5;241m=\u001b[39mdesc \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMap\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3104\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m pbar:\n\u001b[0;32m-> 3105\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrank\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontent\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mDataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_map_single\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 3106\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 3107\u001b[0m \u001b[43m \u001b[49m\u001b[43mshards_done\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:3482\u001b[0m, in \u001b[0;36mDataset._map_single\u001b[0;34m(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)\u001b[0m\n\u001b[1;32m 3478\u001b[0m indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\n\u001b[1;32m 3479\u001b[0m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m*\u001b[39m(\u001b[38;5;28mslice\u001b[39m(i, i \u001b[38;5;241m+\u001b[39m batch_size)\u001b[38;5;241m.\u001b[39mindices(shard\u001b[38;5;241m.\u001b[39mnum_rows)))\n\u001b[1;32m 3480\u001b[0m ) \u001b[38;5;66;03m# Something simpler?\u001b[39;00m\n\u001b[1;32m 3481\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3482\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[43mapply_function_on_filtered_inputs\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3483\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3484\u001b[0m \u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3485\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_same_num_examples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlist_indexes\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m>\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3486\u001b[0m \u001b[43m \u001b[49m\u001b[43moffset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffset\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3487\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3488\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m NumExamplesMismatchError:\n\u001b[1;32m 3489\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m DatasetTransformationNotAllowedError(\n\u001b[1;32m 3490\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3491\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:3361\u001b[0m, in \u001b[0;36mDataset._map_single..apply_function_on_filtered_inputs\u001b[0;34m(pa_inputs, indices, check_same_num_examples, offset)\u001b[0m\n\u001b[1;32m 3359\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m with_rank:\n\u001b[1;32m 3360\u001b[0m additional_args \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (rank,)\n\u001b[0;32m-> 3361\u001b[0m processed_inputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfn_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43madditional_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfn_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3362\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(processed_inputs, LazyDict):\n\u001b[1;32m 3363\u001b[0m processed_inputs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 3364\u001b[0m k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m processed_inputs\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m processed_inputs\u001b[38;5;241m.\u001b[39mkeys_to_format\n\u001b[1;32m 3365\u001b[0m }\n", "Cell \u001b[0;32mIn[60], line 23\u001b[0m, in \u001b[0;36mpreprocess_function\u001b[0;34m(examples)\u001b[0m\n\u001b[1;32m 20\u001b[0m inputs \u001b[38;5;241m=\u001b[39m [doc \u001b[38;5;28;01mfor\u001b[39;00m doc \u001b[38;5;129;01min\u001b[39;00m examples[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmessage\u001b[39m\u001b[38;5;124m\"\u001b[39m]]\n\u001b[1;32m 21\u001b[0m model_inputs \u001b[38;5;241m=\u001b[39m tokenizer(inputs, max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1024\u001b[39m, truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmax_length\u001b[39m\u001b[38;5;124m\"\u001b[39m) \n\u001b[0;32m---> 23\u001b[0m labels \u001b[38;5;241m=\u001b[39m tokenizer(text_target\u001b[38;5;241m=\u001b[39m\u001b[43mexamples\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msummary\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m, max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m128\u001b[39m, truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmax_length\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 25\u001b[0m model_inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m labels[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model_inputs\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/formatting/formatting.py:270\u001b[0m, in \u001b[0;36mLazyDict.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, key):\n\u001b[0;32m--> 270\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m key \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mkeys_to_format:\n\u001b[1;32m 272\u001b[0m value \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mformat(key)\n", "\u001b[0;31mKeyError\u001b[0m: 'summary'" ] } ], "source": [ "from transformers import (\n", " AutoModelForSequenceClassification,\n", " AutoTokenizer,\n", " Trainer,\n", " TextClassificationPipeline,\n", " TrainingArguments,\n", ")\n", "\n", "# Model and tokenizer selection\n", "checkpoint = \"google-t5/t5-small\" # Ensure correct model name\n", "\n", "\n", "# Configure model for multi-label classification\n", "model = AutoModelForSequenceClassification.from_pretrained(\n", " checkpoint, num_labels=len(labels)\n", ")\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", "def preprocess_function(examples):\n", " inputs = [doc for doc in examples[\"message\"]]\n", " model_inputs = tokenizer(inputs, max_length=1024, truncation=True, padding=\"max_length\")\n", "\n", " labels = tokenizer(text_target=examples[\"summary\"], max_length=128, truncation=True, padding=\"max_length\")\n", "\n", " model_inputs[\"labels\"] = labels[\"input_ids\"]\n", " return model_inputs\n", "\n", "tokenized_data_set = dataset.map(preprocess_function, batched=True)\n", "\n", "# Training setup (assuming you have data in optimal JSON format)\n", "training_args = TrainingArguments(\n", " output_dir=\"calendar_model\",\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=5e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " weight_decay=0.01,\n", " save_total_limit=3,\n", " num_train_epochs=1,\n", " use_mps_device=True,\n", " # fp16=True,\n", " # push_to_hub=True,\n", ")\n", "\n", "# Train the model\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=dataset[\"train\"],\n", " eval_dataset=dataset[\"test\"],\n", ")\n", "trainer.train()\n", "\n", "# Create pipeline for multi-label prediction\n", "pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, labels=labels)\n", "\n", "# Example usage for multi-label prediction\n", "text = \"Meeting with John at 2 pm tomorrow in the conference room\"\n", "calendar_entry = pipe(text)\n", "\n", "print(calendar_entry) # Output will be a list of dictionaries, one per label\n", "\n", "# Example: Accessing scores for the \"datetime\" label\n", "datetime_predictions = calendar_entry[0]\n", "print(datetime_predictions[\"score\"]) # List of prediction scores for \"datetime\"\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "506a9ad72c324024a186fda4e1fd7156", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/69 [00:00 1\u001b[0m tokenized_data_set \u001b[38;5;241m=\u001b[39m \u001b[43mdata_set\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpreprocess_function\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/dataset_dict.py:869\u001b[0m, in \u001b[0;36mDatasetDict.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_names, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, desc)\u001b[0m\n\u001b[1;32m 865\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m cache_file_names \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 866\u001b[0m cache_file_names \u001b[38;5;241m=\u001b[39m {k: \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m}\n\u001b[1;32m 867\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m DatasetDict(\n\u001b[1;32m 868\u001b[0m {\n\u001b[0;32m--> 869\u001b[0m k: \u001b[43mdataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 870\u001b[0m \u001b[43m \u001b[49m\u001b[43mfunction\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfunction\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 871\u001b[0m \u001b[43m \u001b[49m\u001b[43mwith_indices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwith_indices\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 872\u001b[0m \u001b[43m \u001b[49m\u001b[43mwith_rank\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwith_rank\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 873\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_columns\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 874\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatched\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 875\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 876\u001b[0m \u001b[43m \u001b[49m\u001b[43mdrop_last_batch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdrop_last_batch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 877\u001b[0m \u001b[43m \u001b[49m\u001b[43mremove_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mremove_columns\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 878\u001b[0m \u001b[43m \u001b[49m\u001b[43mkeep_in_memory\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mkeep_in_memory\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 879\u001b[0m \u001b[43m \u001b[49m\u001b[43mload_from_cache_file\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mload_from_cache_file\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 880\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_file_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_file_names\u001b[49m\u001b[43m[\u001b[49m\u001b[43mk\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 881\u001b[0m \u001b[43m \u001b[49m\u001b[43mwriter_batch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mwriter_batch_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 882\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 883\u001b[0m \u001b[43m \u001b[49m\u001b[43mdisable_nullable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdisable_nullable\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 884\u001b[0m \u001b[43m \u001b[49m\u001b[43mfn_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfn_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 885\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_proc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_proc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 886\u001b[0m \u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdesc\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 887\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 888\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, dataset \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mitems()\n\u001b[1;32m 889\u001b[0m }\n\u001b[1;32m 890\u001b[0m )\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:593\u001b[0m, in \u001b[0;36mtransmit_tasks..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 591\u001b[0m \u001b[38;5;28mself\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mself\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 592\u001b[0m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 593\u001b[0m out: Union[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDatasetDict\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 594\u001b[0m datasets: List[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(out\u001b[38;5;241m.\u001b[39mvalues()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[1;32m 595\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m dataset \u001b[38;5;129;01min\u001b[39;00m datasets:\n\u001b[1;32m 596\u001b[0m \u001b[38;5;66;03m# Remove task templates if a column mapping of the template is no longer valid\u001b[39;00m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:558\u001b[0m, in \u001b[0;36mtransmit_format..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 551\u001b[0m self_format \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 552\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_type,\n\u001b[1;32m 553\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mformat_kwargs\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_kwargs,\n\u001b[1;32m 554\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_columns,\n\u001b[1;32m 555\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_all_columns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_output_all_columns,\n\u001b[1;32m 556\u001b[0m }\n\u001b[1;32m 557\u001b[0m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 558\u001b[0m out: Union[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDatasetDict\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 559\u001b[0m datasets: List[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(out\u001b[38;5;241m.\u001b[39mvalues()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[1;32m 560\u001b[0m \u001b[38;5;66;03m# re-apply format to the output\u001b[39;00m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:3105\u001b[0m, in \u001b[0;36mDataset.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001b[0m\n\u001b[1;32m 3099\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m transformed_dataset \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3100\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m hf_tqdm(\n\u001b[1;32m 3101\u001b[0m unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m examples\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3102\u001b[0m total\u001b[38;5;241m=\u001b[39mpbar_total,\n\u001b[1;32m 3103\u001b[0m desc\u001b[38;5;241m=\u001b[39mdesc \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMap\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3104\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m pbar:\n\u001b[0;32m-> 3105\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrank\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontent\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mDataset\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_map_single\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mdataset_kwargs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 3106\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 3107\u001b[0m \u001b[43m \u001b[49m\u001b[43mshards_done\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:3482\u001b[0m, in \u001b[0;36mDataset._map_single\u001b[0;34m(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)\u001b[0m\n\u001b[1;32m 3478\u001b[0m indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\n\u001b[1;32m 3479\u001b[0m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m*\u001b[39m(\u001b[38;5;28mslice\u001b[39m(i, i \u001b[38;5;241m+\u001b[39m batch_size)\u001b[38;5;241m.\u001b[39mindices(shard\u001b[38;5;241m.\u001b[39mnum_rows)))\n\u001b[1;32m 3480\u001b[0m ) \u001b[38;5;66;03m# Something simpler?\u001b[39;00m\n\u001b[1;32m 3481\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3482\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[43mapply_function_on_filtered_inputs\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3483\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3484\u001b[0m \u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3485\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_same_num_examples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlist_indexes\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m>\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3486\u001b[0m \u001b[43m \u001b[49m\u001b[43moffset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffset\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3487\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3488\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m NumExamplesMismatchError:\n\u001b[1;32m 3489\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m DatasetTransformationNotAllowedError(\n\u001b[1;32m 3490\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3491\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/datasets/arrow_dataset.py:3361\u001b[0m, in \u001b[0;36mDataset._map_single..apply_function_on_filtered_inputs\u001b[0;34m(pa_inputs, indices, check_same_num_examples, offset)\u001b[0m\n\u001b[1;32m 3359\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m with_rank:\n\u001b[1;32m 3360\u001b[0m additional_args \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (rank,)\n\u001b[0;32m-> 3361\u001b[0m processed_inputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfn_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43madditional_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfn_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3362\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(processed_inputs, LazyDict):\n\u001b[1;32m 3363\u001b[0m processed_inputs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 3364\u001b[0m k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m processed_inputs\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m processed_inputs\u001b[38;5;241m.\u001b[39mkeys_to_format\n\u001b[1;32m 3365\u001b[0m }\n", "Cell \u001b[0;32mIn[5], line 14\u001b[0m, in \u001b[0;36mpreprocess_function\u001b[0;34m(examples)\u001b[0m\n\u001b[1;32m 11\u001b[0m inputs \u001b[38;5;241m=\u001b[39m [doc \u001b[38;5;28;01mfor\u001b[39;00m doc \u001b[38;5;129;01min\u001b[39;00m examples[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmessage\u001b[39m\u001b[38;5;124m\"\u001b[39m]]\n\u001b[1;32m 12\u001b[0m model_inputs \u001b[38;5;241m=\u001b[39m tokenizer(inputs, max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m128\u001b[39m, truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmax_length\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 14\u001b[0m labels \u001b[38;5;241m=\u001b[39m \u001b[43mtokenizer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexamples\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlabels\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m128\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtruncation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmax_length\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 16\u001b[0m model_inputs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m labels[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minput_ids\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 17\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model_inputs\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:2829\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.__call__\u001b[0;34m(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 2827\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_in_target_context_manager:\n\u001b[1;32m 2828\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_switch_to_input_mode()\n\u001b[0;32m-> 2829\u001b[0m encodings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_one\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext_pair\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtext_pair\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mall_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2830\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m text_target \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 2831\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_switch_to_target_mode()\n", "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/transformers/tokenization_utils_base.py:2887\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase._call_one\u001b[0;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 2884\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m 2886\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_valid_text_input(text):\n\u001b[0;32m-> 2887\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 2888\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2889\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mor `List[List[str]]` (batch of pretokenized examples).\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2890\u001b[0m )\n\u001b[1;32m 2892\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m text_pair \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m _is_valid_text_input(text_pair):\n\u001b[1;32m 2893\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 2894\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2895\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mor `List[List[str]]` (batch of pretokenized examples).\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 2896\u001b[0m )\n", "\u001b[0;31mValueError\u001b[0m: text input must be of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples)." ] } ], "source": [ "tokenized_data_set = data_set.map(preprocess_function, batched=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from transformers import DataCollatorForSeq2Seq" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import evaluate" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "rouge = evaluate.load(\"rouge\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "\n", "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)\n", " labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n", " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", "\n", " result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)\n", "\n", " prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]\n", " result[\"gen_len\"] = np.mean(prediction_lens)\n", "\n", " return {k: round(v, 4) for k, v in result.items()}\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer\n", "model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model moved to MPS device\n" ] } ], "source": [ "import torch\n", "\n", "# Check that MPS is available\n", "if not torch.backends.mps.is_available():\n", " if not torch.backends.mps.is_built():\n", " print(\"MPS not available because the current PyTorch install was not \"\n", " \"built with MPS enabled.\")\n", " else:\n", " print(\"MPS not available because the current MacOS version is not 12.3+ \"\n", " \"and/or you do not have an MPS-enabled device on this machine.\")\n", "\n", "else:\n", " mps_device = torch.device(\"mps\")\n", " model.to(mps_device)\n", " print(\"Model moved to MPS device\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/transformers/training_args.py:1951: UserWarning: `use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers. `mps` device will be used by default if available similar to the way `cuda` device is used.Therefore, no action from user is required. \n", " warnings.warn(\n" ] } ], "source": [ "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"calendar_model\",\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " weight_decay=0.01,\n", " save_total_limit=3,\n", " num_train_epochs=3,\n", " predict_with_generate=True,\n", " use_mps_device=True,\n", " # fp16=True,\n", " # push_to_hub=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DatasetDict({\n", " train: Dataset({\n", " features: ['details', 'message'],\n", " num_rows: 69\n", " })\n", " test: Dataset({\n", " features: ['details', 'message'],\n", " num_rows: 8\n", " })\n", "})\n" ] } ], "source": [ "print(data_set)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "trainer = Seq2SeqTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_data_set[\"train\"],\n", " eval_dataset=tokenized_data_set[\"test\"],\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "9452caa67e26493eb4c189fd55a68c32", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/15 [00:00