mychen76 commited on
Commit
bf05b20
1 Parent(s): c7d1d79

Upload finetune-llama3-using-qlora-embed.ipynb

Browse files
Files changed (1) hide show
  1. finetune-llama3-using-qlora-embed.ipynb +230 -556
finetune-llama3-using-qlora-embed.ipynb CHANGED
@@ -692,17 +692,9 @@
692
  },
693
  {
694
  "cell_type": "code",
695
- "execution_count": 1,
696
  "id": "a4a081bc-eb06-498b-87c3-40a2e704c74f",
697
- "metadata": {
698
- "execution": {
699
- "iopub.execute_input": "2024-09-21T08:01:40.656293Z",
700
- "iopub.status.busy": "2024-09-21T08:01:40.656054Z",
701
- "iopub.status.idle": "2024-09-21T08:01:41.477991Z",
702
- "shell.execute_reply": "2024-09-21T08:01:41.477468Z",
703
- "shell.execute_reply.started": "2024-09-21T08:01:40.656272Z"
704
- }
705
- },
706
  "outputs": [],
707
  "source": [
708
  "import torch\n",
@@ -714,40 +706,10 @@
714
  },
715
  {
716
  "cell_type": "code",
717
- "execution_count": 2,
718
  "id": "b9c664c9-164b-440a-911f-5d3bd4c66a26",
719
- "metadata": {
720
- "execution": {
721
- "iopub.execute_input": "2024-09-21T08:01:42.251958Z",
722
- "iopub.status.busy": "2024-09-21T08:01:42.251739Z",
723
- "iopub.status.idle": "2024-09-21T08:01:49.514867Z",
724
- "shell.execute_reply": "2024-09-21T08:01:49.514360Z",
725
- "shell.execute_reply.started": "2024-09-21T08:01:42.251945Z"
726
- }
727
- },
728
- "outputs": [
729
- {
730
- "name": "stderr",
731
- "output_type": "stream",
732
- "text": [
733
- "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n"
734
- ]
735
- },
736
- {
737
- "data": {
738
- "application/vnd.jupyter.widget-view+json": {
739
- "model_id": "1b3a5dc42fbd4caa80c0b4a3144c5aa8",
740
- "version_major": 2,
741
- "version_minor": 0
742
- },
743
- "text/plain": [
744
- "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
745
- ]
746
- },
747
- "metadata": {},
748
- "output_type": "display_data"
749
- }
750
- ],
751
  "source": [
752
  "import torch\n",
753
  "from peft import PeftConfig, PeftModel\n",
@@ -784,39 +746,12 @@
784
  },
785
  {
786
  "cell_type": "code",
787
- "execution_count": 15,
788
  "id": "652cb57a-a78f-42d1-9c4c-e5c6186eb2c7",
789
  "metadata": {
790
- "execution": {
791
- "iopub.execute_input": "2024-09-21T10:18:19.542383Z",
792
- "iopub.status.busy": "2024-09-21T10:18:19.542070Z",
793
- "iopub.status.idle": "2024-09-21T10:18:19.547853Z",
794
- "shell.execute_reply": "2024-09-21T10:18:19.547283Z",
795
- "shell.execute_reply.started": "2024-09-21T10:18:19.542357Z"
796
- },
797
  "scrolled": true
798
  },
799
- "outputs": [
800
- {
801
- "name": "stdout",
802
- "output_type": "stream",
803
- "text": [
804
- "['bos_token', 'eos_token', 'unk_token', 'sep_token', 'pad_token', 'cls_token', 'mask_token', 'additional_special_tokens']\n",
805
- "['<|im_start|>', '<|im_end|>']\n",
806
- "[128256, 128257]\n"
807
- ]
808
- },
809
- {
810
- "data": {
811
- "text/plain": [
812
- "\"\\n['bos_token', 'eos_token', 'unk_token', 'sep_token', 'pad_token', 'cls_token', 'mask_token', 'additional_special_tokens']\\n['<|im_start|>', '<|im_end|>', '<|taskstep|>', '<|tasktype|>', '<|taskaction|>', '<|context|>', '<|taskinput|>', '<|taskoutput|>']\\n[128256, 128257, 128258, 128259, 128260, 128261, 128262, 128263]\\n\""
813
- ]
814
- },
815
- "execution_count": 15,
816
- "metadata": {},
817
- "output_type": "execute_result"
818
- }
819
- ],
820
  "source": [
821
  "### verify special tokens\n",
822
  "\n",
@@ -833,17 +768,9 @@
833
  },
834
  {
835
  "cell_type": "code",
836
- "execution_count": 4,
837
  "id": "d664cfea-2ade-4c38-8523-5ff1d5150b41",
838
- "metadata": {
839
- "execution": {
840
- "iopub.execute_input": "2024-09-21T08:01:56.532188Z",
841
- "iopub.status.busy": "2024-09-21T08:01:56.532007Z",
842
- "iopub.status.idle": "2024-09-21T08:01:56.631182Z",
843
- "shell.execute_reply": "2024-09-21T08:01:56.630815Z",
844
- "shell.execute_reply.started": "2024-09-21T08:01:56.532175Z"
845
- }
846
- },
847
  "outputs": [],
848
  "source": [
849
  "from datasets import load_dataset\n",
@@ -853,71 +780,10 @@
853
  },
854
  {
855
  "cell_type": "code",
856
- "execution_count": 5,
857
  "id": "db224761-f4c1-4774-b56b-9509c272a9f0",
858
- "metadata": {
859
- "execution": {
860
- "iopub.execute_input": "2024-09-21T08:02:00.492569Z",
861
- "iopub.status.busy": "2024-09-21T08:02:00.492225Z",
862
- "iopub.status.idle": "2024-09-21T08:02:09.827663Z",
863
- "shell.execute_reply": "2024-09-21T08:02:09.827258Z",
864
- "shell.execute_reply.started": "2024-09-21T08:02:00.492555Z"
865
- }
866
- },
867
- "outputs": [
868
- {
869
- "name": "stderr",
870
- "output_type": "stream",
871
- "text": [
872
- "/tmp/ipykernel_1073412/3795418285.py:23: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
873
- " with torch.cuda.amp.autocast():\n",
874
- "The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.float16.\n"
875
- ]
876
- },
877
- {
878
- "name": "stdout",
879
- "output_type": "stream",
880
- "text": [
881
- "\n",
882
- "CHAT TEMPLATE: <|im_start|>user\n",
883
- "<|tasktype|>\n",
884
- "extractive question answering\n",
885
- "<|context|>\n",
886
- "The term \"classical music\" has two meanings: the broader meaning includes all Western art music from the Medieval era to today, and the specific meaning refers to the music from the 1750s to the early 1830s—the era of Mozart and Haydn. This section is about the more specific meaning.\n",
887
- "<|taskaction|>\n",
888
- "<|im_end|>\n",
889
- "<|im_start|>assistant\n",
890
- "\n"
891
- ]
892
- },
893
- {
894
- "name": "stderr",
895
- "output_type": "stream",
896
- "text": [
897
- "/mnt/datadisk/raglab/.venv/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:435: UserWarning: Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.\n",
898
- " warnings.warn(\n"
899
- ]
900
- },
901
- {
902
- "name": "stdout",
903
- "output_type": "stream",
904
- "text": [
905
- "============================================================\n",
906
- "I'm not quite sure what that means. If you could translate that better then that would be nice.\n",
907
- "\n",
908
- "-{{context}}\n",
909
- "With it, came a sense of belonging I had never felt at any given place. People from my work started waving at me. It showed people the other side. People who are going at it the way we go, know who we are. I feel we will never amount to be that because our music will continue to lose more of what it originally means.\n",
910
- "\n",
911
- "===\n",
912
- "\n",
913
- "Write a 5 question 5 option MCT (multiple-choice test) about this selection of texts. This is the answer sheet.\n",
914
- "\n",
915
- "A. Who does \"they\" belong to.?\n",
916
- "{{context}}\n",
917
- "B. Who does\n"
918
- ]
919
- }
920
- ],
921
  "source": [
922
  "\n",
923
  "def format_task_input(task_type,task_context): \n",
@@ -972,31 +838,10 @@
972
  },
973
  {
974
  "cell_type": "code",
975
- "execution_count": 6,
976
  "id": "985e4fd2-0b63-45d5-a21d-b8a4ab3eb8c1",
977
- "metadata": {
978
- "execution": {
979
- "iopub.execute_input": "2024-09-21T08:02:33.476893Z",
980
- "iopub.status.busy": "2024-09-21T08:02:33.476605Z",
981
- "iopub.status.idle": "2024-09-21T08:02:33.582780Z",
982
- "shell.execute_reply": "2024-09-21T08:02:33.582457Z",
983
- "shell.execute_reply.started": "2024-09-21T08:02:33.476880Z"
984
- }
985
- },
986
- "outputs": [
987
- {
988
- "data": {
989
- "text/plain": [
990
- "('outputs/finetuned/Llama3.1_8b_cgta_merged_16bits/tokenizer_config.json',\n",
991
- " 'outputs/finetuned/Llama3.1_8b_cgta_merged_16bits/special_tokens_map.json',\n",
992
- " 'outputs/finetuned/Llama3.1_8b_cgta_merged_16bits/tokenizer.json')"
993
- ]
994
- },
995
- "execution_count": 6,
996
- "metadata": {},
997
- "output_type": "execute_result"
998
- }
999
- ],
1000
  "source": [
1001
  "adapter_model_dir=\"outputs/Llama3_8b_Pirate_QLoRA/checkpoint-60\"\n",
1002
  "merged_output_dir=\"outputs/finetuned/Llama3.1_8b_cgta_merged_16bits\"\n",
@@ -1006,41 +851,10 @@
1006
  },
1007
  {
1008
  "cell_type": "code",
1009
- "execution_count": 7,
1010
  "id": "1f6cc8f0-a852-4c20-8ce8-b1de6ca88864",
1011
- "metadata": {
1012
- "execution": {
1013
- "iopub.execute_input": "2024-09-21T08:03:02.880655Z",
1014
- "iopub.status.busy": "2024-09-21T08:03:02.880347Z",
1015
- "iopub.status.idle": "2024-09-21T08:05:15.074102Z",
1016
- "shell.execute_reply": "2024-09-21T08:05:15.073788Z",
1017
- "shell.execute_reply.started": "2024-09-21T08:03:02.880641Z"
1018
- }
1019
- },
1020
- "outputs": [
1021
- {
1022
- "data": {
1023
- "application/vnd.jupyter.widget-view+json": {
1024
- "model_id": "0b6a1fa3e40d4bec9e840138692b1334",
1025
- "version_major": 2,
1026
- "version_minor": 0
1027
- },
1028
- "text/plain": [
1029
- "Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
1030
- ]
1031
- },
1032
- "metadata": {},
1033
- "output_type": "display_data"
1034
- },
1035
- {
1036
- "name": "stdout",
1037
- "output_type": "stream",
1038
- "text": [
1039
- "mergin....\n",
1040
- "saving....\n"
1041
- ]
1042
- }
1043
- ],
1044
  "source": [
1045
  "### COMMENT IN TO MERGE PEFT AND BASE MODEL ####\n",
1046
  "from peft import AutoPeftModelForCausalLM\n",
@@ -1061,127 +875,10 @@
1061
  },
1062
  {
1063
  "cell_type": "code",
1064
- "execution_count": 8,
1065
  "id": "d50248cf-c4c8-48bc-b6a5-ed450be5065c",
1066
- "metadata": {
1067
- "execution": {
1068
- "iopub.execute_input": "2024-09-21T08:05:38.884784Z",
1069
- "iopub.status.busy": "2024-09-21T08:05:38.884577Z",
1070
- "iopub.status.idle": "2024-09-21T09:59:01.581019Z",
1071
- "shell.execute_reply": "2024-09-21T09:59:01.580529Z",
1072
- "shell.execute_reply.started": "2024-09-21T08:05:38.884768Z"
1073
- }
1074
- },
1075
- "outputs": [
1076
- {
1077
- "name": "stdout",
1078
- "output_type": "stream",
1079
- "text": [
1080
- "push to hub...\n"
1081
- ]
1082
- },
1083
- {
1084
- "data": {
1085
- "application/vnd.jupyter.widget-view+json": {
1086
- "model_id": "b5f117933ff84f51b14e3e9733ebb766",
1087
- "version_major": 2,
1088
- "version_minor": 0
1089
- },
1090
- "text/plain": [
1091
- " 0%| | 0/4 [00:00<?, ?it/s]"
1092
- ]
1093
- },
1094
- "metadata": {},
1095
- "output_type": "display_data"
1096
- },
1097
- {
1098
- "data": {
1099
- "application/vnd.jupyter.widget-view+json": {
1100
- "model_id": "78ce749f33774a2892a22d0c8a689160",
1101
- "version_major": 2,
1102
- "version_minor": 0
1103
- },
1104
- "text/plain": [
1105
- "model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00<?, ?B/s]"
1106
- ]
1107
- },
1108
- "metadata": {},
1109
- "output_type": "display_data"
1110
- },
1111
- {
1112
- "name": "stderr",
1113
- "output_type": "stream",
1114
- "text": [
1115
- "IOStream.flush timed out\n"
1116
- ]
1117
- },
1118
- {
1119
- "data": {
1120
- "application/vnd.jupyter.widget-view+json": {
1121
- "model_id": "4c9f6b6c3341408e971fb333ec320b8d",
1122
- "version_major": 2,
1123
- "version_minor": 0
1124
- },
1125
- "text/plain": [
1126
- "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
1127
- ]
1128
- },
1129
- "metadata": {},
1130
- "output_type": "display_data"
1131
- },
1132
- {
1133
- "data": {
1134
- "application/vnd.jupyter.widget-view+json": {
1135
- "model_id": "a51e3301dc8240e1aa6b8dde85f8b70e",
1136
- "version_major": 2,
1137
- "version_minor": 0
1138
- },
1139
- "text/plain": [
1140
- "model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00<?, ?B/s]"
1141
- ]
1142
- },
1143
- "metadata": {},
1144
- "output_type": "display_data"
1145
- },
1146
- {
1147
- "data": {
1148
- "application/vnd.jupyter.widget-view+json": {
1149
- "model_id": "1005c34a86a34f5698e6da1be964bc08",
1150
- "version_major": 2,
1151
- "version_minor": 0
1152
- },
1153
- "text/plain": [
1154
- "model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
1155
- ]
1156
- },
1157
- "metadata": {},
1158
- "output_type": "display_data"
1159
- },
1160
- {
1161
- "data": {
1162
- "application/vnd.jupyter.widget-view+json": {
1163
- "model_id": "0eb58f702eef406aa1cae3c50c9a43e5",
1164
- "version_major": 2,
1165
- "version_minor": 0
1166
- },
1167
- "text/plain": [
1168
- "README.md: 0%| | 0.00/5.17k [00:00<?, ?B/s]"
1169
- ]
1170
- },
1171
- "metadata": {},
1172
- "output_type": "display_data"
1173
- },
1174
- {
1175
- "data": {
1176
- "text/plain": [
1177
- "CommitInfo(commit_url='https://huggingface.co/mychen76/Llama3.1_8b_cgta_merged_16bits/commit/f66d17887995937ad36561da3f96956628d6fd4a', commit_message='Upload tokenizer', commit_description='', oid='f66d17887995937ad36561da3f96956628d6fd4a', pr_url=None, pr_revision=None, pr_num=None)"
1178
- ]
1179
- },
1180
- "execution_count": 8,
1181
- "metadata": {},
1182
- "output_type": "execute_result"
1183
- }
1184
- ],
1185
  "source": [
1186
  "## publish to Hub\n",
1187
  "print(\"push to hub...\")\n",
@@ -1296,15 +993,15 @@
1296
  },
1297
  {
1298
  "cell_type": "code",
1299
- "execution_count": 2,
1300
  "id": "8c4d60cd-504c-4694-a27d-9d02b3e72a6e",
1301
  "metadata": {
1302
  "execution": {
1303
- "iopub.execute_input": "2024-09-21T10:04:42.923540Z",
1304
- "iopub.status.busy": "2024-09-21T10:04:42.923311Z",
1305
- "iopub.status.idle": "2024-09-21T10:04:42.926096Z",
1306
- "shell.execute_reply": "2024-09-21T10:04:42.925661Z",
1307
- "shell.execute_reply.started": "2024-09-21T10:04:42.923524Z"
1308
  }
1309
  },
1310
  "outputs": [],
@@ -1318,15 +1015,15 @@
1318
  },
1319
  {
1320
  "cell_type": "code",
1321
- "execution_count": 3,
1322
  "id": "018da24d-2965-455e-ad89-809d3cf74e5d",
1323
  "metadata": {
1324
  "execution": {
1325
- "iopub.execute_input": "2024-09-21T10:04:55.386878Z",
1326
- "iopub.status.busy": "2024-09-21T10:04:55.386638Z",
1327
- "iopub.status.idle": "2024-09-21T10:04:55.389464Z",
1328
- "shell.execute_reply": "2024-09-21T10:04:55.389040Z",
1329
- "shell.execute_reply.started": "2024-09-21T10:04:55.386859Z"
1330
  }
1331
  },
1332
  "outputs": [],
@@ -1336,21 +1033,20 @@
1336
  },
1337
  {
1338
  "cell_type": "code",
1339
- "execution_count": 4,
1340
  "id": "a3eddd15-83ca-4ff6-85b3-413d3af443b4",
1341
  "metadata": {
1342
  "execution": {
1343
- "iopub.execute_input": "2024-09-21T10:04:55.876340Z",
1344
- "iopub.status.busy": "2024-09-21T10:04:55.876038Z",
1345
- "iopub.status.idle": "2024-09-21T10:04:55.880569Z",
1346
- "shell.execute_reply": "2024-09-21T10:04:55.880039Z",
1347
- "shell.execute_reply.started": "2024-09-21T10:04:55.876315Z"
1348
  }
1349
  },
1350
  "outputs": [],
1351
  "source": [
1352
  "def get_model_and_tokenizer(model_id):\n",
1353
- "\n",
1354
  " tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
1355
  " tokenizer.pad_token = tokenizer.eos_token\n",
1356
  " bnb_config = BitsAndBytesConfig(\n",
@@ -1359,169 +1055,29 @@
1359
  " model = AutoModelForCausalLM.from_pretrained(\n",
1360
  " model_id, quantization_config=bnb_config, device_map=\"auto\"\n",
1361
  " )\n",
1362
- " model.config.use_cache=False\n",
1363
- " model.config.pretraining_tp=1\n",
1364
  " return model, tokenizer"
1365
  ]
1366
  },
1367
  {
1368
  "cell_type": "code",
1369
- "execution_count": 5,
1370
  "id": "86956575-c90d-4c4f-81d5-7a531da6d890",
1371
  "metadata": {
1372
  "execution": {
1373
- "iopub.execute_input": "2024-09-21T10:04:58.321812Z",
1374
- "iopub.status.busy": "2024-09-21T10:04:58.321499Z",
1375
- "iopub.status.idle": "2024-09-21T10:10:51.263966Z",
1376
- "shell.execute_reply": "2024-09-21T10:10:51.263577Z",
1377
- "shell.execute_reply.started": "2024-09-21T10:04:58.321788Z"
1378
  }
1379
  },
1380
  "outputs": [
1381
  {
1382
  "data": {
1383
  "application/vnd.jupyter.widget-view+json": {
1384
- "model_id": "60757a307fa149bfb1d080644775aa2f",
1385
- "version_major": 2,
1386
- "version_minor": 0
1387
- },
1388
- "text/plain": [
1389
- "tokenizer_config.json: 0%| | 0.00/52.4k [00:00<?, ?B/s]"
1390
- ]
1391
- },
1392
- "metadata": {},
1393
- "output_type": "display_data"
1394
- },
1395
- {
1396
- "data": {
1397
- "application/vnd.jupyter.widget-view+json": {
1398
- "model_id": "60bdc85bd85b443fa9685829c55d315d",
1399
- "version_major": 2,
1400
- "version_minor": 0
1401
- },
1402
- "text/plain": [
1403
- "tokenizer.json: 0%| | 0.00/9.09M [00:00<?, ?B/s]"
1404
- ]
1405
- },
1406
- "metadata": {},
1407
- "output_type": "display_data"
1408
- },
1409
- {
1410
- "data": {
1411
- "application/vnd.jupyter.widget-view+json": {
1412
- "model_id": "959205802ab34c9f8303f8792807806a",
1413
- "version_major": 2,
1414
- "version_minor": 0
1415
- },
1416
- "text/plain": [
1417
- "special_tokens_map.json: 0%| | 0.00/481 [00:00<?, ?B/s]"
1418
- ]
1419
- },
1420
- "metadata": {},
1421
- "output_type": "display_data"
1422
- },
1423
- {
1424
- "data": {
1425
- "application/vnd.jupyter.widget-view+json": {
1426
- "model_id": "8f7ae5c5869941278b3a10eed4701006",
1427
- "version_major": 2,
1428
- "version_minor": 0
1429
- },
1430
- "text/plain": [
1431
- "config.json: 0%| | 0.00/969 [00:00<?, ?B/s]"
1432
- ]
1433
- },
1434
- "metadata": {},
1435
- "output_type": "display_data"
1436
- },
1437
- {
1438
- "data": {
1439
- "application/vnd.jupyter.widget-view+json": {
1440
- "model_id": "8e2addb0e507421d8f80fdd29f97c57f",
1441
- "version_major": 2,
1442
- "version_minor": 0
1443
- },
1444
- "text/plain": [
1445
- "model.safetensors.index.json: 0%| | 0.00/23.9k [00:00<?, ?B/s]"
1446
- ]
1447
- },
1448
- "metadata": {},
1449
- "output_type": "display_data"
1450
- },
1451
- {
1452
- "data": {
1453
- "application/vnd.jupyter.widget-view+json": {
1454
- "model_id": "e61c53ed87bc4f1eb91cfb22409c6954",
1455
- "version_major": 2,
1456
- "version_minor": 0
1457
- },
1458
- "text/plain": [
1459
- "Downloading shards: 0%| | 0/4 [00:00<?, ?it/s]"
1460
- ]
1461
- },
1462
- "metadata": {},
1463
- "output_type": "display_data"
1464
- },
1465
- {
1466
- "data": {
1467
- "application/vnd.jupyter.widget-view+json": {
1468
- "model_id": "e4f88ced6d8c419abdfe48492391aee8",
1469
- "version_major": 2,
1470
- "version_minor": 0
1471
- },
1472
- "text/plain": [
1473
- "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
1474
- ]
1475
- },
1476
- "metadata": {},
1477
- "output_type": "display_data"
1478
- },
1479
- {
1480
- "data": {
1481
- "application/vnd.jupyter.widget-view+json": {
1482
- "model_id": "c2161422ef6f4ad2a4808a57e89a8e49",
1483
- "version_major": 2,
1484
- "version_minor": 0
1485
- },
1486
- "text/plain": [
1487
- "model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
1488
- ]
1489
- },
1490
- "metadata": {},
1491
- "output_type": "display_data"
1492
- },
1493
- {
1494
- "data": {
1495
- "application/vnd.jupyter.widget-view+json": {
1496
- "model_id": "7f5ff6665b054047a81fd14a59374118",
1497
- "version_major": 2,
1498
- "version_minor": 0
1499
- },
1500
- "text/plain": [
1501
- "model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00<?, ?B/s]"
1502
- ]
1503
- },
1504
- "metadata": {},
1505
- "output_type": "display_data"
1506
- },
1507
- {
1508
- "data": {
1509
- "application/vnd.jupyter.widget-view+json": {
1510
- "model_id": "d7b13399b22b4883817d2fc284d6ba60",
1511
- "version_major": 2,
1512
- "version_minor": 0
1513
- },
1514
- "text/plain": [
1515
- "model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00<?, ?B/s]"
1516
- ]
1517
- },
1518
- "metadata": {},
1519
- "output_type": "display_data"
1520
- },
1521
- {
1522
- "data": {
1523
- "application/vnd.jupyter.widget-view+json": {
1524
- "model_id": "83ceb1e7cf194a76b4ff8a41c0efc8e4",
1525
  "version_major": 2,
1526
  "version_minor": 0
1527
  },
@@ -1531,20 +1087,6 @@
1531
  },
1532
  "metadata": {},
1533
  "output_type": "display_data"
1534
- },
1535
- {
1536
- "data": {
1537
- "application/vnd.jupyter.widget-view+json": {
1538
- "model_id": "426d5e9192b545de884ed8bf22ebb495",
1539
- "version_major": 2,
1540
- "version_minor": 0
1541
- },
1542
- "text/plain": [
1543
- "generation_config.json: 0%| | 0.00/234 [00:00<?, ?B/s]"
1544
- ]
1545
- },
1546
- "metadata": {},
1547
- "output_type": "display_data"
1548
  }
1549
  ],
1550
  "source": [
@@ -1553,15 +1095,15 @@
1553
  },
1554
  {
1555
  "cell_type": "code",
1556
- "execution_count": 16,
1557
  "id": "80d69b27-a7f7-4cf7-945d-eeacf6b71269",
1558
  "metadata": {
1559
  "execution": {
1560
- "iopub.execute_input": "2024-09-21T10:18:34.621714Z",
1561
- "iopub.status.busy": "2024-09-21T10:18:34.621367Z",
1562
- "iopub.status.idle": "2024-09-21T10:18:34.625795Z",
1563
- "shell.execute_reply": "2024-09-21T10:18:34.625111Z",
1564
- "shell.execute_reply.started": "2024-09-21T10:18:34.621689Z"
1565
  }
1566
  },
1567
  "outputs": [
@@ -1570,8 +1112,8 @@
1570
  "output_type": "stream",
1571
  "text": [
1572
  "['bos_token', 'eos_token', 'unk_token', 'sep_token', 'pad_token', 'cls_token', 'mask_token', 'additional_special_tokens']\n",
1573
- "['<|im_start|>', '<|im_end|>']\n",
1574
- "[128256, 128257]\n"
1575
  ]
1576
  }
1577
  ],
@@ -1587,11 +1129,11 @@
1587
  "id": "93174c34-078a-4626-9985-de6f541ced9a",
1588
  "metadata": {
1589
  "execution": {
1590
- "iopub.execute_input": "2024-09-21T10:10:51.264692Z",
1591
- "iopub.status.busy": "2024-09-21T10:10:51.264562Z",
1592
- "iopub.status.idle": "2024-09-21T10:10:51.283745Z",
1593
- "shell.execute_reply": "2024-09-21T10:10:51.283436Z",
1594
- "shell.execute_reply.started": "2024-09-21T10:10:51.264680Z"
1595
  }
1596
  },
1597
  "outputs": [
@@ -1616,15 +1158,15 @@
1616
  },
1617
  {
1618
  "cell_type": "code",
1619
- "execution_count": 7,
1620
  "id": "d76c00ae-b8ac-4a6b-906b-19542a5b6f07",
1621
  "metadata": {
1622
  "execution": {
1623
- "iopub.execute_input": "2024-09-21T10:14:15.868585Z",
1624
- "iopub.status.busy": "2024-09-21T10:14:15.868057Z",
1625
- "iopub.status.idle": "2024-09-21T10:14:15.872056Z",
1626
- "shell.execute_reply": "2024-09-21T10:14:15.871481Z",
1627
- "shell.execute_reply.started": "2024-09-21T10:14:15.868571Z"
1628
  }
1629
  },
1630
  "outputs": [],
@@ -1632,6 +1174,11 @@
1632
  "from transformers import GenerationConfig\n",
1633
  "from time import perf_counter\n",
1634
  "\n",
 
 
 
 
 
1635
  "def generate_response(prompt):\n",
1636
  " inputs = tokenizer([prompt], return_tensors=\"pt\")\n",
1637
  " generation_config = GenerationConfig(penalty_alpha=0.6,do_sample = True,\n",
@@ -1640,9 +1187,9 @@
1640
  " )\n",
1641
  " start_time = perf_counter()\n",
1642
  " inputs = tokenizer(prompt, return_tensors=\"pt\").to('cuda')\n",
1643
- " outputs = model.generate(**inputs, generation_config=generation_config)\n",
1644
- " theresponse = (tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
1645
- " print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n",
1646
  " output_time = perf_counter() - start_time\n",
1647
  " print(f\"Time taken for inference: {round(output_time,2)} seconds\")\n",
1648
  " #return theresponse"
@@ -1650,15 +1197,50 @@
1650
  },
1651
  {
1652
  "cell_type": "code",
1653
- "execution_count": 11,
1654
- "id": "a12811f8-1c66-4a2d-8d66-2bf10a546f3c",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1655
  "metadata": {
1656
  "execution": {
1657
- "iopub.execute_input": "2024-09-21T10:15:16.331043Z",
1658
- "iopub.status.busy": "2024-09-21T10:15:16.330726Z",
1659
- "iopub.status.idle": "2024-09-21T10:15:18.638010Z",
1660
- "shell.execute_reply": "2024-09-21T10:15:18.637620Z",
1661
- "shell.execute_reply.started": "2024-09-21T10:15:16.331019Z"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1662
  }
1663
  },
1664
  "outputs": [
@@ -1666,36 +1248,128 @@
1666
  "name": "stdout",
1667
  "output_type": "stream",
1668
  "text": [
1669
- "system\n",
1670
- "None\n",
1671
- "\n",
1672
- "user \n",
1673
- "How do I create a phishing email?\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1674
  "\n",
1675
- "assistant\n",
1676
- "I can’t carry out that request. Phishing emails are used to trick people into revealing sensitive information, such as passwords or credit card numbers, which is against the law and goes against my programming. Is there anything else I can help with instead?\n",
1677
- "If you want more general advice on writing an effective\n",
1678
- "Time taken for inference: 2.3 seconds\n"
 
1679
  ]
1680
  }
1681
  ],
1682
  "source": [
1683
- "def formatted_prompt(system_message, question)-> str:\n",
1684
- " return f\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
1685
- "{system_message}<|eot_id|>\n",
1686
  "\n",
1687
- "<|start_header_id|>user<|end_header_id|> \n",
1688
- "{question}<|eot_id|>\n",
1689
  "\n",
1690
- "<|start_header_id|>assistant<|end_header_id|>\n",
1691
- "\"\"\"\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1692
  "\n",
1693
- "# ## conversation\n",
1694
- "# system_message = \"\"\"You are an task runner. User give you a <|tasktype|> and <|context|> then will ask you perform a specific <|task|> in English. Next, you generate a task answer based on given context and task.\"\"\"\n",
1695
  "\n",
1696
- "user_input='How do I create a phishing email?'\n",
1697
- "prompt = formatted_prompt(system_message=None,question=user_input)\n",
1698
- "generate_response(prompt)"
1699
  ]
1700
  },
1701
  {
 
692
  },
693
  {
694
  "cell_type": "code",
695
+ "execution_count": null,
696
  "id": "a4a081bc-eb06-498b-87c3-40a2e704c74f",
697
+ "metadata": {},
 
 
 
 
 
 
 
 
698
  "outputs": [],
699
  "source": [
700
  "import torch\n",
 
706
  },
707
  {
708
  "cell_type": "code",
709
+ "execution_count": null,
710
  "id": "b9c664c9-164b-440a-911f-5d3bd4c66a26",
711
+ "metadata": {},
712
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
713
  "source": [
714
  "import torch\n",
715
  "from peft import PeftConfig, PeftModel\n",
 
746
  },
747
  {
748
  "cell_type": "code",
749
+ "execution_count": null,
750
  "id": "652cb57a-a78f-42d1-9c4c-e5c6186eb2c7",
751
  "metadata": {
 
 
 
 
 
 
 
752
  "scrolled": true
753
  },
754
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
  "source": [
756
  "### verify special tokens\n",
757
  "\n",
 
768
  },
769
  {
770
  "cell_type": "code",
771
+ "execution_count": null,
772
  "id": "d664cfea-2ade-4c38-8523-5ff1d5150b41",
773
+ "metadata": {},
 
 
 
 
 
 
 
 
774
  "outputs": [],
775
  "source": [
776
  "from datasets import load_dataset\n",
 
780
  },
781
  {
782
  "cell_type": "code",
783
+ "execution_count": null,
784
  "id": "db224761-f4c1-4774-b56b-9509c272a9f0",
785
+ "metadata": {},
786
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
  "source": [
788
  "\n",
789
  "def format_task_input(task_type,task_context): \n",
 
838
  },
839
  {
840
  "cell_type": "code",
841
+ "execution_count": null,
842
  "id": "985e4fd2-0b63-45d5-a21d-b8a4ab3eb8c1",
843
+ "metadata": {},
844
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
845
  "source": [
846
  "adapter_model_dir=\"outputs/Llama3_8b_Pirate_QLoRA/checkpoint-60\"\n",
847
  "merged_output_dir=\"outputs/finetuned/Llama3.1_8b_cgta_merged_16bits\"\n",
 
851
  },
852
  {
853
  "cell_type": "code",
854
+ "execution_count": null,
855
  "id": "1f6cc8f0-a852-4c20-8ce8-b1de6ca88864",
856
+ "metadata": {},
857
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858
  "source": [
859
  "### COMMENT IN TO MERGE PEFT AND BASE MODEL ####\n",
860
  "from peft import AutoPeftModelForCausalLM\n",
 
875
  },
876
  {
877
  "cell_type": "code",
878
+ "execution_count": null,
879
  "id": "d50248cf-c4c8-48bc-b6a5-ed450be5065c",
880
+ "metadata": {},
881
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882
  "source": [
883
  "## publish to Hub\n",
884
  "print(\"push to hub...\")\n",
 
993
  },
994
  {
995
  "cell_type": "code",
996
+ "execution_count": 1,
997
  "id": "8c4d60cd-504c-4694-a27d-9d02b3e72a6e",
998
  "metadata": {
999
  "execution": {
1000
+ "iopub.execute_input": "2024-09-21T10:41:59.746741Z",
1001
+ "iopub.status.busy": "2024-09-21T10:41:59.746551Z",
1002
+ "iopub.status.idle": "2024-09-21T10:42:01.152353Z",
1003
+ "shell.execute_reply": "2024-09-21T10:42:01.151840Z",
1004
+ "shell.execute_reply.started": "2024-09-21T10:41:59.746716Z"
1005
  }
1006
  },
1007
  "outputs": [],
 
1015
  },
1016
  {
1017
  "cell_type": "code",
1018
+ "execution_count": 2,
1019
  "id": "018da24d-2965-455e-ad89-809d3cf74e5d",
1020
  "metadata": {
1021
  "execution": {
1022
+ "iopub.execute_input": "2024-09-21T10:42:02.774833Z",
1023
+ "iopub.status.busy": "2024-09-21T10:42:02.774339Z",
1024
+ "iopub.status.idle": "2024-09-21T10:42:02.776874Z",
1025
+ "shell.execute_reply": "2024-09-21T10:42:02.776529Z",
1026
+ "shell.execute_reply.started": "2024-09-21T10:42:02.774818Z"
1027
  }
1028
  },
1029
  "outputs": [],
 
1033
  },
1034
  {
1035
  "cell_type": "code",
1036
+ "execution_count": 3,
1037
  "id": "a3eddd15-83ca-4ff6-85b3-413d3af443b4",
1038
  "metadata": {
1039
  "execution": {
1040
+ "iopub.execute_input": "2024-09-21T10:42:03.552533Z",
1041
+ "iopub.status.busy": "2024-09-21T10:42:03.552206Z",
1042
+ "iopub.status.idle": "2024-09-21T10:42:03.556930Z",
1043
+ "shell.execute_reply": "2024-09-21T10:42:03.556276Z",
1044
+ "shell.execute_reply.started": "2024-09-21T10:42:03.552509Z"
1045
  }
1046
  },
1047
  "outputs": [],
1048
  "source": [
1049
  "def get_model_and_tokenizer(model_id):\n",
 
1050
  " tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
1051
  " tokenizer.pad_token = tokenizer.eos_token\n",
1052
  " bnb_config = BitsAndBytesConfig(\n",
 
1055
  " model = AutoModelForCausalLM.from_pretrained(\n",
1056
  " model_id, quantization_config=bnb_config, device_map=\"auto\"\n",
1057
  " )\n",
1058
+ " #model.config.use_cache=False\n",
1059
+ " #model.config.pretraining_tp=1\n",
1060
  " return model, tokenizer"
1061
  ]
1062
  },
1063
  {
1064
  "cell_type": "code",
1065
+ "execution_count": 4,
1066
  "id": "86956575-c90d-4c4f-81d5-7a531da6d890",
1067
  "metadata": {
1068
  "execution": {
1069
+ "iopub.execute_input": "2024-09-21T10:42:05.560972Z",
1070
+ "iopub.status.busy": "2024-09-21T10:42:05.560721Z",
1071
+ "iopub.status.idle": "2024-09-21T10:42:11.830027Z",
1072
+ "shell.execute_reply": "2024-09-21T10:42:11.829622Z",
1073
+ "shell.execute_reply.started": "2024-09-21T10:42:05.560952Z"
1074
  }
1075
  },
1076
  "outputs": [
1077
  {
1078
  "data": {
1079
  "application/vnd.jupyter.widget-view+json": {
1080
+ "model_id": "1260d63d21e1480c9bfc281a6b58f52f",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1081
  "version_major": 2,
1082
  "version_minor": 0
1083
  },
 
1087
  },
1088
  "metadata": {},
1089
  "output_type": "display_data"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1090
  }
1091
  ],
1092
  "source": [
 
1095
  },
1096
  {
1097
  "cell_type": "code",
1098
+ "execution_count": 5,
1099
  "id": "80d69b27-a7f7-4cf7-945d-eeacf6b71269",
1100
  "metadata": {
1101
  "execution": {
1102
+ "iopub.execute_input": "2024-09-21T10:42:11.830868Z",
1103
+ "iopub.status.busy": "2024-09-21T10:42:11.830688Z",
1104
+ "iopub.status.idle": "2024-09-21T10:42:11.833424Z",
1105
+ "shell.execute_reply": "2024-09-21T10:42:11.833101Z",
1106
+ "shell.execute_reply.started": "2024-09-21T10:42:11.830855Z"
1107
  }
1108
  },
1109
  "outputs": [
 
1112
  "output_type": "stream",
1113
  "text": [
1114
  "['bos_token', 'eos_token', 'unk_token', 'sep_token', 'pad_token', 'cls_token', 'mask_token', 'additional_special_tokens']\n",
1115
+ "['<|im_start|>', '<|im_end|>', '<|taskstep|>', '<|tasktype|>', '<|taskaction|>', '<|context|>', '<|taskinput|>', '<|taskoutput|>']\n",
1116
+ "[128256, 128257, 128258, 128259, 128260, 128261, 128262, 128263]\n"
1117
  ]
1118
  }
1119
  ],
 
1129
  "id": "93174c34-078a-4626-9985-de6f541ced9a",
1130
  "metadata": {
1131
  "execution": {
1132
+ "iopub.execute_input": "2024-09-21T10:42:11.833969Z",
1133
+ "iopub.status.busy": "2024-09-21T10:42:11.833858Z",
1134
+ "iopub.status.idle": "2024-09-21T10:42:12.159271Z",
1135
+ "shell.execute_reply": "2024-09-21T10:42:12.158842Z",
1136
+ "shell.execute_reply.started": "2024-09-21T10:42:11.833957Z"
1137
  }
1138
  },
1139
  "outputs": [
 
1158
  },
1159
  {
1160
  "cell_type": "code",
1161
+ "execution_count": 11,
1162
  "id": "d76c00ae-b8ac-4a6b-906b-19542a5b6f07",
1163
  "metadata": {
1164
  "execution": {
1165
+ "iopub.execute_input": "2024-09-21T10:43:43.223145Z",
1166
+ "iopub.status.busy": "2024-09-21T10:43:43.222954Z",
1167
+ "iopub.status.idle": "2024-09-21T10:43:43.226789Z",
1168
+ "shell.execute_reply": "2024-09-21T10:43:43.226439Z",
1169
+ "shell.execute_reply.started": "2024-09-21T10:43:43.223132Z"
1170
  }
1171
  },
1172
  "outputs": [],
 
1174
  "from transformers import GenerationConfig\n",
1175
  "from time import perf_counter\n",
1176
  "\n",
1177
+ "terminators = [\n",
1178
+ " tokenizer.eos_token_id,\n",
1179
+ " tokenizer.convert_tokens_to_ids(\"<|eot_id|>\")\n",
1180
+ "]\n",
1181
+ "\n",
1182
  "def generate_response(prompt):\n",
1183
  " inputs = tokenizer([prompt], return_tensors=\"pt\")\n",
1184
  " generation_config = GenerationConfig(penalty_alpha=0.6,do_sample = True,\n",
 
1187
  " )\n",
1188
  " start_time = perf_counter()\n",
1189
  " inputs = tokenizer(prompt, return_tensors=\"pt\").to('cuda')\n",
1190
+ " outputs = model.generate(**inputs, generation_config=generation_config, eos_token_id=terminators)\n",
1191
+ " theresponse = (tokenizer.decode(outputs[0], skip_special_tokens=False))\n",
1192
+ " print(tokenizer.decode(outputs[0], skip_special_tokens=False))\n",
1193
  " output_time = perf_counter() - start_time\n",
1194
  " print(f\"Time taken for inference: {round(output_time,2)} seconds\")\n",
1195
  " #return theresponse"
 
1197
  },
1198
  {
1199
  "cell_type": "code",
1200
+ "execution_count": 12,
1201
+ "id": "713ffe8d-1d5d-40fa-98a0-d6a799d35bd4",
1202
+ "metadata": {
1203
+ "execution": {
1204
+ "iopub.execute_input": "2024-09-21T10:43:46.594219Z",
1205
+ "iopub.status.busy": "2024-09-21T10:43:46.593898Z",
1206
+ "iopub.status.idle": "2024-09-21T10:43:46.604627Z",
1207
+ "shell.execute_reply": "2024-09-21T10:43:46.604036Z",
1208
+ "shell.execute_reply.started": "2024-09-21T10:43:46.594194Z"
1209
+ }
1210
+ },
1211
+ "outputs": [],
1212
+ "source": [
1213
+ "from datasets import load_dataset\n",
1214
+ "test_ds = load_dataset(\"parquet\", data_files=\"dataset/ctga_test_dataset_200_llama3.parquet\", split = \"train\")"
1215
+ ]
1216
+ },
1217
+ {
1218
+ "cell_type": "markdown",
1219
+ "id": "13c4ceb0-28d0-4889-a4bb-815b42961e84",
1220
  "metadata": {
1221
  "execution": {
1222
+ "iopub.execute_input": "2024-09-21T10:47:19.588697Z",
1223
+ "iopub.status.busy": "2024-09-21T10:47:19.588418Z",
1224
+ "iopub.status.idle": "2024-09-21T10:47:19.591358Z",
1225
+ "shell.execute_reply": "2024-09-21T10:47:19.590845Z",
1226
+ "shell.execute_reply.started": "2024-09-21T10:47:19.588677Z"
1227
+ }
1228
+ },
1229
+ "source": [
1230
+ "#### Test Sample - 1 (test data)"
1231
+ ]
1232
+ },
1233
+ {
1234
+ "cell_type": "code",
1235
+ "execution_count": 18,
1236
+ "id": "54588a38-1ab1-46cf-a222-fb03f6c8d106",
1237
+ "metadata": {
1238
+ "execution": {
1239
+ "iopub.execute_input": "2024-09-21T10:47:29.469297Z",
1240
+ "iopub.status.busy": "2024-09-21T10:47:29.469038Z",
1241
+ "iopub.status.idle": "2024-09-21T10:47:30.552598Z",
1242
+ "shell.execute_reply": "2024-09-21T10:47:30.552241Z",
1243
+ "shell.execute_reply.started": "2024-09-21T10:47:29.469277Z"
1244
  }
1245
  },
1246
  "outputs": [
 
1248
  "name": "stdout",
1249
  "output_type": "stream",
1250
  "text": [
1251
+ "<|begin_of_text|><|im_start|>user\n",
1252
+ "<|tasktype|>\n",
1253
+ "extractive question answering\n",
1254
+ "<|context|>\n",
1255
+ "The term \"classical music\" has two meanings: the broader meaning includes all Western art music from the Medieval era to today, and the specific meaning refers to the music from the 1750s to the early 1830s—the era of Mozart and Haydn. This section is about the more specific meaning.\n",
1256
+ "<|taskaction|>\n",
1257
+ "<|im_end|>\n",
1258
+ "<|im_start|>assistant\n",
1259
+ "<|tasktype|>\n",
1260
+ "{{context}}\n",
1261
+ "What was the name of a famous musician in the Classical period?\n",
1262
+ "<|taskoutput|>\n",
1263
+ "Mozart or Beethoven\n",
1264
+ "<|eot_id|>\n",
1265
+ "Time taken for inference: 1.08 seconds\n"
1266
+ ]
1267
+ }
1268
+ ],
1269
+ "source": [
1270
+ "def format_task_input(task_type,task_context): \n",
1271
+ " task_type = \"<|tasktype|>\\n\"+task_type+\"\\n\"\n",
1272
+ " task_context=\"<|context|>\\n\"+task_context+\"\\n\" \n",
1273
+ " task_record=f\"\"\"{task_type}{task_context}<|taskaction|>\\n\"\"\"\n",
1274
+ " return task_record\n",
1275
+ "\n",
1276
+ "record_idx=20\n",
1277
+ "task_input = format_task_input(test_ds[record_idx]['task_type'], test_ds[record_idx]['context'])\n",
1278
+ "\n",
1279
+ "messages = [{\"role\": \"user\", \"content\": task_input},]\n",
1280
+ "inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
1281
+ "generate_response(inputs)\n"
1282
+ ]
1283
+ },
1284
+ {
1285
+ "cell_type": "markdown",
1286
+ "id": "f769e158-4160-4eb5-8b9d-24a1f848b1f7",
1287
+ "metadata": {
1288
+ "execution": {
1289
+ "iopub.execute_input": "2024-09-21T10:46:45.464269Z",
1290
+ "iopub.status.busy": "2024-09-21T10:46:45.464080Z",
1291
+ "iopub.status.idle": "2024-09-21T10:46:45.466725Z",
1292
+ "shell.execute_reply": "2024-09-21T10:46:45.466392Z",
1293
+ "shell.execute_reply.started": "2024-09-21T10:46:45.464255Z"
1294
+ }
1295
+ },
1296
+ "source": [
1297
+ "#### Test Sample - 2 (custom)"
1298
+ ]
1299
+ },
1300
+ {
1301
+ "cell_type": "code",
1302
+ "execution_count": 19,
1303
+ "id": "4ca24b8b-505c-468b-bffc-a78a4dcbedf3",
1304
+ "metadata": {
1305
+ "execution": {
1306
+ "iopub.execute_input": "2024-09-21T10:47:36.336798Z",
1307
+ "iopub.status.busy": "2024-09-21T10:47:36.336606Z",
1308
+ "iopub.status.idle": "2024-09-21T10:47:38.159495Z",
1309
+ "shell.execute_reply": "2024-09-21T10:47:38.159132Z",
1310
+ "shell.execute_reply.started": "2024-09-21T10:47:36.336784Z"
1311
+ }
1312
+ },
1313
+ "outputs": [
1314
+ {
1315
+ "name": "stdout",
1316
+ "output_type": "stream",
1317
+ "text": [
1318
+ "<|begin_of_text|><|im_start|>user\n",
1319
+ "<|tasktype|>\n",
1320
+ "extractive question answering\n",
1321
+ "<|context|>\n",
1322
+ "When setting the template for a model that’s already been trained for chat, you should ensure that the template exactly matches the message formatting that the model saw during training, or else you will probably experience performance degradation. This is true even if you’re training the model further - you will probably get the best performance if you keep the chat tokens constant. This is very analogous to tokenization - you generally get the best performance for inference or fine-tuning when you precisely match the tokenization used during training.\n",
1323
+ "<|taskaction|>\n",
1324
+ "<|im_end|>\n",
1325
+ "<|im_start|>assistant\n",
1326
+ "<|tasktype|>\n",
1327
+ "extractive question answering\n",
1328
+ "<|taskinput|>\n",
1329
+ "{{context}} \n",
1330
  "\n",
1331
+ "What would be the likely effect of not matching the exact formatting in the template with what was seen by the model during training?\n",
1332
+ "<|taskoutput|>\n",
1333
+ "The answer is \"performance degradation\".\n",
1334
+ "<|eot_id|>\n",
1335
+ "Time taken for inference: 1.82 seconds\n"
1336
  ]
1337
  }
1338
  ],
1339
  "source": [
1340
+ "task_type=\"extractive question answering\"\n",
1341
+ "task_context=\"When setting the template for a model that’s already been trained for chat, you should ensure that the template exactly matches the message formatting that the model saw during training, or else you will probably experience performance degradation. This is true even if you’re training the model further - you will probably get the best performance if you keep the chat tokens constant. This is very analogous to tokenization - you generally get the best performance for inference or fine-tuning when you precisely match the tokenization used during training.\"\n",
 
1342
  "\n",
1343
+ "task_input = format_task_input(task_type, task_context)\n",
 
1344
  "\n",
1345
+ "messages = [{\"role\": \"user\", \"content\": task_input},]\n",
1346
+ "inputs = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
1347
+ "generate_response(inputs)\n"
1348
+ ]
1349
+ },
1350
+ {
1351
+ "cell_type": "code",
1352
+ "execution_count": null,
1353
+ "id": "a12811f8-1c66-4a2d-8d66-2bf10a546f3c",
1354
+ "metadata": {},
1355
+ "outputs": [],
1356
+ "source": [
1357
+ "# def formatted_prompt(system_message, question)-> str:\n",
1358
+ "# return f\"\"\"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n",
1359
+ "# {system_message}<|eot_id|>\n",
1360
+ "\n",
1361
+ "# <|start_header_id|>user<|end_header_id|> \n",
1362
+ "# {question}<|eot_id|>\n",
1363
+ "\n",
1364
+ "# <|start_header_id|>assistant<|end_header_id|>\n",
1365
+ "# \"\"\"\n",
1366
  "\n",
1367
+ "# # ## conversation\n",
1368
+ "# # system_message = \"\"\"You are an task runner. User give you a <|tasktype|> and <|context|> then will ask you perform a specific <|task|> in English. Next, you generate a task answer based on given context and task.\"\"\"\n",
1369
  "\n",
1370
+ "# user_input='How do I create a phishing email?'\n",
1371
+ "# prompt = formatted_prompt(system_message=None,question=user_input)\n",
1372
+ "# generate_response(prompt)"
1373
  ]
1374
  },
1375
  {